rust-auth-utils 1.0.0

A rust port of @better-auth/utils.
Documentation
// based on https://github.com/better-auth/utils/blob/main/src/rsa.ts

use crate::types::ExportKeyFormat;
use rsa::{
    pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey},
    Oaep, Pss, RsaPrivateKey, RsaPublicKey,
};
use sha2::{Digest, Sha256, Sha384, Sha512};

const DEFAULT_RSA_KEY_SIZE: usize = 2048;
const DEFAULT_SALT_LENGTH: usize = 32;

#[derive(Clone)]
pub struct RsaKeyPair {
    pub private_key: RsaPrivateKey,
    pub public_key: RsaPublicKey,
    pub hash_algorithm: HashAlgorithm,
}

#[derive(Clone, Copy)]
pub enum HashAlgorithm {
    SHA256,
    SHA384,
    SHA512,
}

impl HashAlgorithm {
    fn digest(&self, data: &[u8]) -> Vec<u8> {
        match self {
            HashAlgorithm::SHA256 => {
                let mut hasher = Sha256::new();
                hasher.update(data);
                hasher.finalize().to_vec()
            }
            HashAlgorithm::SHA384 => {
                let mut hasher = Sha384::new();
                hasher.update(data);
                hasher.finalize().to_vec()
            }
            HashAlgorithm::SHA512 => {
                let mut hasher = Sha512::new();
                hasher.update(data);
                hasher.finalize().to_vec()
            }
        }
    }

    fn oaep_padding(&self) -> Oaep {
        match self {
            HashAlgorithm::SHA256 => Oaep::new::<Sha256>(),
            HashAlgorithm::SHA384 => Oaep::new::<Sha384>(),
            HashAlgorithm::SHA512 => Oaep::new::<Sha512>(),
        }
    }
}

pub struct RSA;

impl RSA {
    pub async fn generate_key_pair(
        modulus_length: Option<usize>,
        hash: Option<HashAlgorithm>,
    ) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
        let mut rng = rand::thread_rng();
        let bits = modulus_length.unwrap_or(DEFAULT_RSA_KEY_SIZE);
        let private_key = RsaPrivateKey::new(&mut rng, bits)?;
        let public_key = RsaPublicKey::from(&private_key);
        let hash_algorithm = hash.unwrap_or(HashAlgorithm::SHA256);
        Ok(RsaKeyPair {
            private_key,
            public_key,
            hash_algorithm,
        })
    }

    pub async fn export_key_public(
        key: &RsaKeyPair,
        format: ExportKeyFormat,
    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
        match format {
            ExportKeyFormat::SPKI => Ok(key.public_key.to_public_key_der()?.as_bytes().to_vec()),
            ExportKeyFormat::JWK => {
                // TODO: Implement JWK format
                Err("JWK format not yet implemented".into())
            }
            _ => Err("Unsupported export format for public key".into()),
        }
    }

    pub async fn export_key_private(
        key: &RsaKeyPair,
        format: ExportKeyFormat,
    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
        match format {
            ExportKeyFormat::PKCS8 => Ok(key.private_key.to_pkcs8_der()?.as_bytes().to_vec()),
            ExportKeyFormat::JWK => {
                // TODO: Implement JWK format
                Err("JWK format not yet implemented".into())
            }
            _ => Err("Unsupported export format for private key".into()),
        }
    }

    pub async fn import_key(
        key_data: &[u8],
        format: ExportKeyFormat,
        _for_encryption: bool,
        hash: Option<HashAlgorithm>,
    ) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
        let hash_algorithm = hash.unwrap_or(HashAlgorithm::SHA256);
        match format {
            ExportKeyFormat::SPKI => {
                let public_key = RsaPublicKey::from_public_key_der(key_data)?;
                // Note: private key will be None for imported public keys
                let private_key =
                    RsaPrivateKey::new(&mut rand::thread_rng(), DEFAULT_RSA_KEY_SIZE)?;
                Ok(RsaKeyPair {
                    private_key,
                    public_key,
                    hash_algorithm,
                })
            }
            ExportKeyFormat::JWK => {
                // TODO: Implement JWK format
                Err("JWK format not yet implemented".into())
            }
            _ => Err("Unsupported import format".into()),
        }
    }

    pub async fn encrypt(
        key: &RsaKeyPair,
        data: impl AsRef<[u8]>,
    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
        let mut rng = rand::thread_rng();
        let padding = key.hash_algorithm.oaep_padding();
        Ok(key.public_key.encrypt(&mut rng, padding, data.as_ref())?)
    }

    pub async fn decrypt(
        key: &RsaKeyPair,
        encrypted_data: impl AsRef<[u8]>,
    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
        let padding = key.hash_algorithm.oaep_padding();
        Ok(key.private_key.decrypt(padding, encrypted_data.as_ref())?)
    }

    pub async fn sign(
        key: &RsaKeyPair,
        data: impl AsRef<[u8]>,
        salt_length: Option<usize>,
        hash: Option<HashAlgorithm>,
    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
        let hash_alg = hash.unwrap_or(key.hash_algorithm);
        let hashed = hash_alg.digest(data.as_ref());

        let mut rng = rand::thread_rng();
        let salt_len = salt_length.unwrap_or(DEFAULT_SALT_LENGTH);
        let padding = match hash_alg {
            HashAlgorithm::SHA256 => Pss::new_with_salt::<Sha256>(salt_len),
            HashAlgorithm::SHA384 => Pss::new_with_salt::<Sha384>(salt_len),
            HashAlgorithm::SHA512 => Pss::new_with_salt::<Sha512>(salt_len),
        };
        Ok(key.private_key.sign_with_rng(&mut rng, padding, &hashed)?)
    }

    pub async fn verify(
        key: &RsaKeyPair,
        signature: impl AsRef<[u8]>,
        data: impl AsRef<[u8]>,
        salt_length: Option<usize>,
        hash: Option<HashAlgorithm>,
    ) -> Result<bool, Box<dyn std::error::Error>> {
        let hash_alg = hash.unwrap_or(key.hash_algorithm);
        let hashed = hash_alg.digest(data.as_ref());

        let salt_len = salt_length.unwrap_or(DEFAULT_SALT_LENGTH);
        let padding = match hash_alg {
            HashAlgorithm::SHA256 => Pss::new_with_salt::<Sha256>(salt_len),
            HashAlgorithm::SHA384 => Pss::new_with_salt::<Sha384>(salt_len),
            HashAlgorithm::SHA512 => Pss::new_with_salt::<Sha512>(salt_len),
        };
        Ok(key
            .public_key
            .verify(padding, &hashed, signature.as_ref())
            .is_ok())
    }
}