hanzo-crypto 0.1.0

Pure-Rust cryptography library for Hanzo ecosystem including post-quantum algorithms
Documentation
//! Key Encapsulation Mechanism (KEM) implementation
//! FIPS 203 (ML-KEM/Kyber) support with hybrid X25519 option

use crate::{PqcError, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use zeroize::Zeroize;

/// KEM algorithms supported
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum KemAlgorithm {
    /// ML-KEM-512 (NIST Level 1 - 128-bit security)
    MlKem512,
    /// ML-KEM-768 (NIST Level 3 - 192-bit security) - RECOMMENDED DEFAULT
    MlKem768,
    /// ML-KEM-1024 (NIST Level 5 - 256-bit security)
    MlKem1024,
    /// Classic X25519 for compatibility
    X25519,
}

impl KemAlgorithm {
    /// Get the encapsulation key size in bytes
    pub fn encap_key_size(&self) -> usize {
        match self {
            Self::MlKem512 => 800,   // Per FIPS 203
            Self::MlKem768 => 1184,  // Per FIPS 203
            Self::MlKem1024 => 1568, // Per FIPS 203
            Self::X25519 => 32,
        }
    }

    /// Get the ciphertext size in bytes
    pub fn ciphertext_size(&self) -> usize {
        match self {
            Self::MlKem512 => 768,   // Per FIPS 203
            Self::MlKem768 => 1088,  // Per FIPS 203
            Self::MlKem1024 => 1568, // Per FIPS 203
            Self::X25519 => 32,
        }
    }

    /// Get the shared secret size (always 32 bytes for ML-KEM)
    pub fn shared_secret_size(&self) -> usize {
        32 // All ML-KEM variants produce 32-byte shared secrets
    }

    /// Get the saorsa_pqc ML-KEM variant
    #[cfg(feature = "ml-kem")]
    #[allow(dead_code)] // Future use when ML-KEM integration is complete
    pub(crate) fn to_saorsa_variant(&self) -> saorsa_pqc::MlKemVariant {
        match self {
            Self::MlKem512 => saorsa_pqc::MlKemVariant::MlKem512,
            Self::MlKem768 => saorsa_pqc::MlKemVariant::MlKem768,
            Self::MlKem1024 => saorsa_pqc::MlKemVariant::MlKem1024,
            Self::X25519 => panic!("X25519 is not a ML-KEM algorithm"),
        }
    }
}

impl Default for KemAlgorithm {
    fn default() -> Self {
        Self::MlKem768 // NIST recommended default
    }
}

/// Encapsulation key (public key for KEM)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncapsulationKey {
    pub algorithm: KemAlgorithm,
    pub key_bytes: Vec<u8>,
}

/// Decapsulation key (private key for KEM)
#[derive(Clone)]
pub struct DecapsulationKey {
    pub algorithm: KemAlgorithm,
    pub key_bytes: Vec<u8>,
}

impl Drop for DecapsulationKey {
    fn drop(&mut self) {
        self.key_bytes.zeroize();
    }
}

/// KEM key pair
pub struct KemKeyPair {
    pub encap_key: EncapsulationKey,
    pub decap_key: DecapsulationKey,
}

/// KEM ciphertext and shared secret
pub struct KemOutput {
    pub ciphertext: Vec<u8>,
    pub shared_secret: [u8; 32],
}

/// Trait for KEM operations
#[async_trait]
pub trait Kem: Send + Sync {
    /// Generate a new key pair
    async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair>;

    /// Encapsulate (generate ciphertext and shared secret)
    async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput>;

    /// Decapsulate (recover shared secret from ciphertext)
    async fn decapsulate(
        &self,
        decap_key: &DecapsulationKey,
        ciphertext: &[u8],
    ) -> Result<[u8; 32]>;
}

/// ML-KEM implementation using saorsa_pqc
#[cfg(feature = "ml-kem")]
pub struct MlKem {
    // Cache for algorithm instances
    _phantom: std::marker::PhantomData<()>,
}

#[cfg(feature = "ml-kem")]
impl MlKem {
    pub fn new() -> Self {
        Self {
            _phantom: std::marker::PhantomData,
        }
    }
}

#[cfg(feature = "ml-kem")]
#[async_trait]
impl Kem for MlKem {
    async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
        use saorsa_pqc::{MlKem768, MlKemOperations};

        if matches!(alg, KemAlgorithm::X25519) {
            return Err(PqcError::UnsupportedAlgorithm(
                "Use X25519Kem for X25519".into(),
            ));
        }

        // Use MlKem768 as the default implementation
        // TODO: Support other variants based on alg parameter
        let ml_kem = MlKem768::new();
        let (pub_key, sec_key) = ml_kem
            .generate_keypair()
            .map_err(|e| PqcError::KemError(format!("Keypair generation failed: {:?}", e)))?;

        Ok(KemKeyPair {
            encap_key: EncapsulationKey {
                algorithm: alg,
                key_bytes: pub_key.as_bytes().to_vec(),
            },
            decap_key: DecapsulationKey {
                algorithm: alg,
                key_bytes: sec_key.as_bytes().to_vec(),
            },
        })
    }

    async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
        use saorsa_pqc::{MlKem768, MlKemOperations, MlKemPublicKey};

        let ml_kem = MlKem768::new();
        let pub_key = MlKemPublicKey::from_bytes(&encap_key.key_bytes)
            .map_err(|e| PqcError::KemError(format!("Invalid public key: {:?}", e)))?;

        let (ciphertext, shared_secret) = ml_kem
            .encapsulate(&pub_key)
            .map_err(|e| PqcError::KemError(format!("Encapsulation failed: {:?}", e)))?;

        let mut secret_array = [0u8; 32];
        secret_array.copy_from_slice(shared_secret.as_bytes());

        Ok(KemOutput {
            ciphertext: ciphertext.as_bytes().to_vec(),
            shared_secret: secret_array,
        })
    }

    async fn decapsulate(
        &self,
        decap_key: &DecapsulationKey,
        ciphertext: &[u8],
    ) -> Result<[u8; 32]> {
        use saorsa_pqc::{MlKem768, MlKemCiphertext, MlKemOperations, MlKemSecretKey};

        let ml_kem = MlKem768::new();
        let sec_key = MlKemSecretKey::from_bytes(&decap_key.key_bytes)
            .map_err(|e| PqcError::KemError(format!("Invalid secret key: {:?}", e)))?;

        let ciphertext = MlKemCiphertext::from_bytes(ciphertext)
            .map_err(|e| PqcError::KemError(format!("Invalid ciphertext: {:?}", e)))?;

        let shared_secret = ml_kem
            .decapsulate(&sec_key, &ciphertext)
            .map_err(|e| PqcError::KemError(format!("Decapsulation failed: {:?}", e)))?;

        let mut secret_array = [0u8; 32];
        secret_array.copy_from_slice(shared_secret.as_bytes());
        Ok(secret_array)
    }
}

/// X25519 KEM for backward compatibility and hybrid mode
#[cfg(feature = "hybrid")]
pub struct X25519Kem;

#[cfg(feature = "hybrid")]
#[async_trait]
impl Kem for X25519Kem {
    async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
        if !matches!(alg, KemAlgorithm::X25519) {
            return Err(PqcError::UnsupportedAlgorithm(
                "Use MlKem for ML-KEM".into(),
            ));
        }

        use rand::rngs::OsRng;
        use x25519_dalek::{PublicKey, StaticSecret};

        let secret = StaticSecret::random_from_rng(&mut OsRng);
        let public = PublicKey::from(&secret);

        Ok(KemKeyPair {
            encap_key: EncapsulationKey {
                algorithm: alg,
                key_bytes: public.as_bytes().to_vec(),
            },
            decap_key: DecapsulationKey {
                algorithm: alg,
                key_bytes: secret.to_bytes().to_vec(),
            },
        })
    }

    async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
        use rand::rngs::OsRng;
        use x25519_dalek::{PublicKey, StaticSecret};

        // Generate ephemeral key pair
        let ephemeral_secret = StaticSecret::random_from_rng(&mut OsRng);
        let ephemeral_public = PublicKey::from(&ephemeral_secret);

        // Parse recipient's public key
        let mut pk_bytes = [0u8; 32];
        pk_bytes.copy_from_slice(&encap_key.key_bytes);
        let recipient_public = PublicKey::from(pk_bytes);

        // Compute shared secret
        let shared = ephemeral_secret.diffie_hellman(&recipient_public);

        Ok(KemOutput {
            ciphertext: ephemeral_public.as_bytes().to_vec(),
            shared_secret: *shared.as_bytes(),
        })
    }

    async fn decapsulate(
        &self,
        decap_key: &DecapsulationKey,
        ciphertext: &[u8],
    ) -> Result<[u8; 32]> {
        use x25519_dalek::{PublicKey, StaticSecret};

        // Parse private key
        let mut sk_bytes = [0u8; 32];
        sk_bytes.copy_from_slice(&decap_key.key_bytes);
        let secret = StaticSecret::from(sk_bytes);

        // Parse ephemeral public key
        let mut ephem_bytes = [0u8; 32];
        ephem_bytes.copy_from_slice(ciphertext);
        let ephemeral_public = PublicKey::from(ephem_bytes);

        // Compute shared secret
        let shared = secret.diffie_hellman(&ephemeral_public);
        Ok(*shared.as_bytes())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    #[cfg(feature = "ml-kem")]
    async fn test_ml_kem_768() {
        let kem = MlKem::new();
        let keypair = kem.generate_keypair(KemAlgorithm::MlKem768).await.unwrap();

        let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
        let recovered = kem
            .decapsulate(&keypair.decap_key, &output.ciphertext)
            .await
            .unwrap();

        assert_eq!(output.shared_secret, recovered);
        assert_eq!(output.ciphertext.len(), 1088); // ML-KEM-768 ciphertext size
    }

    #[tokio::test]
    #[cfg(feature = "hybrid")]
    async fn test_x25519() {
        let kem = X25519Kem;
        let keypair = kem.generate_keypair(KemAlgorithm::X25519).await.unwrap();

        let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
        let recovered = kem
            .decapsulate(&keypair.decap_key, &output.ciphertext)
            .await
            .unwrap();

        assert_eq!(output.shared_secret, recovered);
    }
}