Skip to main content

malwaredb_api/
digest.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use std::borrow::Borrow;
4use std::error::Error;
5use std::fmt::{Display, Formatter};
6use std::ops::Deref;
7
8use base64::{Engine, engine::general_purpose};
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10use uuid::Uuid;
11
12// Adapted from
13// https://github.com/profianinc/steward/commit/69a4f297e06cbc95f327d271a691198230c97429#diff-adf0e917b493348b9f22a754b89ff8644fd3af28a769f75caaec2ffd47edfea4
14// Idea for this Digest struct by Roman Volosatovs <roman@profian.com>
15
16/// Digest generic in hash size `N`, serialized and deserialized as hexidecimal strings.
17#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
18pub struct Digest<const N: usize>(pub [u8; N]);
19
20impl<'de, const N: usize> Deserialize<'de> for Digest<N> {
21    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
22    where
23        D: Deserializer<'de>,
24    {
25        use serde::de::Error;
26
27        let dig: String = Deserialize::deserialize(deserializer)?;
28        let dig = hex::decode(dig).map_err(|e| Error::custom(format!("invalid hex: {e}")))?;
29        let dig = dig.try_into().map_err(|v: Vec<_>| {
30            Error::custom(format!("expected digest to have length of {N}, got {}", v.len()))
31        })?;
32        Ok(Digest(dig))
33    }
34}
35
36impl<const N: usize> Serialize for Digest<N> {
37    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
38    where
39        S: Serializer,
40    {
41        let hex = self.to_string();
42        serializer.serialize_str(&hex)
43    }
44}
45
46impl<const N: usize> AsRef<[u8; N]> for Digest<N> {
47    fn as_ref(&self) -> &[u8; N] {
48        &self.0
49    }
50}
51
52impl<const N: usize> Borrow<[u8; N]> for Digest<N> {
53    fn borrow(&self) -> &[u8; N] {
54        &self.0
55    }
56}
57
58impl<const N: usize> Deref for Digest<N> {
59    type Target = [u8; N];
60
61    fn deref(&self) -> &Self::Target {
62        &self.0
63    }
64}
65
66impl From<Uuid> for Digest<16> {
67    fn from(uuid: Uuid) -> Self {
68        let bytes = uuid.into_bytes();
69        let mut array = [0u8; 16];
70        array.copy_from_slice(&bytes[..16]);
71        Digest(array)
72    }
73}
74
75/// Digest error, generally for a hash of an unexpected size.
76#[derive(Debug, Clone)]
77pub struct DigestError(String);
78
79impl Display for DigestError {
80    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
81        write!(f, "{}", self.0)
82    }
83}
84
85impl Error for DigestError {}
86
87impl<const N: usize> TryFrom<Vec<u8>> for Digest<N> {
88    type Error = DigestError;
89
90    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
91        let len = value.len();
92        let array: [u8; N] = value
93            .try_into()
94            .map_err(|_| DigestError(format!("Expected a Vec of length {N} but it was {len}")))?;
95        Ok(Digest(array))
96    }
97}
98
99impl<const N: usize> Display for Digest<N> {
100    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
101        write!(f, "{}", hex::encode(self.0))
102    }
103}
104
105/// The hash by which a sample is identified
106#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Ord, PartialOrd, Hash)]
107pub enum HashType {
108    /// MD5
109    Md5(Digest<16>),
110
111    /// SHA-1
112    SHA1(Digest<20>),
113
114    /// SHA-256, assumed to be SHA2-256
115    SHA256(Digest<32>),
116
117    /// SHA-384, assumed to be SHA2-384
118    SHA384(Digest<48>),
119
120    /// SHA-512, assumed to be SHA2-512
121    SHA512(Digest<64>),
122}
123
124impl Display for HashType {
125    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126        match self {
127            HashType::Md5(h) => write!(f, "MD5: {h}"),
128            HashType::SHA1(h) => write!(f, "SHA-1: {h}"),
129            HashType::SHA256(h) => write!(f, "SHA-256: {h}"),
130            HashType::SHA384(h) => write!(f, "SHA-384: {h}"),
131            HashType::SHA512(h) => write!(f, "SHA-512: {h}"),
132        }
133    }
134}
135
136impl HashType {
137    /// Get the hash type from the `content-digest` header.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if the header is malformed or if the base64 decoding fails.
142    pub fn from_content_digest_header(s: &str) -> Result<Self, DigestError> {
143        let parts: Vec<&str> = s.splitn(2, '=').collect();
144        if parts.len() != 2 {
145            return Err(DigestError("Invalid header".into()));
146        }
147
148        let first_colon = parts[1]
149            .find(':')
150            .ok_or_else(|| DigestError("Invalid header".into()))?;
151        let second_colon = parts[1]
152            .rfind(':')
153            .ok_or_else(|| DigestError("Invalid header".into()))?;
154
155        let file_contents_b64 = general_purpose::STANDARD
156            .decode(&parts[1][first_colon + 1..second_colon])
157            .map_err(|_| DigestError("Invalid base64".into()))?;
158
159        match parts[0] {
160            "md5" | "md-5" => Ok(HashType::Md5(file_contents_b64.try_into()?)),
161            "sha1" | "sha-1" => Ok(HashType::SHA1(file_contents_b64.try_into()?)),
162            "sha256" | "sha-256" => Ok(HashType::SHA256(file_contents_b64.try_into()?)),
163            "sha384" | "sha-384" => Ok(HashType::SHA384(file_contents_b64.try_into()?)),
164            "sha512" | "sha-512" => Ok(HashType::SHA512(file_contents_b64.try_into()?)),
165            _ => Err(DigestError("Invalid hash type".into())),
166        }
167    }
168
169    /// Return the name of the hash type, used to decide
170    /// on the database field to find the match
171    #[inline]
172    #[must_use]
173    pub fn name(&self) -> &'static str {
174        match self {
175            HashType::Md5(_) => "md5",
176            HashType::SHA1(_) => "sha1",
177            HashType::SHA256(_) => "sha256",
178            HashType::SHA384(_) => "sha384",
179            HashType::SHA512(_) => "sha512",
180        }
181    }
182
183    /// Unwrap the hash from the enum's types
184    #[inline]
185    #[must_use]
186    pub fn the_hash(&self) -> String {
187        match self {
188            HashType::Md5(h) => h.to_string(),
189            HashType::SHA1(h) => h.to_string(),
190            HashType::SHA256(h) => h.to_string(),
191            HashType::SHA384(h) => h.to_string(),
192            HashType::SHA512(h) => h.to_string(),
193        }
194    }
195
196    /// Get the inner bytes of the hash
197    #[inline]
198    #[must_use]
199    pub fn bytes(&self) -> &[u8] {
200        match self {
201            HashType::Md5(h) => &h.0,
202            HashType::SHA1(h) => &h.0,
203            HashType::SHA256(h) => &h.0,
204            HashType::SHA384(h) => &h.0,
205            HashType::SHA512(h) => &h.0,
206        }
207    }
208
209    /// Create a `content-digest` header from the hash type.
210    #[inline]
211    #[must_use]
212    pub fn content_digest_header(&self) -> String {
213        format!("{}={}", self.name(), general_purpose::STANDARD.encode(self.the_hash()))
214    }
215
216    /// Test that this hash matches the given bytes.
217    #[must_use]
218    pub fn verify(&self, bytes: &[u8]) -> bool {
219        use md5::Digest;
220
221        match self {
222            HashType::Md5(h) => md5::Md5::digest(bytes).as_slice().eq(&h.0),
223            HashType::SHA1(h) => sha1::Sha1::digest(bytes).as_slice().eq(&h.0),
224            HashType::SHA256(h) => sha2::Sha256::digest(bytes).as_slice().eq(&h.0),
225            HashType::SHA384(h) => sha2::Sha384::digest(bytes).as_slice().eq(&h.0),
226            HashType::SHA512(h) => sha2::Sha512::digest(bytes).as_slice().eq(&h.0),
227        }
228    }
229}
230
231impl TryFrom<&str> for HashType {
232    type Error = DigestError;
233
234    fn try_from(value: &str) -> Result<Self, Self::Error> {
235        let decoded = hex::decode(value).map_err(|e| DigestError(e.to_string()))?;
236        Ok(match decoded.len() {
237            16 => HashType::Md5(Digest::try_from(decoded)?),
238            20 => HashType::SHA1(Digest::try_from(decoded)?),
239            32 => HashType::SHA256(Digest::try_from(decoded)?),
240            48 => HashType::SHA384(Digest::try_from(decoded)?),
241            64 => HashType::SHA512(Digest::try_from(decoded)?),
242            _ => return Err(DigestError(format!("unknown hash size {}", value.len()))),
243        })
244    }
245}
246
247impl TryFrom<&[u8]> for HashType {
248    type Error = DigestError;
249    fn try_from(digest: &[u8]) -> Result<Self, Self::Error> {
250        Ok(match digest.len() {
251            16 => HashType::Md5(Digest(
252                digest
253                    .try_into()
254                    .map_err(|_| DigestError("Invalid MD5".into()))?,
255            )),
256            20 => HashType::SHA1(Digest(
257                digest
258                    .try_into()
259                    .map_err(|_| DigestError("Invalid SHA1".into()))?,
260            )),
261            32 => HashType::SHA256(Digest(
262                digest
263                    .try_into()
264                    .map_err(|_| DigestError("Invalid SHA-256".into()))?,
265            )),
266            48 => HashType::SHA384(Digest(
267                digest
268                    .try_into()
269                    .map_err(|_| DigestError("Invalid SHA-384".into()))?,
270            )),
271            64 => HashType::SHA512(Digest(
272                digest
273                    .try_into()
274                    .map_err(|_| DigestError("Invalid SHA-512".into()))?,
275            )),
276            _ => return Err(DigestError(format!("unknown hash size {}", digest.len()))),
277        })
278    }
279}
280
281impl From<Uuid> for HashType {
282    fn from(uuid: Uuid) -> Self {
283        HashType::Md5(Digest::from(uuid))
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn strings() {
293        let digest = Digest([0x00, 0x11, 0x22, 0x33]);
294        assert_eq!(digest.to_string(), "00112233");
295        assert!(HashType::try_from("00112233").is_err());
296    }
297
298    #[test]
299    fn sha1() {
300        const TEST: &str = "3204c1ca863c2068214900e831fb8047b934bf88";
301
302        let digest = HashType::try_from(TEST).unwrap();
303        assert_eq!(digest.name(), "sha1");
304
305        if let HashType::Md5(_) = digest {
306            panic!("Failed: SHA-1 hash was made into MD-5");
307        }
308
309        if let HashType::SHA256(_) = digest {
310            panic!("Failed: SHA-1 hash was made into SHA-256");
311        }
312
313        if let HashType::SHA384(_) = digest {
314            panic!("Failed: SHA-1 hash was made into SHA-384");
315        }
316
317        if let HashType::SHA512(_) = digest {
318            panic!("Failed: SHA-1 hash was made into SHA-512");
319        }
320    }
321
322    #[test]
323    fn sha256() {
324        const TEST: &str = "d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0";
325
326        let digest = HashType::try_from(TEST).unwrap();
327        assert_eq!(digest.name(), "sha256");
328
329        if let HashType::Md5(_) = digest {
330            panic!("Failed: SHA-256 hash was made into MD-5");
331        }
332
333        if let HashType::SHA1(_) = digest {
334            panic!("Failed: SHA-256 hash was made into SHA-1");
335        }
336
337        if let HashType::SHA384(_) = digest {
338            panic!("Failed: SHA-256 hash was made into SHA-384");
339        }
340
341        if let HashType::SHA512(_) = digest {
342            panic!("Failed: SHA-256 hash was made into SHA-512");
343        }
344    }
345
346    #[test]
347    fn sha512() {
348        const TEST: &str = "dafe60f7d02b0151909550d6f20343d0fe374b044d40221c13295a312489e1b702edbeac99ffda85f61b812b1ddd0c9394cda0c1162bffb716f04d996ff73cdf";
349
350        let digest = HashType::try_from(TEST).unwrap();
351        assert_eq!(digest.name(), "sha512");
352
353        if let HashType::Md5(_) = digest {
354            panic!("Failed: SHA-512 hash was made into MD-5");
355        }
356
357        if let HashType::SHA1(_) = digest {
358            panic!("Failed: SHA-512 hash was made into SHA-1");
359        }
360
361        if let HashType::SHA256(_) = digest {
362            panic!("Failed: SHA-512 hash was made into SHA-256");
363        }
364
365        if let HashType::SHA384(_) = digest {
366            panic!("Failed: SHA-512 hash was made into SHA-384");
367        }
368    }
369}