rust-auth-utils 1.0.0

A rust port of @better-auth/utils.
Documentation
// based on https://github.com/better-auth/utils/blob/main/src/hmac.ts

use crate::types::{EncodingFormat, SHAFamily};
use crate::{base64, hex};
use hmac::{Hmac, Mac};
use sha1::Sha1;
use sha2::{Sha256, Sha384, Sha512};

pub struct HmacBuilder {
    algorithm: SHAFamily,
    encoding: EncodingFormat,
}

impl Default for HmacBuilder {
    fn default() -> Self {
        Self {
            algorithm: SHAFamily::SHA256,
            encoding: EncodingFormat::None,
        }
    }
}

impl HmacBuilder {
    pub fn new(algorithm: Option<SHAFamily>, encoding: Option<EncodingFormat>) -> Self {
        Self {
            algorithm: algorithm.unwrap_or(SHAFamily::SHA256),
            encoding: encoding.unwrap_or(EncodingFormat::None),
        }
    }

    pub fn sign(&self, key: &[u8], data: &[u8]) -> Result<Vec<u8>, &'static str> {
        let signature = match self.algorithm {
            SHAFamily::SHA1 => {
                let mut mac =
                    Hmac::<Sha1>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
                mac.update(data);
                mac.finalize().into_bytes().to_vec()
            }
            SHAFamily::SHA256 => {
                let mut mac =
                    Hmac::<Sha256>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
                mac.update(data);
                mac.finalize().into_bytes().to_vec()
            }
            SHAFamily::SHA384 => {
                let mut mac =
                    Hmac::<Sha384>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
                mac.update(data);
                mac.finalize().into_bytes().to_vec()
            }
            SHAFamily::SHA512 => {
                let mut mac =
                    Hmac::<Sha512>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
                mac.update(data);
                mac.finalize().into_bytes().to_vec()
            }
        };

        match self.encoding {
            EncodingFormat::Hex => Ok(hex::Hex::encode(&signature).as_bytes().to_vec()),
            EncodingFormat::Base64 => Ok(base64::Base64::encode(&signature, Some(true))
                .as_bytes()
                .to_vec()),
            EncodingFormat::Base64Url => Ok(base64::Base64Url::encode(&signature, Some(true))
                .as_bytes()
                .to_vec()),
            EncodingFormat::Base64UrlNoPad => {
                Ok(base64::Base64Url::encode(&signature, Some(false))
                    .as_bytes()
                    .to_vec())
            }
            EncodingFormat::None => Ok(signature),
        }
    }

    pub fn verify(&self, key: &[u8], data: &[u8], signature: &[u8]) -> Result<bool, &'static str> {
        let decoded_signature = match self.encoding {
            EncodingFormat::Hex => {
                let hex_str = std::str::from_utf8(signature).map_err(|_| "Invalid UTF-8")?;
                hex::Hex::decode(hex_str).map_err(|_| "Invalid hex encoding")?
            }
            EncodingFormat::Base64 => {
                let base64_str = std::str::from_utf8(signature).map_err(|_| "Invalid UTF-8")?;
                // Strict base64 format validation
                if base64_str.len() % 4 != 0 {
                    return Err("Invalid base64 encoding: length not multiple of 4");
                }
                if !base64_str
                    .chars()
                    .all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=')
                {
                    return Err("Invalid base64 encoding");
                }
                // Validate padding
                let padding_count = base64_str.chars().rev().take_while(|&c| c == '=').count();
                if padding_count > 2 {
                    return Err("Invalid base64 padding");
                }
                // Ensure no padding characters except at the end
                if base64_str[..base64_str.len() - padding_count].contains('=') {
                    return Err("Invalid base64 encoding: padding in wrong position");
                }
                // Ensure it's actually base64 by checking for base64-specific characters
                let non_padding_part = &base64_str[..base64_str.len() - padding_count];
                if !non_padding_part.is_empty()
                    && !non_padding_part.contains(|c| c == '+' || c == '/')
                {
                    return Err("Invalid base64 encoding: missing base64-specific characters");
                }
                base64::Base64::decode(base64_str).map_err(|_| "Invalid base64 encoding")?
            }
            EncodingFormat::Base64Url | EncodingFormat::Base64UrlNoPad => {
                let base64_str = std::str::from_utf8(signature).map_err(|_| "Invalid UTF-8")?;
                // Strict base64url format validation
                if !base64_str
                    .chars()
                    .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '=')
                {
                    return Err("Invalid base64url encoding");
                }
                // For Base64UrlNoPad, reject if there's padding
                if self.encoding == EncodingFormat::Base64UrlNoPad && base64_str.contains('=') {
                    return Err("Invalid base64url encoding: unexpected padding");
                }
                // For Base64Url, validate padding
                if self.encoding == EncodingFormat::Base64Url {
                    if base64_str.len() % 4 != 0 {
                        return Err("Invalid base64url encoding: length not multiple of 4");
                    }
                    let padding_count = base64_str.chars().rev().take_while(|&c| c == '=').count();
                    if padding_count > 2 {
                        return Err("Invalid base64url padding");
                    }
                    // Ensure no padding characters except at the end
                    if base64_str[..base64_str.len() - padding_count].contains('=') {
                        return Err("Invalid base64url encoding: padding in wrong position");
                    }
                }
                // Ensure it's actually base64url by checking for base64url-specific characters
                let non_padding_part = if self.encoding == EncodingFormat::Base64Url {
                    let padding_count = base64_str.chars().rev().take_while(|&c| c == '=').count();
                    &base64_str[..base64_str.len() - padding_count]
                } else {
                    base64_str
                };
                if !non_padding_part.is_empty()
                    && !non_padding_part.contains(|c| c == '-' || c == '_')
                {
                    return Err(
                        "Invalid base64url encoding: missing base64url-specific characters",
                    );
                }
                base64::Base64Url::decode(base64_str).map_err(|_| "Invalid base64url encoding")?
            }
            EncodingFormat::None => signature.to_vec(),
        };

        let result = match self.algorithm {
            SHAFamily::SHA1 => {
                let mut mac =
                    Hmac::<Sha1>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
                mac.update(data);
                mac.verify_slice(&decoded_signature).is_ok()
            }
            SHAFamily::SHA256 => {
                let mut mac =
                    Hmac::<Sha256>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
                mac.update(data);
                mac.verify_slice(&decoded_signature).is_ok()
            }
            SHAFamily::SHA384 => {
                let mut mac =
                    Hmac::<Sha384>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
                mac.update(data);
                mac.verify_slice(&decoded_signature).is_ok()
            }
            SHAFamily::SHA512 => {
                let mut mac =
                    Hmac::<Sha512>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
                mac.update(data);
                mac.verify_slice(&decoded_signature).is_ok()
            }
        };

        Ok(result)
    }
}