1use crate::{PqcError, Result};
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use zeroize::Zeroize;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum KemAlgorithm {
12 MlKem512,
14 MlKem768,
16 MlKem1024,
18 X25519,
20}
21
22impl KemAlgorithm {
23 pub fn encap_key_size(&self) -> usize {
25 match self {
26 Self::MlKem512 => 800, Self::MlKem768 => 1184, Self::MlKem1024 => 1568, Self::X25519 => 32,
30 }
31 }
32
33 pub fn ciphertext_size(&self) -> usize {
35 match self {
36 Self::MlKem512 => 768, Self::MlKem768 => 1088, Self::MlKem1024 => 1568, Self::X25519 => 32,
40 }
41 }
42
43 pub fn shared_secret_size(&self) -> usize {
45 32 }
47
48 #[cfg(feature = "ml-kem")]
50 #[allow(dead_code)] pub(crate) fn to_saorsa_variant(&self) -> saorsa_pqc::MlKemVariant {
52 match self {
53 Self::MlKem512 => saorsa_pqc::MlKemVariant::MlKem512,
54 Self::MlKem768 => saorsa_pqc::MlKemVariant::MlKem768,
55 Self::MlKem1024 => saorsa_pqc::MlKemVariant::MlKem1024,
56 Self::X25519 => panic!("X25519 is not a ML-KEM algorithm"),
57 }
58 }
59}
60
61impl Default for KemAlgorithm {
62 fn default() -> Self {
63 Self::MlKem768 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct EncapsulationKey {
70 pub algorithm: KemAlgorithm,
71 pub key_bytes: Vec<u8>,
72}
73
74#[derive(Clone)]
76pub struct DecapsulationKey {
77 pub algorithm: KemAlgorithm,
78 pub key_bytes: Vec<u8>,
79}
80
81impl Drop for DecapsulationKey {
82 fn drop(&mut self) {
83 self.key_bytes.zeroize();
84 }
85}
86
87pub struct KemKeyPair {
89 pub encap_key: EncapsulationKey,
90 pub decap_key: DecapsulationKey,
91}
92
93pub struct KemOutput {
95 pub ciphertext: Vec<u8>,
96 pub shared_secret: [u8; 32],
97}
98
99#[async_trait]
101pub trait Kem: Send + Sync {
102 async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair>;
104
105 async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput>;
107
108 async fn decapsulate(
110 &self,
111 decap_key: &DecapsulationKey,
112 ciphertext: &[u8],
113 ) -> Result<[u8; 32]>;
114}
115
116#[cfg(feature = "ml-kem")]
118pub struct MlKem {
119 _phantom: std::marker::PhantomData<()>,
121}
122
123#[cfg(feature = "ml-kem")]
124impl MlKem {
125 pub fn new() -> Self {
126 Self {
127 _phantom: std::marker::PhantomData,
128 }
129 }
130}
131
132#[cfg(feature = "ml-kem")]
133#[async_trait]
134impl Kem for MlKem {
135 async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
136 use saorsa_pqc::{MlKem768, MlKemOperations};
137
138 if matches!(alg, KemAlgorithm::X25519) {
139 return Err(PqcError::UnsupportedAlgorithm(
140 "Use X25519Kem for X25519".into(),
141 ));
142 }
143
144 let ml_kem = MlKem768::new();
147 let (pub_key, sec_key) = ml_kem
148 .generate_keypair()
149 .map_err(|e| PqcError::KemError(format!("Keypair generation failed: {:?}", e)))?;
150
151 Ok(KemKeyPair {
152 encap_key: EncapsulationKey {
153 algorithm: alg,
154 key_bytes: pub_key.as_bytes().to_vec(),
155 },
156 decap_key: DecapsulationKey {
157 algorithm: alg,
158 key_bytes: sec_key.as_bytes().to_vec(),
159 },
160 })
161 }
162
163 async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
164 use saorsa_pqc::{MlKem768, MlKemOperations, MlKemPublicKey};
165
166 let ml_kem = MlKem768::new();
167 let pub_key = MlKemPublicKey::from_bytes(&encap_key.key_bytes)
168 .map_err(|e| PqcError::KemError(format!("Invalid public key: {:?}", e)))?;
169
170 let (ciphertext, shared_secret) = ml_kem
171 .encapsulate(&pub_key)
172 .map_err(|e| PqcError::KemError(format!("Encapsulation failed: {:?}", e)))?;
173
174 let mut secret_array = [0u8; 32];
175 secret_array.copy_from_slice(shared_secret.as_bytes());
176
177 Ok(KemOutput {
178 ciphertext: ciphertext.as_bytes().to_vec(),
179 shared_secret: secret_array,
180 })
181 }
182
183 async fn decapsulate(
184 &self,
185 decap_key: &DecapsulationKey,
186 ciphertext: &[u8],
187 ) -> Result<[u8; 32]> {
188 use saorsa_pqc::{MlKem768, MlKemCiphertext, MlKemOperations, MlKemSecretKey};
189
190 let ml_kem = MlKem768::new();
191 let sec_key = MlKemSecretKey::from_bytes(&decap_key.key_bytes)
192 .map_err(|e| PqcError::KemError(format!("Invalid secret key: {:?}", e)))?;
193
194 let ciphertext = MlKemCiphertext::from_bytes(ciphertext)
195 .map_err(|e| PqcError::KemError(format!("Invalid ciphertext: {:?}", e)))?;
196
197 let shared_secret = ml_kem
198 .decapsulate(&sec_key, &ciphertext)
199 .map_err(|e| PqcError::KemError(format!("Decapsulation failed: {:?}", e)))?;
200
201 let mut secret_array = [0u8; 32];
202 secret_array.copy_from_slice(shared_secret.as_bytes());
203 Ok(secret_array)
204 }
205}
206
207#[cfg(feature = "hybrid")]
209pub struct X25519Kem;
210
211#[cfg(feature = "hybrid")]
212#[async_trait]
213impl Kem for X25519Kem {
214 async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
215 if !matches!(alg, KemAlgorithm::X25519) {
216 return Err(PqcError::UnsupportedAlgorithm(
217 "Use MlKem for ML-KEM".into(),
218 ));
219 }
220
221 use rand::rngs::OsRng;
222 use x25519_dalek::{PublicKey, StaticSecret};
223
224 let secret = StaticSecret::random_from_rng(&mut OsRng);
225 let public = PublicKey::from(&secret);
226
227 Ok(KemKeyPair {
228 encap_key: EncapsulationKey {
229 algorithm: alg,
230 key_bytes: public.as_bytes().to_vec(),
231 },
232 decap_key: DecapsulationKey {
233 algorithm: alg,
234 key_bytes: secret.to_bytes().to_vec(),
235 },
236 })
237 }
238
239 async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
240 use rand::rngs::OsRng;
241 use x25519_dalek::{PublicKey, StaticSecret};
242
243 let ephemeral_secret = StaticSecret::random_from_rng(&mut OsRng);
245 let ephemeral_public = PublicKey::from(&ephemeral_secret);
246
247 let mut pk_bytes = [0u8; 32];
249 pk_bytes.copy_from_slice(&encap_key.key_bytes);
250 let recipient_public = PublicKey::from(pk_bytes);
251
252 let shared = ephemeral_secret.diffie_hellman(&recipient_public);
254
255 Ok(KemOutput {
256 ciphertext: ephemeral_public.as_bytes().to_vec(),
257 shared_secret: *shared.as_bytes(),
258 })
259 }
260
261 async fn decapsulate(
262 &self,
263 decap_key: &DecapsulationKey,
264 ciphertext: &[u8],
265 ) -> Result<[u8; 32]> {
266 use x25519_dalek::{PublicKey, StaticSecret};
267
268 let mut sk_bytes = [0u8; 32];
270 sk_bytes.copy_from_slice(&decap_key.key_bytes);
271 let secret = StaticSecret::from(sk_bytes);
272
273 let mut ephem_bytes = [0u8; 32];
275 ephem_bytes.copy_from_slice(ciphertext);
276 let ephemeral_public = PublicKey::from(ephem_bytes);
277
278 let shared = secret.diffie_hellman(&ephemeral_public);
280 Ok(*shared.as_bytes())
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[tokio::test]
289 #[cfg(feature = "ml-kem")]
290 async fn test_ml_kem_768() {
291 let kem = MlKem::new();
292 let keypair = kem.generate_keypair(KemAlgorithm::MlKem768).await.unwrap();
293
294 let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
295 let recovered = kem
296 .decapsulate(&keypair.decap_key, &output.ciphertext)
297 .await
298 .unwrap();
299
300 assert_eq!(output.shared_secret, recovered);
301 assert_eq!(output.ciphertext.len(), 1088); }
303
304 #[tokio::test]
305 #[cfg(feature = "hybrid")]
306 async fn test_x25519() {
307 let kem = X25519Kem;
308 let keypair = kem.generate_keypair(KemAlgorithm::X25519).await.unwrap();
309
310 let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
311 let recovered = kem
312 .decapsulate(&keypair.decap_key, &output.ciphertext)
313 .await
314 .unwrap();
315
316 assert_eq!(output.shared_secret, recovered);
317 }
318}