malwaredb-api 0.3.3

Common API endpoints and data types for MalwareDB components.
Documentation
// SPDX-License-Identifier: Apache-2.0

use std::borrow::Borrow;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::ops::Deref;

use base64::{engine::general_purpose, Engine};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use uuid::Uuid;

// Adapted from
// https://github.com/profianinc/steward/commit/69a4f297e06cbc95f327d271a691198230c97429#diff-adf0e917b493348b9f22a754b89ff8644fd3af28a769f75caaec2ffd47edfea4
// Idea for this Digest struct by Roman Volosatovs <roman@profian.com>

/// Digest generic in hash size `N`, serialized and deserialized as hexidecimal strings.
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct Digest<const N: usize>(pub [u8; N]);

impl<'de, const N: usize> Deserialize<'de> for Digest<N> {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        use serde::de::Error;

        let dig: String = Deserialize::deserialize(deserializer)?;
        let dig = hex::decode(dig).map_err(|e| Error::custom(format!("invalid hex: {e}")))?;
        let dig = dig.try_into().map_err(|v: Vec<_>| {
            Error::custom(format!(
                "expected digest to have length of {N}, got {}",
                v.len()
            ))
        })?;
        Ok(Digest(dig))
    }
}

impl<const N: usize> Serialize for Digest<N> {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let hex = self.to_string();
        serializer.serialize_str(&hex)
    }
}

impl<const N: usize> AsRef<[u8; N]> for Digest<N> {
    fn as_ref(&self) -> &[u8; N] {
        &self.0
    }
}

impl<const N: usize> Borrow<[u8; N]> for Digest<N> {
    fn borrow(&self) -> &[u8; N] {
        &self.0
    }
}

impl<const N: usize> Deref for Digest<N> {
    type Target = [u8; N];

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl From<Uuid> for Digest<16> {
    fn from(uuid: Uuid) -> Self {
        let bytes = uuid.into_bytes();
        let mut array = [0u8; 16];
        array.copy_from_slice(&bytes[..16]);
        Digest(array)
    }
}

/// Digest error, generally for a hash of an unexpected size.
#[derive(Debug, Clone)]
pub struct DigestError(String);

impl Display for DigestError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

impl Error for DigestError {}

impl<const N: usize> TryFrom<Vec<u8>> for Digest<N> {
    type Error = DigestError;

    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
        let len = value.len();
        let array: [u8; N] = value
            .try_into()
            .map_err(|_| DigestError(format!("Expected a Vec of length {N} but it was {len}")))?;
        Ok(Digest(array))
    }
}

impl<const N: usize> Display for Digest<N> {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", hex::encode(self.0))
    }
}

/// The hash by which a sample is identified
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum HashType {
    /// MD5
    Md5(Digest<16>),

    /// SHA-1
    SHA1(Digest<20>),

    /// SHA-256, assumed to be SHA2-256
    SHA256(Digest<32>),

    /// SHA-384, assumed to be SHA2-384
    SHA384(Digest<48>),

    /// SHA-512, assumed to be SHA2-512
    SHA512(Digest<64>),
}

impl Display for HashType {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            HashType::Md5(h) => write!(f, "MD5: {h}"),
            HashType::SHA1(h) => write!(f, "SHA-1: {h}"),
            HashType::SHA256(h) => write!(f, "SHA-256: {h}"),
            HashType::SHA384(h) => write!(f, "SHA-384: {h}"),
            HashType::SHA512(h) => write!(f, "SHA-512: {h}"),
        }
    }
}

impl HashType {
    /// Get the hash type from the `content-digest` header.
    ///
    /// # Errors
    ///
    /// Returns an error if the header is malformed or if the base64 decoding fails.
    pub fn from_content_digest_header(s: &str) -> Result<Self, DigestError> {
        let parts: Vec<&str> = s.splitn(2, '=').collect();
        if parts.len() != 2 {
            return Err(DigestError("Invalid header".into()));
        }

        let first_colon = parts[1]
            .find(':')
            .ok_or_else(|| DigestError("Invalid header".into()))?;
        let second_colon = parts[1]
            .rfind(':')
            .ok_or_else(|| DigestError("Invalid header".into()))?;

        let file_contents_b64 = general_purpose::STANDARD
            .decode(&parts[1][first_colon + 1..second_colon])
            .map_err(|_| DigestError("Invalid base64".into()))?;

        match parts[0] {
            "md5" | "md-5" => Ok(HashType::Md5(file_contents_b64.try_into()?)),
            "sha1" | "sha-1" => Ok(HashType::SHA1(file_contents_b64.try_into()?)),
            "sha256" | "sha-256" => Ok(HashType::SHA256(file_contents_b64.try_into()?)),
            "sha384" | "sha-384" => Ok(HashType::SHA384(file_contents_b64.try_into()?)),
            "sha512" | "sha-512" => Ok(HashType::SHA512(file_contents_b64.try_into()?)),
            _ => Err(DigestError("Invalid hash type".into())),
        }
    }

    /// Return the name of the hash type, used to decide
    /// on the database field to find the match
    #[inline]
    #[must_use]
    pub fn name(&self) -> &'static str {
        match self {
            HashType::Md5(_) => "md5",
            HashType::SHA1(_) => "sha1",
            HashType::SHA256(_) => "sha256",
            HashType::SHA384(_) => "sha384",
            HashType::SHA512(_) => "sha512",
        }
    }

    /// Unwrap the hash from the enum's types
    #[inline]
    #[must_use]
    pub fn the_hash(&self) -> String {
        match self {
            HashType::Md5(h) => h.to_string(),
            HashType::SHA1(h) => h.to_string(),
            HashType::SHA256(h) => h.to_string(),
            HashType::SHA384(h) => h.to_string(),
            HashType::SHA512(h) => h.to_string(),
        }
    }

    /// Get the inner bytes of the hash
    #[inline]
    #[must_use]
    pub fn bytes(&self) -> &[u8] {
        match self {
            HashType::Md5(h) => &h.0,
            HashType::SHA1(h) => &h.0,
            HashType::SHA256(h) => &h.0,
            HashType::SHA384(h) => &h.0,
            HashType::SHA512(h) => &h.0,
        }
    }

    /// Create a `content-digest` header from the hash type.
    #[inline]
    #[must_use]
    pub fn content_digest_header(&self) -> String {
        format!(
            "{}={}",
            self.name(),
            general_purpose::STANDARD.encode(self.the_hash())
        )
    }

    /// Test that this hash matches the given bytes.
    #[must_use]
    pub fn verify(&self, bytes: &[u8]) -> bool {
        use md5::Digest;

        match self {
            HashType::Md5(h) => md5::Md5::digest(bytes).as_slice().eq(&h.0),
            HashType::SHA1(h) => sha1::Sha1::digest(bytes).as_slice().eq(&h.0),
            HashType::SHA256(h) => sha2::Sha256::digest(bytes).as_slice().eq(&h.0),
            HashType::SHA384(h) => sha2::Sha384::digest(bytes).as_slice().eq(&h.0),
            HashType::SHA512(h) => sha2::Sha512::digest(bytes).as_slice().eq(&h.0),
        }
    }
}

impl TryFrom<&str> for HashType {
    type Error = DigestError;

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        let decoded = hex::decode(value).map_err(|e| DigestError(e.to_string()))?;
        Ok(match decoded.len() {
            16 => HashType::Md5(Digest::try_from(decoded)?),
            20 => HashType::SHA1(Digest::try_from(decoded)?),
            32 => HashType::SHA256(Digest::try_from(decoded)?),
            48 => HashType::SHA384(Digest::try_from(decoded)?),
            64 => HashType::SHA512(Digest::try_from(decoded)?),
            _ => return Err(DigestError(format!("unknown hash size {}", value.len()))),
        })
    }
}

impl TryFrom<&[u8]> for HashType {
    type Error = DigestError;
    fn try_from(digest: &[u8]) -> Result<Self, Self::Error> {
        Ok(match digest.len() {
            16 => HashType::Md5(Digest(
                digest
                    .try_into()
                    .map_err(|_| DigestError("Invalid MD5".into()))?,
            )),
            20 => HashType::SHA1(Digest(
                digest
                    .try_into()
                    .map_err(|_| DigestError("Invalid SHA1".into()))?,
            )),
            32 => HashType::SHA256(Digest(
                digest
                    .try_into()
                    .map_err(|_| DigestError("Invalid SHA-256".into()))?,
            )),
            48 => HashType::SHA384(Digest(
                digest
                    .try_into()
                    .map_err(|_| DigestError("Invalid SHA-384".into()))?,
            )),
            64 => HashType::SHA512(Digest(
                digest
                    .try_into()
                    .map_err(|_| DigestError("Invalid SHA-512".into()))?,
            )),
            _ => return Err(DigestError(format!("unknown hash size {}", digest.len()))),
        })
    }
}

impl From<Uuid> for HashType {
    fn from(uuid: Uuid) -> Self {
        HashType::Md5(Digest::from(uuid))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn strings() {
        let digest = Digest([0x00, 0x11, 0x22, 0x33]);
        assert_eq!(digest.to_string(), "00112233");
        assert!(HashType::try_from("00112233").is_err());
    }

    #[test]
    fn sha1() {
        const TEST: &str = "3204c1ca863c2068214900e831fb8047b934bf88";

        let digest = HashType::try_from(TEST).unwrap();
        assert_eq!(digest.name(), "sha1");

        if let HashType::Md5(_) = digest {
            panic!("Failed: SHA-1 hash was made into MD-5");
        }

        if let HashType::SHA256(_) = digest {
            panic!("Failed: SHA-1 hash was made into SHA-256");
        }

        if let HashType::SHA384(_) = digest {
            panic!("Failed: SHA-1 hash was made into SHA-384");
        }

        if let HashType::SHA512(_) = digest {
            panic!("Failed: SHA-1 hash was made into SHA-512");
        }
    }

    #[test]
    fn sha256() {
        const TEST: &str = "d154b8420fc56a629df2e6d918be53310d8ac39a926aa5f60ae59a66298969a0";

        let digest = HashType::try_from(TEST).unwrap();
        assert_eq!(digest.name(), "sha256");

        if let HashType::Md5(_) = digest {
            panic!("Failed: SHA-256 hash was made into MD-5");
        }

        if let HashType::SHA1(_) = digest {
            panic!("Failed: SHA-256 hash was made into SHA-1");
        }

        if let HashType::SHA384(_) = digest {
            panic!("Failed: SHA-256 hash was made into SHA-384");
        }

        if let HashType::SHA512(_) = digest {
            panic!("Failed: SHA-256 hash was made into SHA-512");
        }
    }

    #[test]
    fn sha512() {
        const TEST: &str = "dafe60f7d02b0151909550d6f20343d0fe374b044d40221c13295a312489e1b702edbeac99ffda85f61b812b1ddd0c9394cda0c1162bffb716f04d996ff73cdf";

        let digest = HashType::try_from(TEST).unwrap();
        assert_eq!(digest.name(), "sha512");

        if let HashType::Md5(_) = digest {
            panic!("Failed: SHA-512 hash was made into MD-5");
        }

        if let HashType::SHA1(_) = digest {
            panic!("Failed: SHA-512 hash was made into SHA-1");
        }

        if let HashType::SHA256(_) = digest {
            panic!("Failed: SHA-512 hash was made into SHA-256");
        }

        if let HashType::SHA384(_) = digest {
            panic!("Failed: SHA-512 hash was made into SHA-384");
        }
    }
}