1use hkdf::Hkdf;
46use ml_kem::array::Array;
47use ml_kem::{Decapsulate, DecapsulationKey, EncapsulationKey, KeyExport, MlKem768};
48use sha2::Sha256;
49use zeroize::Zeroizing;
50
51use crate::error::{ProtocolError, Result};
52
53pub const MLKEM_EK_LEN: usize = 1184;
55pub const MLKEM_CT_LEN: usize = 1088;
57pub const SS_LEN: usize = 32;
59
60const MLKEM_SEED_LABEL: &[u8] = b"huddle-mlkem-768-seed-v1";
62const HYBRID_COMBINE_SALT: &[u8] = b"huddle-hybrid-kem-v1";
64
65pub struct PqKeypair {
72 dk: DecapsulationKey<MlKem768>,
73}
74
75impl PqKeypair {
76 pub fn from_identity_seed(ed25519_seed: &[u8; 32]) -> Self {
85 let mut seed64 = Zeroizing::new([0u8; 64]);
86 let hk = Hkdf::<Sha256>::new(Some(MLKEM_SEED_LABEL), ed25519_seed);
87 hk.expand(b"", seed64.as_mut_slice())
88 .expect("HKDF expand to 64 bytes is within SHA-256's output limit");
89 let seed: ml_kem::Seed =
93 Array::try_from(seed64.as_slice()).expect("ML-KEM seed is exactly 64 bytes");
94 let dk = DecapsulationKey::<MlKem768>::from_seed(seed);
95 Self { dk }
96 }
97
98 pub fn encapsulation_key_bytes(&self) -> [u8; MLKEM_EK_LEN] {
100 let encoded = self.dk.encapsulation_key().to_bytes();
101 let mut out = [0u8; MLKEM_EK_LEN];
102 out.copy_from_slice(&encoded);
103 out
104 }
105
106 pub fn decapsulate(&self, ciphertext: &[u8]) -> Result<Zeroizing<[u8; SS_LEN]>> {
114 if ciphertext.len() != MLKEM_CT_LEN {
115 return Err(ProtocolError::Session(format!(
116 "ML-KEM ciphertext is {} bytes, expected {MLKEM_CT_LEN}",
117 ciphertext.len()
118 )));
119 }
120 let ct = Array::try_from(ciphertext)
121 .map_err(|_| ProtocolError::Session("ML-KEM ciphertext decode failed".into()))?;
122 let ss = self.dk.decapsulate(&ct);
123 let mut out = Zeroizing::new([0u8; SS_LEN]);
124 out.copy_from_slice(&ss);
125 Ok(out)
126 }
127}
128
129pub fn encapsulate_deterministic(
138 partner_ek_bytes: &[u8],
139 m: &[u8; SS_LEN],
140) -> Result<(Vec<u8>, Zeroizing<[u8; SS_LEN]>)> {
141 if partner_ek_bytes.len() != MLKEM_EK_LEN {
142 return Err(ProtocolError::Session(format!(
143 "ML-KEM encapsulation key is {} bytes, expected {MLKEM_EK_LEN}",
144 partner_ek_bytes.len()
145 )));
146 }
147 let ek_arr = Array::try_from(partner_ek_bytes)
148 .map_err(|_| ProtocolError::Session("ML-KEM encapsulation key decode failed".into()))?;
149 let ek = EncapsulationKey::<MlKem768>::new(&ek_arr)
150 .map_err(|_| ProtocolError::Session("invalid ML-KEM encapsulation key".into()))?;
151 let m_arr: ml_kem::B32 =
152 Array::try_from(&m[..]).expect("encapsulation message is exactly 32 bytes");
153 let (ct, ss) = ek.encapsulate_deterministic(&m_arr);
154 let mut ss_out = Zeroizing::new([0u8; SS_LEN]);
155 ss_out.copy_from_slice(&ss);
156 Ok((ct.to_vec(), ss_out))
157}
158
159pub fn combine_hybrid(
179 ss_x25519: &[u8; SS_LEN],
180 ss_mlkem: &[u8; SS_LEN],
181 kem_ciphertext: &[u8],
182 context: &[u8],
183) -> Zeroizing<[u8; SS_LEN]> {
184 let mut ikm = Zeroizing::new([0u8; 2 * SS_LEN]);
185 ikm[..SS_LEN].copy_from_slice(ss_x25519);
186 ikm[SS_LEN..].copy_from_slice(ss_mlkem);
187
188 let mut info = Vec::with_capacity(kem_ciphertext.len() + context.len());
189 info.extend_from_slice(kem_ciphertext);
190 info.extend_from_slice(context);
191
192 let hk = Hkdf::<Sha256>::new(Some(HYBRID_COMBINE_SALT), ikm.as_slice());
193 let mut out = Zeroizing::new([0u8; SS_LEN]);
194 hk.expand(&info, out.as_mut_slice())
195 .expect("HKDF expand to 32 bytes is within SHA-256's output limit");
196 out
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 fn seed(n: u8) -> [u8; 32] {
204 [n; 32]
205 }
206
207 #[test]
208 fn keypair_is_deterministic_from_seed() {
209 let a = PqKeypair::from_identity_seed(&seed(7));
210 let b = PqKeypair::from_identity_seed(&seed(7));
211 assert_eq!(
212 a.encapsulation_key_bytes(),
213 b.encapsulation_key_bytes(),
214 "same identity seed must yield the same ML-KEM public key"
215 );
216 }
217
218 #[test]
219 fn different_seeds_yield_different_keys() {
220 let a = PqKeypair::from_identity_seed(&seed(1));
221 let b = PqKeypair::from_identity_seed(&seed(2));
222 assert_ne!(a.encapsulation_key_bytes(), b.encapsulation_key_bytes());
223 }
224
225 #[test]
226 fn ek_has_expected_size() {
227 let kp = PqKeypair::from_identity_seed(&seed(9));
228 assert_eq!(kp.encapsulation_key_bytes().len(), MLKEM_EK_LEN);
229 }
230
231 #[test]
232 fn encapsulate_decapsulate_round_trip() {
233 let responder = PqKeypair::from_identity_seed(&seed(42));
234 let ek = responder.encapsulation_key_bytes();
235 let m = [3u8; SS_LEN];
236
237 let (ct, ss_send) = encapsulate_deterministic(&ek, &m).unwrap();
238 assert_eq!(ct.len(), MLKEM_CT_LEN);
239
240 let ss_recv = responder.decapsulate(&ct).unwrap();
241 assert_eq!(
242 *ss_send, *ss_recv,
243 "encapsulator and decapsulator must agree"
244 );
245 }
246
247 #[test]
248 fn deterministic_encapsulation_reproduces() {
249 let responder = PqKeypair::from_identity_seed(&seed(11));
250 let ek = responder.encapsulation_key_bytes();
251 let m = [5u8; SS_LEN];
252 let (ct1, ss1) = encapsulate_deterministic(&ek, &m).unwrap();
253 let (ct2, ss2) = encapsulate_deterministic(&ek, &m).unwrap();
254 assert_eq!(ct1, ct2, "same m + ek must reproduce the same ciphertext");
255 assert_eq!(*ss1, *ss2, "same m + ek must reproduce the same secret");
256 }
257
258 #[test]
259 fn different_m_yields_different_ciphertext_and_secret() {
260 let responder = PqKeypair::from_identity_seed(&seed(11));
261 let ek = responder.encapsulation_key_bytes();
262 let (ct_a, ss_a) = encapsulate_deterministic(&ek, &[1u8; SS_LEN]).unwrap();
263 let (ct_b, ss_b) = encapsulate_deterministic(&ek, &[2u8; SS_LEN]).unwrap();
264 assert_ne!(ct_a, ct_b);
265 assert_ne!(*ss_a, *ss_b);
266 }
267
268 #[test]
269 fn tampered_ciphertext_does_not_recover_secret() {
270 let responder = PqKeypair::from_identity_seed(&seed(99));
274 let ek = responder.encapsulation_key_bytes();
275 let (mut ct, ss_send) = encapsulate_deterministic(&ek, &[8u8; SS_LEN]).unwrap();
276 ct[0] ^= 0x01;
277 let ss_recv = responder.decapsulate(&ct).unwrap();
278 assert_ne!(
279 *ss_send, *ss_recv,
280 "a tampered ciphertext must not recover the encapsulated secret"
281 );
282 }
283
284 #[test]
285 fn wrong_ek_length_is_rejected() {
286 let err = encapsulate_deterministic(&[0u8; 10], &[0u8; SS_LEN]);
287 assert!(err.is_err());
288 }
289
290 #[test]
291 fn wrong_ct_length_is_rejected() {
292 let kp = PqKeypair::from_identity_seed(&seed(1));
293 assert!(kp.decapsulate(&[0u8; 10]).is_err());
294 }
295
296 #[test]
297 fn combiner_is_deterministic_and_input_sensitive() {
298 let ss_x = [1u8; SS_LEN];
299 let ss_pq = [2u8; SS_LEN];
300 let ct = vec![3u8; MLKEM_CT_LEN];
301 let ctx = b"room-1";
302
303 let k = *combine_hybrid(&ss_x, &ss_pq, &ct, ctx);
304 let k_again = *combine_hybrid(&ss_x, &ss_pq, &ct, ctx);
305 assert_eq!(k, k_again, "combiner must be deterministic");
306
307 assert_ne!(k, *combine_hybrid(&[9u8; SS_LEN], &ss_pq, &ct, ctx));
309 assert_ne!(k, *combine_hybrid(&ss_x, &[9u8; SS_LEN], &ct, ctx));
310 let mut ct2 = ct.clone();
311 ct2[0] ^= 0xFF;
312 assert_ne!(k, *combine_hybrid(&ss_x, &ss_pq, &ct2, ctx));
313 assert_ne!(k, *combine_hybrid(&ss_x, &ss_pq, &ct, b"room-2"));
314 }
315
316 #[test]
317 fn combiner_differs_from_either_raw_secret() {
318 let ss_x = [4u8; SS_LEN];
319 let ss_pq = [5u8; SS_LEN];
320 let ct = vec![6u8; MLKEM_CT_LEN];
321 let k = *combine_hybrid(&ss_x, &ss_pq, &ct, b"ctx");
322 assert_ne!(k, ss_x, "hybrid key must not equal the raw X25519 secret");
323 assert_ne!(k, ss_pq, "hybrid key must not equal the raw ML-KEM secret");
324 }
325
326 #[test]
327 fn full_two_party_hybrid_agreement() {
328 let responder = PqKeypair::from_identity_seed(&seed(21));
331 let ek = responder.encapsulation_key_bytes();
332 let ss_x = [7u8; SS_LEN]; let m = [13u8; SS_LEN];
334
335 let (ct, ss_pq_send) = encapsulate_deterministic(&ek, &m).unwrap();
336 let key_initiator = *combine_hybrid(&ss_x, &ss_pq_send, &ct, b"dm-room");
337
338 let ss_pq_recv = responder.decapsulate(&ct).unwrap();
339 let key_responder = *combine_hybrid(&ss_x, &ss_pq_recv, &ct, b"dm-room");
340
341 assert_eq!(key_initiator, key_responder);
342 }
343}