use ml_kem::kem::Encapsulate;
use ml_kem::{Encoded, EncodedSizeUser, KemCore, MlKem1024, MlKem768};
use rust_util::{opt_result, simple_error, XResult};
#[derive(Clone, Copy, Debug)]
pub enum MlKemAlgo {
MlKem768,
MlKem1024,
}
pub fn ml_kem_768_encapsulate(public_key: &[u8]) -> XResult<(Vec<u8>, Vec<u8>)> {
let encapsulation_key_encoded: Encoded<<MlKem768 as KemCore>::EncapsulationKey> = opt_result!(
public_key.try_into(),
"Parse ML-KEM 768 encapsulation key failed: {}"
);
let encapsulation_key =
<MlKem768 as KemCore>::EncapsulationKey::from_bytes(&encapsulation_key_encoded);
let mut rng = rand::rngs::OsRng;
let (ciphertext, shared_key) = opt_result!(
encapsulation_key.encapsulate(&mut rng),
"Encapsulate shared key failed: {:?}"
);
Ok((shared_key.0.to_vec(), ciphertext.0.to_vec()))
}
pub fn ml_kem_1024_encapsulate(public_key: &[u8]) -> XResult<(Vec<u8>, Vec<u8>)> {
let encapsulation_key_encoded: Encoded<<MlKem1024 as KemCore>::EncapsulationKey> = opt_result!(
public_key.try_into(),
"Parse ML-KEM 1024 encapsulation key failed: {}"
);
let encapsulation_key =
<MlKem1024 as KemCore>::EncapsulationKey::from_bytes(&encapsulation_key_encoded);
let mut rng = rand::rngs::OsRng;
let (ciphertext, shared_key) = opt_result!(
encapsulation_key.encapsulate(&mut rng),
"Encapsulate shared key failed: {:?}"
);
Ok((shared_key.0.to_vec(), ciphertext.0.to_vec()))
}
pub fn try_ml_kem_encapsulate(public_key: &[u8]) -> XResult<(Vec<u8>, Vec<u8>, MlKemAlgo)> {
if let Ok((shared_key, ciphertext)) = ml_kem_768_encapsulate(public_key) {
return Ok((shared_key, ciphertext, MlKemAlgo::MlKem768));
}
if let Ok((shared_key, ciphertext)) = ml_kem_1024_encapsulate(public_key) {
return Ok((shared_key, ciphertext, MlKemAlgo::MlKem1024));
}
simple_error!("Only supports ML-KEM 768 or ML-KEM 1024.")
}