Skip to main content

hanzo_crypto/
kem.rs

1//! Key Encapsulation Mechanism (KEM) implementation
2//! FIPS 203 (ML-KEM/Kyber) support with hybrid X25519 option
3
4use crate::{PqcError, Result};
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use zeroize::Zeroize;
8
9/// KEM algorithms supported
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum KemAlgorithm {
12    /// ML-KEM-512 (NIST Level 1 - 128-bit security)
13    MlKem512,
14    /// ML-KEM-768 (NIST Level 3 - 192-bit security) - RECOMMENDED DEFAULT
15    MlKem768,
16    /// ML-KEM-1024 (NIST Level 5 - 256-bit security)
17    MlKem1024,
18    /// Classic X25519 for compatibility
19    X25519,
20}
21
22impl KemAlgorithm {
23    /// Get the encapsulation key size in bytes
24    pub fn encap_key_size(&self) -> usize {
25        match self {
26            Self::MlKem512 => 800,   // Per FIPS 203
27            Self::MlKem768 => 1184,  // Per FIPS 203
28            Self::MlKem1024 => 1568, // Per FIPS 203
29            Self::X25519 => 32,
30        }
31    }
32
33    /// Get the ciphertext size in bytes
34    pub fn ciphertext_size(&self) -> usize {
35        match self {
36            Self::MlKem512 => 768,   // Per FIPS 203
37            Self::MlKem768 => 1088,  // Per FIPS 203
38            Self::MlKem1024 => 1568, // Per FIPS 203
39            Self::X25519 => 32,
40        }
41    }
42
43    /// Get the shared secret size (always 32 bytes for ML-KEM)
44    pub fn shared_secret_size(&self) -> usize {
45        32 // All ML-KEM variants produce 32-byte shared secrets
46    }
47
48    /// Get the saorsa_pqc ML-KEM variant
49    #[cfg(feature = "ml-kem")]
50    #[allow(dead_code)] // Future use when ML-KEM integration is complete
51    pub(crate) fn to_saorsa_variant(&self) -> saorsa_pqc::MlKemVariant {
52        match self {
53            Self::MlKem512 => saorsa_pqc::MlKemVariant::MlKem512,
54            Self::MlKem768 => saorsa_pqc::MlKemVariant::MlKem768,
55            Self::MlKem1024 => saorsa_pqc::MlKemVariant::MlKem1024,
56            Self::X25519 => panic!("X25519 is not a ML-KEM algorithm"),
57        }
58    }
59}
60
61impl Default for KemAlgorithm {
62    fn default() -> Self {
63        Self::MlKem768 // NIST recommended default
64    }
65}
66
67/// Encapsulation key (public key for KEM)
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct EncapsulationKey {
70    pub algorithm: KemAlgorithm,
71    pub key_bytes: Vec<u8>,
72}
73
74/// Decapsulation key (private key for KEM)
75#[derive(Clone)]
76pub struct DecapsulationKey {
77    pub algorithm: KemAlgorithm,
78    pub key_bytes: Vec<u8>,
79}
80
81impl Drop for DecapsulationKey {
82    fn drop(&mut self) {
83        self.key_bytes.zeroize();
84    }
85}
86
87/// KEM key pair
88pub struct KemKeyPair {
89    pub encap_key: EncapsulationKey,
90    pub decap_key: DecapsulationKey,
91}
92
93/// KEM ciphertext and shared secret
94pub struct KemOutput {
95    pub ciphertext: Vec<u8>,
96    pub shared_secret: [u8; 32],
97}
98
99/// Trait for KEM operations
100#[async_trait]
101pub trait Kem: Send + Sync {
102    /// Generate a new key pair
103    async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair>;
104
105    /// Encapsulate (generate ciphertext and shared secret)
106    async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput>;
107
108    /// Decapsulate (recover shared secret from ciphertext)
109    async fn decapsulate(
110        &self,
111        decap_key: &DecapsulationKey,
112        ciphertext: &[u8],
113    ) -> Result<[u8; 32]>;
114}
115
116/// ML-KEM implementation using saorsa_pqc
117#[cfg(feature = "ml-kem")]
118pub struct MlKem {
119    // Cache for algorithm instances
120    _phantom: std::marker::PhantomData<()>,
121}
122
123#[cfg(feature = "ml-kem")]
124impl MlKem {
125    pub fn new() -> Self {
126        Self {
127            _phantom: std::marker::PhantomData,
128        }
129    }
130}
131
132#[cfg(feature = "ml-kem")]
133#[async_trait]
134impl Kem for MlKem {
135    async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
136        use saorsa_pqc::{MlKem768, MlKemOperations};
137
138        if matches!(alg, KemAlgorithm::X25519) {
139            return Err(PqcError::UnsupportedAlgorithm(
140                "Use X25519Kem for X25519".into(),
141            ));
142        }
143
144        // Use MlKem768 as the default implementation
145        // TODO: Support other variants based on alg parameter
146        let ml_kem = MlKem768::new();
147        let (pub_key, sec_key) = ml_kem
148            .generate_keypair()
149            .map_err(|e| PqcError::KemError(format!("Keypair generation failed: {:?}", e)))?;
150
151        Ok(KemKeyPair {
152            encap_key: EncapsulationKey {
153                algorithm: alg,
154                key_bytes: pub_key.as_bytes().to_vec(),
155            },
156            decap_key: DecapsulationKey {
157                algorithm: alg,
158                key_bytes: sec_key.as_bytes().to_vec(),
159            },
160        })
161    }
162
163    async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
164        use saorsa_pqc::{MlKem768, MlKemOperations, MlKemPublicKey};
165
166        let ml_kem = MlKem768::new();
167        let pub_key = MlKemPublicKey::from_bytes(&encap_key.key_bytes)
168            .map_err(|e| PqcError::KemError(format!("Invalid public key: {:?}", e)))?;
169
170        let (ciphertext, shared_secret) = ml_kem
171            .encapsulate(&pub_key)
172            .map_err(|e| PqcError::KemError(format!("Encapsulation failed: {:?}", e)))?;
173
174        let mut secret_array = [0u8; 32];
175        secret_array.copy_from_slice(shared_secret.as_bytes());
176
177        Ok(KemOutput {
178            ciphertext: ciphertext.as_bytes().to_vec(),
179            shared_secret: secret_array,
180        })
181    }
182
183    async fn decapsulate(
184        &self,
185        decap_key: &DecapsulationKey,
186        ciphertext: &[u8],
187    ) -> Result<[u8; 32]> {
188        use saorsa_pqc::{MlKem768, MlKemCiphertext, MlKemOperations, MlKemSecretKey};
189
190        let ml_kem = MlKem768::new();
191        let sec_key = MlKemSecretKey::from_bytes(&decap_key.key_bytes)
192            .map_err(|e| PqcError::KemError(format!("Invalid secret key: {:?}", e)))?;
193
194        let ciphertext = MlKemCiphertext::from_bytes(ciphertext)
195            .map_err(|e| PqcError::KemError(format!("Invalid ciphertext: {:?}", e)))?;
196
197        let shared_secret = ml_kem
198            .decapsulate(&sec_key, &ciphertext)
199            .map_err(|e| PqcError::KemError(format!("Decapsulation failed: {:?}", e)))?;
200
201        let mut secret_array = [0u8; 32];
202        secret_array.copy_from_slice(shared_secret.as_bytes());
203        Ok(secret_array)
204    }
205}
206
207/// X25519 KEM for backward compatibility and hybrid mode
208#[cfg(feature = "hybrid")]
209pub struct X25519Kem;
210
211#[cfg(feature = "hybrid")]
212#[async_trait]
213impl Kem for X25519Kem {
214    async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
215        if !matches!(alg, KemAlgorithm::X25519) {
216            return Err(PqcError::UnsupportedAlgorithm(
217                "Use MlKem for ML-KEM".into(),
218            ));
219        }
220
221        use rand::rngs::OsRng;
222        use x25519_dalek::{PublicKey, StaticSecret};
223
224        let secret = StaticSecret::random_from_rng(&mut OsRng);
225        let public = PublicKey::from(&secret);
226
227        Ok(KemKeyPair {
228            encap_key: EncapsulationKey {
229                algorithm: alg,
230                key_bytes: public.as_bytes().to_vec(),
231            },
232            decap_key: DecapsulationKey {
233                algorithm: alg,
234                key_bytes: secret.to_bytes().to_vec(),
235            },
236        })
237    }
238
239    async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
240        use rand::rngs::OsRng;
241        use x25519_dalek::{PublicKey, StaticSecret};
242
243        // Generate ephemeral key pair
244        let ephemeral_secret = StaticSecret::random_from_rng(&mut OsRng);
245        let ephemeral_public = PublicKey::from(&ephemeral_secret);
246
247        // Parse recipient's public key
248        let mut pk_bytes = [0u8; 32];
249        pk_bytes.copy_from_slice(&encap_key.key_bytes);
250        let recipient_public = PublicKey::from(pk_bytes);
251
252        // Compute shared secret
253        let shared = ephemeral_secret.diffie_hellman(&recipient_public);
254
255        Ok(KemOutput {
256            ciphertext: ephemeral_public.as_bytes().to_vec(),
257            shared_secret: *shared.as_bytes(),
258        })
259    }
260
261    async fn decapsulate(
262        &self,
263        decap_key: &DecapsulationKey,
264        ciphertext: &[u8],
265    ) -> Result<[u8; 32]> {
266        use x25519_dalek::{PublicKey, StaticSecret};
267
268        // Parse private key
269        let mut sk_bytes = [0u8; 32];
270        sk_bytes.copy_from_slice(&decap_key.key_bytes);
271        let secret = StaticSecret::from(sk_bytes);
272
273        // Parse ephemeral public key
274        let mut ephem_bytes = [0u8; 32];
275        ephem_bytes.copy_from_slice(ciphertext);
276        let ephemeral_public = PublicKey::from(ephem_bytes);
277
278        // Compute shared secret
279        let shared = secret.diffie_hellman(&ephemeral_public);
280        Ok(*shared.as_bytes())
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[tokio::test]
289    #[cfg(feature = "ml-kem")]
290    async fn test_ml_kem_768() {
291        let kem = MlKem::new();
292        let keypair = kem.generate_keypair(KemAlgorithm::MlKem768).await.unwrap();
293
294        let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
295        let recovered = kem
296            .decapsulate(&keypair.decap_key, &output.ciphertext)
297            .await
298            .unwrap();
299
300        assert_eq!(output.shared_secret, recovered);
301        assert_eq!(output.ciphertext.len(), 1088); // ML-KEM-768 ciphertext size
302    }
303
304    #[tokio::test]
305    #[cfg(feature = "hybrid")]
306    async fn test_x25519() {
307        let kem = X25519Kem;
308        let keypair = kem.generate_keypair(KemAlgorithm::X25519).await.unwrap();
309
310        let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
311        let recovered = kem
312            .decapsulate(&keypair.decap_key, &output.ciphertext)
313            .await
314            .unwrap();
315
316        assert_eq!(output.shared_secret, recovered);
317    }
318}