malwaredb_api/
digest.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use std::borrow::Borrow;
4use std::error::Error;
5use std::fmt::{Display, Formatter};
6use std::ops::Deref;
7
8use base64::{engine::general_purpose, Engine};
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10
11// Adapted from
12// https://github.com/profianinc/steward/commit/69a4f297e06cbc95f327d271a691198230c97429#diff-adf0e917b493348b9f22a754b89ff8644fd3af28a769f75caaec2ffd47edfea4
13// Idea for this Digest struct by Roman Volosatovs <roman@profian.com>
14
15/// Digest generic in hash size `N`, serialized and deserialized as hexidecimal strings.
16#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
17pub struct Digest<const N: usize>(pub [u8; N]);
18
19impl<'de, const N: usize> Deserialize<'de> for Digest<N> {
20    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
21    where
22        D: Deserializer<'de>,
23    {
24        use serde::de::Error;
25
26        let dig: String = Deserialize::deserialize(deserializer)?;
27        let dig = hex::decode(dig).map_err(|e| Error::custom(format!("invalid hex: {e}")))?;
28        let dig = dig.try_into().map_err(|v: Vec<_>| {
29            Error::custom(format!(
30                "expected digest to have length of {N}, got {}",
31                v.len()
32            ))
33        })?;
34        Ok(Digest(dig))
35    }
36}
37
38impl<const N: usize> Serialize for Digest<N> {
39    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
40    where
41        S: Serializer,
42    {
43        let hex = self.to_string();
44        serializer.serialize_str(&hex)
45    }
46}
47
48impl<const N: usize> AsRef<[u8; N]> for Digest<N> {
49    fn as_ref(&self) -> &[u8; N] {
50        &self.0
51    }
52}
53
54impl<const N: usize> Borrow<[u8; N]> for Digest<N> {
55    fn borrow(&self) -> &[u8; N] {
56        &self.0
57    }
58}
59
60impl<const N: usize> Deref for Digest<N> {
61    type Target = [u8; N];
62
63    fn deref(&self) -> &Self::Target {
64        &self.0
65    }
66}
67
68/// Digest error, generally for a hash of an unexpected size.
69#[derive(Debug, Clone)]
70pub struct DigestError(String);
71
72impl Display for DigestError {
73    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
74        write!(f, "{}", self.0)
75    }
76}
77
78impl Error for DigestError {}
79
80impl<const N: usize> TryFrom<Vec<u8>> for Digest<N> {
81    type Error = DigestError;
82
83    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
84        let len = value.len();
85        let array: [u8; N] = value
86            .try_into()
87            .map_err(|_| DigestError(format!("Expected a Vec of length {N} but it was {len}")))?;
88        Ok(Digest(array))
89    }
90}
91
92impl<const N: usize> Display for Digest<N> {
93    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
94        write!(f, "{}", hex::encode(self.0))
95    }
96}
97
98/// The hash by which a sample is identified
99#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Ord, PartialOrd, Hash)]
100pub enum HashType {
101    /// MD5
102    Md5(Digest<16>),
103
104    /// SHA-1
105    SHA1(Digest<20>),
106
107    /// SHA-256, assumed to be SHA2-256
108    SHA256(Digest<32>),
109
110    /// SHA-384, assumed to be SHA2-384
111    SHA384(Digest<48>),
112
113    /// SHA-512, assumed to be SHA2-512
114    SHA512(Digest<64>),
115}
116
117impl Display for HashType {
118    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
119        match self {
120            HashType::Md5(h) => write!(f, "MD5: {h}"),
121            HashType::SHA1(h) => write!(f, "SHA-1: {h}"),
122            HashType::SHA256(h) => write!(f, "SHA-256: {h}"),
123            HashType::SHA384(h) => write!(f, "SHA-384: {h}"),
124            HashType::SHA512(h) => write!(f, "SHA-512: {h}"),
125        }
126    }
127}
128
129impl HashType {
130    /// Get the hash type from the `content-digest` header.
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if the header is malformed or if the base64 decoding fails.
135    pub fn from_content_digest_header(s: &str) -> Result<Self, DigestError> {
136        let parts: Vec<&str> = s.splitn(2, '=').collect();
137        if parts.len() != 2 {
138            return Err(DigestError("Invalid header".into()));
139        }
140
141        let first_colon = parts[1]
142            .find(':')
143            .ok_or_else(|| DigestError("Invalid header".into()))?;
144        let second_colon = parts[1]
145            .rfind(':')
146            .ok_or_else(|| DigestError("Invalid header".into()))?;
147
148        let file_contents_b64 = general_purpose::STANDARD
149            .decode(&parts[1][first_colon + 1..second_colon])
150            .map_err(|_| DigestError("Invalid base64".into()))?;
151
152        match parts[0] {
153            "md5" | "md-5" => Ok(HashType::Md5(file_contents_b64.try_into()?)),
154            "sha1" | "sha-1" => Ok(HashType::SHA1(file_contents_b64.try_into()?)),
155            "sha256" | "sha-256" => Ok(HashType::SHA256(file_contents_b64.try_into()?)),
156            "sha384" | "sha-384" => Ok(HashType::SHA384(file_contents_b64.try_into()?)),
157            "sha512" | "sha-512" => Ok(HashType::SHA512(file_contents_b64.try_into()?)),
158            _ => Err(DigestError("Invalid hash type".into())),
159        }
160    }
161
162    /// Return the name of the hash type, used to decide
163    /// on the database field to find the match
164    #[must_use]
165    pub fn name(&self) -> &'static str {
166        match self {
167            HashType::Md5(_) => "md5",
168            HashType::SHA1(_) => "sha1",
169            HashType::SHA256(_) => "sha256",
170            HashType::SHA384(_) => "sha384",
171            HashType::SHA512(_) => "sha512",
172        }
173    }
174
175    /// Unwrap the hash from the enum's types
176    #[must_use]
177    pub fn the_hash(&self) -> String {
178        match self {
179            HashType::Md5(h) => h.to_string(),
180            HashType::SHA1(h) => h.to_string(),
181            HashType::SHA256(h) => h.to_string(),
182            HashType::SHA384(h) => h.to_string(),
183            HashType::SHA512(h) => h.to_string(),
184        }
185    }
186
187    /// Get the inner bytes of the hash
188    #[must_use]
189    pub fn bytes(&self) -> &[u8] {
190        match self {
191            HashType::Md5(h) => &h.0,
192            HashType::SHA1(h) => &h.0,
193            HashType::SHA256(h) => &h.0,
194            HashType::SHA384(h) => &h.0,
195            HashType::SHA512(h) => &h.0,
196        }
197    }
198
199    /// Create a `content-digest` header from the hash type.
200    #[must_use]
201    pub fn content_digest_header(&self) -> String {
202        format!(
203            "{}={}",
204            self.name(),
205            general_purpose::STANDARD.encode(self.the_hash())
206        )
207    }
208
209    /// See if the hash matches the given bytes.
210    #[must_use]
211    pub fn verify(&self, bytes: &[u8]) -> bool {
212        use md5::Digest;
213
214        match self {
215            HashType::Md5(h) => {
216                let mut md5 = md5::Md5::new();
217                md5.update(bytes);
218                let md5 = md5.finalize();
219                h.0 == md5.as_slice()
220            }
221            HashType::SHA1(h) => {
222                let mut sha1 = sha1::Sha1::new();
223                sha1.update(bytes);
224                let sha1 = sha1.finalize();
225                h.0 == sha1.as_slice()
226            }
227            HashType::SHA256(h) => {
228                let mut sha256 = sha2::Sha256::new();
229                sha256.update(bytes);
230                let sha256 = sha256.finalize();
231                h.0 == sha256.as_slice()
232            }
233            HashType::SHA384(h) => {
234                let mut sha384 = sha2::Sha384::new();
235                sha384.update(bytes);
236                let sha384 = sha384.finalize();
237                h.0 == sha384.as_slice()
238            }
239            HashType::SHA512(h) => {
240                let mut sha512 = sha2::Sha512::new();
241                sha512.update(bytes);
242                let sha512 = sha512.finalize();
243                h.0 == sha512.as_slice()
244            }
245        }
246    }
247}
248
249impl TryFrom<&str> for HashType {
250    type Error = DigestError;
251
252    fn try_from(value: &str) -> Result<Self, Self::Error> {
253        let decoded = hex::decode(value).map_err(|e| DigestError(e.to_string()))?;
254        Ok(match decoded.len() {
255            16 => HashType::Md5(Digest::try_from(decoded)?),
256            20 => HashType::SHA1(Digest::try_from(decoded)?),
257            32 => HashType::SHA256(Digest::try_from(decoded)?),
258            48 => HashType::SHA384(Digest::try_from(decoded)?),
259            64 => HashType::SHA512(Digest::try_from(decoded)?),
260            _ => return Err(DigestError(format!("unknown hash size {}", value.len()))),
261        })
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn strings() {
271        let digest = Digest([0x00, 0x11, 0x22, 0x33]);
272        assert_eq!(digest.to_string(), "00112233");
273        assert!(HashType::try_from("00112233").is_err());
274    }
275
276    #[test]
277    fn sha1() {
278        const TEST: &str = "3204c1ca863c2068214900e831fb8047b934bf88";
279
280        let digest = HashType::try_from(TEST).unwrap();
281        assert_eq!(digest.name(), "sha1");
282
283        if let HashType::Md5(_) = digest {
284            panic!("Failed: SHA-1 hash was made into MD-5");
285        }
286
287        if let HashType::SHA256(_) = digest {
288            panic!("Failed: SHA-1 hash was made into SHA-256");
289        }
290
291        if let HashType::SHA384(_) = digest {
292            panic!("Failed: SHA-1 hash was made into SHA-384");
293        }
294
295        if let HashType::SHA512(_) = digest {
296            panic!("Failed: SHA-1 hash was made into SHA-512");
297        }
298    }
299
300    #[test]
301    fn sha256() {
302        const TEST: &str = "d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0";
303
304        let digest = HashType::try_from(TEST).unwrap();
305        assert_eq!(digest.name(), "sha256");
306
307        if let HashType::Md5(_) = digest {
308            panic!("Failed: SHA-256 hash was made into MD-5");
309        }
310
311        if let HashType::SHA1(_) = digest {
312            panic!("Failed: SHA-256 hash was made into SHA-1");
313        }
314
315        if let HashType::SHA384(_) = digest {
316            panic!("Failed: SHA-256 hash was made into SHA-384");
317        }
318
319        if let HashType::SHA512(_) = digest {
320            panic!("Failed: SHA-256 hash was made into SHA-512");
321        }
322    }
323
324    #[test]
325    fn sha512() {
326        const TEST: &str = "dafe60f7d02b0151909550d6f20343d0fe374b044d40221c13295a312489e1b702edbeac99ffda85f61b812b1ddd0c9394cda0c1162bffb716f04d996ff73cdf";
327
328        let digest = HashType::try_from(TEST).unwrap();
329        assert_eq!(digest.name(), "sha512");
330
331        if let HashType::Md5(_) = digest {
332            panic!("Failed: SHA-512 hash was made into MD-5");
333        }
334
335        if let HashType::SHA1(_) = digest {
336            panic!("Failed: SHA-512 hash was made into SHA-1");
337        }
338
339        if let HashType::SHA256(_) = digest {
340            panic!("Failed: SHA-512 hash was made into SHA-256");
341        }
342
343        if let HashType::SHA384(_) = digest {
344            panic!("Failed: SHA-512 hash was made into SHA-384");
345        }
346    }
347}