Skip to main content

malwaredb_api/
digest.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use std::borrow::Borrow;
4use std::error::Error;
5use std::fmt::{Display, Formatter};
6use std::ops::Deref;
7
8use base64::{engine::general_purpose, Engine};
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10use uuid::Uuid;
11
12// Adapted from
13// https://github.com/profianinc/steward/commit/69a4f297e06cbc95f327d271a691198230c97429#diff-adf0e917b493348b9f22a754b89ff8644fd3af28a769f75caaec2ffd47edfea4
14// Idea for this Digest struct by Roman Volosatovs <roman@profian.com>
15
16/// Digest generic in hash size `N`, serialized and deserialized as hexidecimal strings.
17#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
18pub struct Digest<const N: usize>(pub [u8; N]);
19
20impl<'de, const N: usize> Deserialize<'de> for Digest<N> {
21    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
22    where
23        D: Deserializer<'de>,
24    {
25        use serde::de::Error;
26
27        let dig: String = Deserialize::deserialize(deserializer)?;
28        let dig = hex::decode(dig).map_err(|e| Error::custom(format!("invalid hex: {e}")))?;
29        let dig = dig.try_into().map_err(|v: Vec<_>| {
30            Error::custom(format!(
31                "expected digest to have length of {N}, got {}",
32                v.len()
33            ))
34        })?;
35        Ok(Digest(dig))
36    }
37}
38
39impl<const N: usize> Serialize for Digest<N> {
40    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
41    where
42        S: Serializer,
43    {
44        let hex = self.to_string();
45        serializer.serialize_str(&hex)
46    }
47}
48
49impl<const N: usize> AsRef<[u8; N]> for Digest<N> {
50    fn as_ref(&self) -> &[u8; N] {
51        &self.0
52    }
53}
54
55impl<const N: usize> Borrow<[u8; N]> for Digest<N> {
56    fn borrow(&self) -> &[u8; N] {
57        &self.0
58    }
59}
60
61impl<const N: usize> Deref for Digest<N> {
62    type Target = [u8; N];
63
64    fn deref(&self) -> &Self::Target {
65        &self.0
66    }
67}
68
69impl From<Uuid> for Digest<16> {
70    fn from(uuid: Uuid) -> Self {
71        let bytes = uuid.into_bytes();
72        let mut array = [0u8; 16];
73        array.copy_from_slice(&bytes[..16]);
74        Digest(array)
75    }
76}
77
78/// Digest error, generally for a hash of an unexpected size.
79#[derive(Debug, Clone)]
80pub struct DigestError(String);
81
82impl Display for DigestError {
83    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
84        write!(f, "{}", self.0)
85    }
86}
87
88impl Error for DigestError {}
89
90impl<const N: usize> TryFrom<Vec<u8>> for Digest<N> {
91    type Error = DigestError;
92
93    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
94        let len = value.len();
95        let array: [u8; N] = value
96            .try_into()
97            .map_err(|_| DigestError(format!("Expected a Vec of length {N} but it was {len}")))?;
98        Ok(Digest(array))
99    }
100}
101
102impl<const N: usize> Display for Digest<N> {
103    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
104        write!(f, "{}", hex::encode(self.0))
105    }
106}
107
108/// The hash by which a sample is identified
109#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Ord, PartialOrd, Hash)]
110pub enum HashType {
111    /// MD5
112    Md5(Digest<16>),
113
114    /// SHA-1
115    SHA1(Digest<20>),
116
117    /// SHA-256, assumed to be SHA2-256
118    SHA256(Digest<32>),
119
120    /// SHA-384, assumed to be SHA2-384
121    SHA384(Digest<48>),
122
123    /// SHA-512, assumed to be SHA2-512
124    SHA512(Digest<64>),
125}
126
127impl Display for HashType {
128    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
129        match self {
130            HashType::Md5(h) => write!(f, "MD5: {h}"),
131            HashType::SHA1(h) => write!(f, "SHA-1: {h}"),
132            HashType::SHA256(h) => write!(f, "SHA-256: {h}"),
133            HashType::SHA384(h) => write!(f, "SHA-384: {h}"),
134            HashType::SHA512(h) => write!(f, "SHA-512: {h}"),
135        }
136    }
137}
138
139impl HashType {
140    /// Get the hash type from the `content-digest` header.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if the header is malformed or if the base64 decoding fails.
145    pub fn from_content_digest_header(s: &str) -> Result<Self, DigestError> {
146        let parts: Vec<&str> = s.splitn(2, '=').collect();
147        if parts.len() != 2 {
148            return Err(DigestError("Invalid header".into()));
149        }
150
151        let first_colon = parts[1]
152            .find(':')
153            .ok_or_else(|| DigestError("Invalid header".into()))?;
154        let second_colon = parts[1]
155            .rfind(':')
156            .ok_or_else(|| DigestError("Invalid header".into()))?;
157
158        let file_contents_b64 = general_purpose::STANDARD
159            .decode(&parts[1][first_colon + 1..second_colon])
160            .map_err(|_| DigestError("Invalid base64".into()))?;
161
162        match parts[0] {
163            "md5" | "md-5" => Ok(HashType::Md5(file_contents_b64.try_into()?)),
164            "sha1" | "sha-1" => Ok(HashType::SHA1(file_contents_b64.try_into()?)),
165            "sha256" | "sha-256" => Ok(HashType::SHA256(file_contents_b64.try_into()?)),
166            "sha384" | "sha-384" => Ok(HashType::SHA384(file_contents_b64.try_into()?)),
167            "sha512" | "sha-512" => Ok(HashType::SHA512(file_contents_b64.try_into()?)),
168            _ => Err(DigestError("Invalid hash type".into())),
169        }
170    }
171
172    /// Return the name of the hash type, used to decide
173    /// on the database field to find the match
174    #[must_use]
175    pub fn name(&self) -> &'static str {
176        match self {
177            HashType::Md5(_) => "md5",
178            HashType::SHA1(_) => "sha1",
179            HashType::SHA256(_) => "sha256",
180            HashType::SHA384(_) => "sha384",
181            HashType::SHA512(_) => "sha512",
182        }
183    }
184
185    /// Unwrap the hash from the enum's types
186    #[must_use]
187    pub fn the_hash(&self) -> String {
188        match self {
189            HashType::Md5(h) => h.to_string(),
190            HashType::SHA1(h) => h.to_string(),
191            HashType::SHA256(h) => h.to_string(),
192            HashType::SHA384(h) => h.to_string(),
193            HashType::SHA512(h) => h.to_string(),
194        }
195    }
196
197    /// Get the inner bytes of the hash
198    #[must_use]
199    pub fn bytes(&self) -> &[u8] {
200        match self {
201            HashType::Md5(h) => &h.0,
202            HashType::SHA1(h) => &h.0,
203            HashType::SHA256(h) => &h.0,
204            HashType::SHA384(h) => &h.0,
205            HashType::SHA512(h) => &h.0,
206        }
207    }
208
209    /// Create a `content-digest` header from the hash type.
210    #[must_use]
211    pub fn content_digest_header(&self) -> String {
212        format!(
213            "{}={}",
214            self.name(),
215            general_purpose::STANDARD.encode(self.the_hash())
216        )
217    }
218
219    /// See if the hash matches the given bytes.
220    #[must_use]
221    pub fn verify(&self, bytes: &[u8]) -> bool {
222        use md5::Digest;
223
224        match self {
225            HashType::Md5(h) => {
226                let mut md5 = md5::Md5::new();
227                md5.update(bytes);
228                let md5 = md5.finalize();
229                h.0 == md5.as_slice()
230            }
231            HashType::SHA1(h) => {
232                let mut sha1 = sha1::Sha1::new();
233                sha1.update(bytes);
234                let sha1 = sha1.finalize();
235                h.0 == sha1.as_slice()
236            }
237            HashType::SHA256(h) => {
238                let mut sha256 = sha2::Sha256::new();
239                sha256.update(bytes);
240                let sha256 = sha256.finalize();
241                h.0 == sha256.as_slice()
242            }
243            HashType::SHA384(h) => {
244                let mut sha384 = sha2::Sha384::new();
245                sha384.update(bytes);
246                let sha384 = sha384.finalize();
247                h.0 == sha384.as_slice()
248            }
249            HashType::SHA512(h) => {
250                let mut sha512 = sha2::Sha512::new();
251                sha512.update(bytes);
252                let sha512 = sha512.finalize();
253                h.0 == sha512.as_slice()
254            }
255        }
256    }
257}
258
259impl TryFrom<&str> for HashType {
260    type Error = DigestError;
261
262    fn try_from(value: &str) -> Result<Self, Self::Error> {
263        let decoded = hex::decode(value).map_err(|e| DigestError(e.to_string()))?;
264        Ok(match decoded.len() {
265            16 => HashType::Md5(Digest::try_from(decoded)?),
266            20 => HashType::SHA1(Digest::try_from(decoded)?),
267            32 => HashType::SHA256(Digest::try_from(decoded)?),
268            48 => HashType::SHA384(Digest::try_from(decoded)?),
269            64 => HashType::SHA512(Digest::try_from(decoded)?),
270            _ => return Err(DigestError(format!("unknown hash size {}", value.len()))),
271        })
272    }
273}
274
275impl TryFrom<&[u8]> for HashType {
276    type Error = DigestError;
277    fn try_from(digest: &[u8]) -> Result<Self, Self::Error> {
278        Ok(match digest.len() {
279            16 => HashType::Md5(Digest(
280                digest
281                    .try_into()
282                    .map_err(|_| DigestError("Invalid MD5".into()))?,
283            )),
284            20 => HashType::SHA1(Digest(
285                digest
286                    .try_into()
287                    .map_err(|_| DigestError("Invalid SHA1".into()))?,
288            )),
289            32 => HashType::SHA256(Digest(
290                digest
291                    .try_into()
292                    .map_err(|_| DigestError("Invalid SHA-256".into()))?,
293            )),
294            48 => HashType::SHA384(Digest(
295                digest
296                    .try_into()
297                    .map_err(|_| DigestError("Invalid SHA-384".into()))?,
298            )),
299            64 => HashType::SHA512(Digest(
300                digest
301                    .try_into()
302                    .map_err(|_| DigestError("Invalid SHA-512".into()))?,
303            )),
304            _ => return Err(DigestError(format!("unknown hash size {}", digest.len()))),
305        })
306    }
307}
308
309impl From<Uuid> for HashType {
310    fn from(uuid: Uuid) -> Self {
311        HashType::Md5(Digest::from(uuid))
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn strings() {
321        let digest = Digest([0x00, 0x11, 0x22, 0x33]);
322        assert_eq!(digest.to_string(), "00112233");
323        assert!(HashType::try_from("00112233").is_err());
324    }
325
326    #[test]
327    fn sha1() {
328        const TEST: &str = "3204c1ca863c2068214900e831fb8047b934bf88";
329
330        let digest = HashType::try_from(TEST).unwrap();
331        assert_eq!(digest.name(), "sha1");
332
333        if let HashType::Md5(_) = digest {
334            panic!("Failed: SHA-1 hash was made into MD-5");
335        }
336
337        if let HashType::SHA256(_) = digest {
338            panic!("Failed: SHA-1 hash was made into SHA-256");
339        }
340
341        if let HashType::SHA384(_) = digest {
342            panic!("Failed: SHA-1 hash was made into SHA-384");
343        }
344
345        if let HashType::SHA512(_) = digest {
346            panic!("Failed: SHA-1 hash was made into SHA-512");
347        }
348    }
349
350    #[test]
351    fn sha256() {
352        const TEST: &str = "d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0";
353
354        let digest = HashType::try_from(TEST).unwrap();
355        assert_eq!(digest.name(), "sha256");
356
357        if let HashType::Md5(_) = digest {
358            panic!("Failed: SHA-256 hash was made into MD-5");
359        }
360
361        if let HashType::SHA1(_) = digest {
362            panic!("Failed: SHA-256 hash was made into SHA-1");
363        }
364
365        if let HashType::SHA384(_) = digest {
366            panic!("Failed: SHA-256 hash was made into SHA-384");
367        }
368
369        if let HashType::SHA512(_) = digest {
370            panic!("Failed: SHA-256 hash was made into SHA-512");
371        }
372    }
373
374    #[test]
375    fn sha512() {
376        const TEST: &str = "dafe60f7d02b0151909550d6f20343d0fe374b044d40221c13295a312489e1b702edbeac99ffda85f61b812b1ddd0c9394cda0c1162bffb716f04d996ff73cdf";
377
378        let digest = HashType::try_from(TEST).unwrap();
379        assert_eq!(digest.name(), "sha512");
380
381        if let HashType::Md5(_) = digest {
382            panic!("Failed: SHA-512 hash was made into MD-5");
383        }
384
385        if let HashType::SHA1(_) = digest {
386            panic!("Failed: SHA-512 hash was made into SHA-1");
387        }
388
389        if let HashType::SHA256(_) = digest {
390            panic!("Failed: SHA-512 hash was made into SHA-256");
391        }
392
393        if let HashType::SHA384(_) = digest {
394            panic!("Failed: SHA-512 hash was made into SHA-384");
395        }
396    }
397}