1use pqcrypto_mlkem::{
2 ffi::{
3 PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES, PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES,
4 PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES,
5 },
6 mlkem1024::{self, SharedSecret},
7 mlkem1024_decapsulate, mlkem1024_encapsulate, mlkem1024_keypair,
8};
9use pqcrypto_traits::kem::{PublicKey, SecretKey};
10
11use crate::errors::CryptoError;
12
13pub struct KEMPair {
19 pub_key: mlkem1024::PublicKey,
20 sec_key: mlkem1024::SecretKey,
21}
22
23impl KEMPair {
24 pub fn create() -> Self {
29 let (pk, sk) = mlkem1024_keypair();
30 Self {
31 pub_key: pk,
32 sec_key: sk,
33 }
34 }
35
36 pub fn from_bytes(pub_key: &[u8], sec_key: &[u8]) -> Result<Self, CryptoError> {
45 let pub_key = mlkem1024::PublicKey::from_bytes(pub_key)?;
46 let sec_key = mlkem1024::SecretKey::from_bytes(sec_key)?;
47 Ok(Self { pub_key, sec_key })
48 }
49
50 pub fn to_bytes(
56 &self,
57 ) -> Result<
58 (
59 [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES],
60 [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES],
61 ),
62 CryptoError,
63 > {
64 Ok((
65 self.pub_key.as_bytes().try_into()?,
66 self.sec_key.as_bytes().try_into()?,
67 ))
68 }
69
70 pub fn to_bytes_uniform(&self) -> Vec<u8> {
75 let mut bytes = Vec::new();
76 bytes.extend_from_slice(self.pub_key.as_bytes());
77 bytes.extend_from_slice(self.sec_key.as_bytes());
78 bytes
79 }
80
81 pub fn from_bytes_uniform(bytes: &[u8]) -> Result<Self, CryptoError> {
89 if bytes.len()
90 != PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
91 + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES
92 {
93 return Err(CryptoError::IncongruentLength(
94 PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
95 + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES,
96 bytes.len(),
97 ));
98 }
99 let pub_key = mlkem1024::PublicKey::from_bytes(
100 &bytes[..PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES],
101 )?;
102 let sec_key = mlkem1024::SecretKey::from_bytes(
103 &bytes[PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES..],
104 )?;
105 Ok(Self { pub_key, sec_key })
106 }
107
108 pub fn encapsulate(
116 &self,
117 receiver_pubkey: &mlkem1024::PublicKey,
118 ) -> (SharedSecret, mlkem1024::Ciphertext) {
119 mlkem1024_encapsulate(receiver_pubkey)
120 }
121
122 pub fn decapsulate(
130 &self,
131 ciphertext: &mlkem1024::Ciphertext,
132 ) -> Result<SharedSecret, CryptoError> {
133 let shared_secret = mlkem1024_decapsulate(ciphertext, &self.sec_key);
134 Ok(shared_secret)
135 }
136}
137
138pub fn ss2b(ss: &SharedSecret) -> [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES] {
146 unsafe { *(ss as *const SharedSecret as *const [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES]) }
147}
148
149pub fn b2ss(bytes: &[u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES]) -> SharedSecret {
157 unsafe {
158 std::ptr::read(
159 bytes as *const [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES] as *const SharedSecret,
160 )
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_keypair() {
170 let keypair = KEMPair::create();
171 let (pub_key, sec_key) = keypair.to_bytes().unwrap();
172 let new_keypair = KEMPair::from_bytes(&pub_key, &sec_key).unwrap();
173 assert_eq!(keypair.pub_key.as_bytes(), new_keypair.pub_key.as_bytes());
174 assert_eq!(keypair.sec_key.as_bytes(), new_keypair.sec_key.as_bytes());
175 }
176
177 #[test]
178 fn test_encapsulate_decapsulate() {
179 let sender = KEMPair::create();
180 let receiver = KEMPair::create();
181
182 let (shared_secret, ciphertext) = sender.encapsulate(&receiver.pub_key);
183 let dec_shared_secret = receiver.decapsulate(&ciphertext).unwrap();
184
185 let ss1 = ss2b(&shared_secret);
186 let ss2 = ss2b(&dec_shared_secret);
187
188 assert_eq!(ss1, ss2, "Difference in shared secrets!");
189 }
190}