orion/hazardous/kem/ml_kem/
mlkem768.rs1use crate::errors::UnknownCryptoError;
80use crate::hazardous::kem::ml_kem::internal::*;
81pub use crate::hazardous::kem::ml_kem::Seed;
82use zeroize::Zeroize;
83
84construct_secret_key! {
85 (SharedSecret, test_shared_key, MlKem768Internal::SHARED_SECRET_SIZE, MlKem768Internal::SHARED_SECRET_SIZE)
94}
95
96impl_from_trait!(SharedSecret, MlKem768Internal::SHARED_SECRET_SIZE);
97
98construct_public! {
99 (Ciphertext, test_kem_ciphertext, MlKem768Internal::CIPHERTEXT_SIZE, MlKem768Internal::CIPHERTEXT_SIZE)
105}
106
107impl_from_trait!(Ciphertext, MlKem768Internal::CIPHERTEXT_SIZE);
108
109#[derive(Debug, PartialEq)]
110pub struct KeyPair {
112 seed: Seed,
113 dk: DecapsulationKey,
114}
115
116impl KeyPair {
117 #[cfg(feature = "safe_api")]
118 #[cfg_attr(docsrs, doc(cfg(feature = "safe_api")))]
119 pub fn generate() -> Result<Self, UnknownCryptoError> {
121 let seed = Seed::generate();
122 let (ek, dk) = KeyPairInternal::<MlKem768Internal>::from_seed::<3, 1184, 2400>(&seed)?;
123
124 Ok(Self {
125 seed,
126 dk: DecapsulationKey {
127 value: dk,
128 cached_ek: EncapsulationKey { value: ek },
129 },
130 })
131 }
132
133 #[cfg(feature = "safe_api")]
134 #[cfg_attr(docsrs, doc(cfg(feature = "safe_api")))]
135 pub fn from_keys(seed: &Seed, dk: &DecapsulationKey) -> Result<Self, UnknownCryptoError> {
142 let unchecked_ek = EncapsulationKey::try_from(dk)?;
143 let (ek, dk) = KeyPairInternal::<MlKem768Internal>::from_keys::<3, 1184, 2400, 1088>(
144 seed,
145 &unchecked_ek.value,
146 &dk.value,
147 )?;
148
149 Ok(Self {
150 seed: Seed::from_slice(seed.unprotected_as_bytes()).unwrap(),
151 dk: DecapsulationKey {
152 value: dk,
153 cached_ek: EncapsulationKey { value: ek },
154 },
155 })
156 }
157
158 pub fn seed(&self) -> &Seed {
161 &self.seed
162 }
163
164 pub fn public(&self) -> &EncapsulationKey {
166 &self.dk.cached_ek
167 }
168
169 pub fn private(&self) -> &DecapsulationKey {
172 &self.dk
173 }
174}
175
176impl TryFrom<&Seed> for KeyPair {
177 type Error = UnknownCryptoError;
178
179 fn try_from(value: &Seed) -> Result<Self, Self::Error> {
180 let (ek, dk) = KeyPairInternal::<MlKem768Internal>::from_seed::<3, 1184, 2400>(value)?;
181
182 Ok(Self {
183 seed: Seed::from_slice(value.unprotected_as_bytes()).unwrap(),
184 dk: DecapsulationKey {
185 value: dk,
186 cached_ek: EncapsulationKey { value: ek },
187 },
188 })
189 }
190}
191
192#[derive(Debug, PartialEq)]
193pub struct DecapsulationKey {
195 pub(crate) value: DecapKey<3, 1184, 2400, MlKem768Internal>,
196 pub(crate) cached_ek: EncapsulationKey,
200}
201
202impl PartialEq<&[u8]> for DecapsulationKey {
203 fn eq(&self, other: &&[u8]) -> bool {
204 self.value == *other
206 }
207}
208
209impl DecapsulationKey {
210 pub fn unchecked_from_slice(slice: &[u8]) -> Result<Self, UnknownCryptoError> {
212 let dk_unchecked =
213 DecapKey::<3, 1184, 2400, MlKem768Internal>::unchecked_from_slice(slice)?;
214 let ek_unchecked =
215 EncapsulationKey::from_slice(dk_unchecked.get_encapsulation_key_bytes())?;
216
217 Ok(Self {
218 value: dk_unchecked,
219 cached_ek: ek_unchecked,
220 })
221 }
222
223 pub fn decap(&self, c: &Ciphertext) -> Result<SharedSecret, UnknownCryptoError> {
225 let mut c_prime_buf = [0u8; MlKem768Internal::CIPHERTEXT_SIZE];
226 let mut k_internal = self.value.mlkem_decap_internal_with_ek(
227 c.as_ref(),
228 &mut c_prime_buf,
229 &self.cached_ek.value,
230 )?;
231 let k = SharedSecret::from_slice(&k_internal)?;
232 k_internal.zeroize();
233
234 Ok(k)
235 }
236}
237
238#[derive(Debug, PartialEq, Clone)]
239pub struct EncapsulationKey {
241 pub(crate) value: EncapKey<3, 1184, MlKem768Internal>,
242}
243
244impl PartialEq<&[u8]> for EncapsulationKey {
245 fn eq(&self, other: &&[u8]) -> bool {
246 self.value == *other
247 }
248}
249
250impl TryFrom<&DecapsulationKey> for EncapsulationKey {
251 type Error = UnknownCryptoError;
252
253 fn try_from(value: &DecapsulationKey) -> Result<Self, Self::Error> {
254 Ok(Self {
255 value: EncapKey::<3, 1184, MlKem768Internal>::from_slice(
256 value.value.get_encapsulation_key_bytes(),
257 )?,
258 })
259 }
260}
261
262impl AsRef<[u8]> for EncapsulationKey {
263 fn as_ref(&self) -> &[u8] {
264 self.value.as_ref()
265 }
266}
267
268impl EncapsulationKey {
269 pub fn from_slice(slice: &[u8]) -> Result<Self, UnknownCryptoError> {
271 Ok(Self {
272 value: EncapKey::<3, 1184, MlKem768Internal>::from_slice(slice)?,
273 })
274 }
275
276 #[cfg(feature = "safe_api")]
277 #[cfg_attr(docsrs, doc(cfg(feature = "safe_api")))]
278 pub fn encap(&self) -> Result<(SharedSecret, Ciphertext), UnknownCryptoError> {
280 use zeroize::Zeroizing;
281
282 let mut m = Zeroizing::new([0u8; 32]);
283 getrandom::fill(m.as_mut())?;
284
285 self.encap_deterministic(m.as_ref())
286 }
287
288 pub fn encap_deterministic(
290 &self,
291 m: &[u8],
292 ) -> Result<(SharedSecret, Ciphertext), UnknownCryptoError> {
293 if m.len() != 32 {
294 return Err(UnknownCryptoError);
295 }
296
297 let mut c = Ciphertext::from_slice(&[0u8; MlKem768Internal::CIPHERTEXT_SIZE])?;
298 let mut k_internal = self.value.mlkem_encap_internal(m.as_ref(), &mut c.value)?;
299 let k = SharedSecret::from_slice(k_internal.as_slice())?;
300 k_internal.zeroize();
301
302 Ok((k, c))
303 }
304}
305
306#[derive(PartialEq, Debug)]
307pub struct MlKem768;
309
310impl MlKem768 {
311 pub const EK_SIZE: usize = MlKem768Internal::EK_SIZE;
313 pub const DK_SIZE: usize = MlKem768Internal::DK_SIZE;
315 pub const CIPHERTEXT_SIZE: usize = MlKem768Internal::CIPHERTEXT_SIZE;
317 pub const SHARED_SECRET_SIZE: usize = MlKem768Internal::SHARED_SECRET_SIZE;
319
320 #[cfg(feature = "safe_api")]
321 #[cfg_attr(docsrs, doc(cfg(feature = "safe_api")))]
322 pub fn encap(ek: &EncapsulationKey) -> Result<(SharedSecret, Ciphertext), UnknownCryptoError> {
324 ek.encap()
325 }
326
327 pub fn decap(
329 dk: &DecapsulationKey,
330 c: &Ciphertext,
331 ) -> Result<SharedSecret, UnknownCryptoError> {
332 dk.decap(c)
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[cfg(feature = "safe_api")]
341 use crate::test_framework::kem_interface::{KemTester, TestableKem};
342
343 #[cfg(feature = "safe_api")]
344 impl TestableKem<SharedSecret, Ciphertext> for MlKem768 {
345 fn keygen(seed: &[u8]) -> Result<(Vec<u8>, Vec<u8>), UnknownCryptoError> {
346 let kp = KeyPair::try_from(&Seed::from_slice(seed).unwrap()).unwrap();
347
348 Ok((
349 kp.dk.cached_ek.as_ref().to_vec(),
350 kp.dk.value.unprotected_as_bytes().to_vec(),
351 ))
352 }
353
354 fn ciphertext_from_bytes(b: &[u8]) -> Result<Ciphertext, UnknownCryptoError> {
355 Ciphertext::from_slice(b)
356 }
357
358 fn encap(ek: &[u8]) -> Result<(SharedSecret, Ciphertext), UnknownCryptoError> {
359 let ek = EncapsulationKey::from_slice(ek).unwrap();
360 ek.encap()
361 }
362
363 fn decap(dk: &[u8], c: &Ciphertext) -> Result<SharedSecret, UnknownCryptoError> {
364 let dk = DecapsulationKey::unchecked_from_slice(dk).unwrap();
365 dk.decap(c)
366 }
367 }
368
369 #[test]
370 fn test_keypair_dk_ek_match_internal() {
371 let seed = Seed::from_slice(&[128u8; 64]).unwrap();
372 let kp = KeyPair::try_from(&seed).unwrap();
373 assert_eq!(kp.public(), &kp.private().cached_ek);
374 }
375
376 #[test]
377 #[cfg(feature = "safe_api")]
378 fn test_dk_cached_ek() {
379 let seed = Seed::from_slice(&[128u8; 64]).unwrap();
380 let kp = KeyPair::try_from(&seed).unwrap();
381 let (ss_pubapi, ct_pubapi) = kp.public().encap_deterministic(&[125u8; 32]).unwrap();
382 let mut c_prime = [0u8; MlKem768Internal::CIPHERTEXT_SIZE];
383 let ss_privapi = kp
385 .private()
386 .value
387 .mlkem_decap_internal(ct_pubapi.as_ref(), &mut c_prime)
388 .unwrap();
389 assert_eq!(ss_privapi.as_ref(), ss_pubapi.unprotected_as_bytes());
390 assert_eq!(
391 MlKem768::decap(kp.private(), &ct_pubapi).unwrap(),
392 ss_pubapi
393 );
394 }
395
396 #[cfg(feature = "safe_api")]
397 #[test]
398 fn test_dk_to_ek_conversions() {
399 let kp = KeyPair::generate().unwrap();
400 assert_eq!(
401 kp.dk.cached_ek,
402 EncapsulationKey::try_from(kp.private()).unwrap()
403 );
404 }
405
406 #[cfg(feature = "safe_api")]
407 #[test]
408 fn test_bad_m_length() {
409 let kp = KeyPair::generate().unwrap();
410 let mut m = [0u8; 32];
411 getrandom::fill(m.as_mut()).unwrap();
412
413 assert_eq!(
416 kp.public().encap_deterministic(&m).unwrap(),
417 kp.public().encap_deterministic(&m).unwrap()
418 );
419 assert!(kp.public().encap_deterministic(&[0u8; 31]).is_err());
420 assert!(kp.public().encap_deterministic(&[0u8; 33]).is_err());
421 }
422
423 #[cfg(feature = "safe_api")]
424 #[test]
425 fn test_dk_ek_partialeq() {
426 let s0 = Seed::generate();
427 let kp = KeyPair::try_from(&s0).unwrap();
428
429 let dk_bytes = kp.private().value.bytes;
430 let ek_bytes = kp.public().value.bytes;
431
432 assert_eq!(
433 KeyPair::try_from(&s0).unwrap().private(),
434 &dk_bytes.as_ref()
435 );
436 assert_eq!(KeyPair::try_from(&s0).unwrap().public(), &ek_bytes.as_ref());
437 }
438
439 #[cfg(feature = "safe_api")]
440 #[test]
441 fn test_keypair_from_keys() {
442 let s0 = Seed::generate();
443 let s1 = Seed::generate();
444
445 let kp0 = KeyPair::try_from(&s0).unwrap();
446 let kp1 = KeyPair::try_from(&s1).unwrap();
447 assert_eq!(kp0.seed(), &s0);
448 assert_eq!(kp1.seed(), &s1);
449
450 assert!(KeyPair::from_keys(&s0, kp0.private()).is_ok());
451 assert!(KeyPair::from_keys(&s1, kp1.private()).is_ok());
452 assert!(KeyPair::from_keys(&s1, kp0.private()).is_err());
453 assert!(KeyPair::from_keys(&s0, kp1.private()).is_err());
454
455 let kp0_keys = KeyPair::from_keys(&s0, kp0.private()).unwrap();
456 let kp1_keys = KeyPair::from_keys(&s1, kp1.private()).unwrap();
457 assert_eq!(kp0.seed(), kp0_keys.seed());
458 assert_eq!(kp1.seed(), kp1_keys.seed());
459
460 assert_eq!(kp0.private(), kp0_keys.private());
461 assert_eq!(kp0.public(), kp0_keys.public());
462 }
463
464 #[cfg(feature = "safe_api")]
465 #[test]
466 fn run_basic_kem_tests() {
467 let seed = Seed::generate();
468 KemTester::<MlKem768, SharedSecret, Ciphertext>::run_all_tests(seed.unprotected_as_bytes());
469 }
470
471 #[test]
472 fn basic_roundtrip() {
474 let seed = Seed::from_slice(&[127u8; 64]).unwrap();
475 let kp = KeyPair::try_from(&seed).unwrap();
476
477 let (k, c) = kp.public().encap_deterministic(&[255u8; 32]).unwrap();
478 let k_prime = kp.private().decap(&c).unwrap();
479
480 assert_eq!(k, k_prime);
481 }
482}