1use borsh::{BorshDeserialize, BorshSerialize};
26use hkdf::Hkdf;
27use ml_kem::array::Array;
28use ml_kem::kem::{Decapsulate, Encapsulate};
29use ml_kem::{Encoded, EncodedSizeUser, KemCore, MlKem768};
30use rand::rngs::OsRng;
31use sha2::Sha256;
32use std::fmt;
33use zeroize::ZeroizeOnDrop;
34
35#[cfg(not(feature = "fips"))]
36use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
37
38#[cfg(feature = "fips")]
39use aws_lc_rs::{
40 agreement::{self, agree, EphemeralPrivateKey, PrivateKey, UnparsedPublicKey, ECDH_P256},
41 rand::SystemRandom,
42};
43
44type MlKem768DecapKey = <MlKem768 as KemCore>::DecapsulationKey;
45type MlKem768EncapKey = <MlKem768 as KemCore>::EncapsulationKey;
46
47#[cfg(not(feature = "fips"))]
52pub const CLASSICAL_PK_BYTES: usize = 32;
53#[cfg(feature = "fips")]
54pub const CLASSICAL_PK_BYTES: usize = 65;
55
56#[cfg(not(feature = "fips"))]
61const COMBINE_LABEL: &[u8] = b"HybridKEM_X25519_Kyber768";
62#[cfg(feature = "fips")]
63const COMBINE_LABEL: &[u8] = b"HybridKEM_P256_Kyber768";
64
65#[derive(ZeroizeOnDrop)]
77pub struct HybridSecretKey {
78 #[cfg(not(feature = "fips"))]
82 pub classical_sk: StaticSecret,
83 #[cfg(feature = "fips")]
84 #[zeroize(skip)] pub classical_sk: PrivateKey,
86
87 #[zeroize(skip)] pub ml_kem_dk: Box<MlKem768DecapKey>,
91}
92
93impl HybridSecretKey {
94 pub fn generate() -> (Self, HybridKeyPackage) {
95 let mut rng = OsRng;
96
97 #[cfg(not(feature = "fips"))]
101 let (classical_sk, classical_pk_bytes) = {
102 let sk = StaticSecret::random_from_rng(rng);
103 let pk = X25519PublicKey::from(&sk);
104 (sk, *pk.as_bytes())
105 };
106 #[cfg(feature = "fips")]
107 let (classical_sk, classical_pk_bytes) = {
108 #[allow(clippy::expect_used)]
117 let sk = PrivateKey::generate(&ECDH_P256)
118 .expect("aws-lc-rs ECDH-P-256 generate must succeed");
119 #[allow(clippy::expect_used)]
120 let pk = sk
121 .compute_public_key()
122 .expect("aws-lc-rs ECDH-P-256 compute_public_key must succeed");
123 let mut bytes = [0u8; CLASSICAL_PK_BYTES];
124 bytes.copy_from_slice(pk.as_ref());
125 (sk, bytes)
126 };
127
128 let (dk, ek) = MlKem768::generate(&mut rng);
131
132 let secret_key = HybridSecretKey {
133 classical_sk,
134 ml_kem_dk: Box::new(dk),
135 };
136 let key_package = HybridKeyPackage {
137 classical_pk: classical_pk_bytes,
138 ml_kem_pk: ek.as_bytes().to_vec(),
139 };
140 (secret_key, key_package)
141 }
142
143 pub fn decapsulate(&self, ciphertext: &HybridCiphertext) -> Result<[u8; 32], anyhow::Error> {
144 #[cfg(not(feature = "fips"))]
146 let classical_shared: [u8; 32] = {
147 let peer = X25519PublicKey::from(ciphertext.classical_pk);
148 let s = self.classical_sk.diffie_hellman(&peer);
149 *s.as_bytes()
150 };
151 #[cfg(feature = "fips")]
152 let classical_shared: [u8; 32] = {
153 let peer = UnparsedPublicKey::new(&ECDH_P256, &ciphertext.classical_pk[..]);
154 agree(
159 &self.classical_sk,
160 peer,
161 anyhow::anyhow!("aws-lc-rs ECDH-P-256 agree failed (peer key parse)"),
162 |km| -> Result<[u8; 32], anyhow::Error> {
163 let mut out = [0u8; 32];
165 out.copy_from_slice(km);
166 Ok(out)
167 },
168 )?
169 };
170
171 let ct_array = decode_ml_kem_ciphertext(&ciphertext.ml_kem_ct)
173 .ok_or_else(|| anyhow::anyhow!("invalid ML-KEM-768 ciphertext length"))?;
174 let ml_kem_shared = self
175 .ml_kem_dk
176 .decapsulate(&ct_array)
177 .map_err(|e| anyhow::anyhow!("ML-KEM decapsulation failed: {:?}", e))?;
178
179 Self::combine_secrets(&classical_shared, ml_kem_shared.as_slice())
181 }
182
183 pub(crate) fn combine_secrets(
184 ecc_secret: &[u8],
185 pq_secret: &[u8],
186 ) -> Result<[u8; 32], anyhow::Error> {
187 let ikm = zeroize::Zeroizing::new([ecc_secret, pq_secret].concat());
191 let hkdf = Hkdf::<Sha256>::new(None, &ikm);
192 let mut okm = [0u8; 32];
193 hkdf.expand(COMBINE_LABEL, &mut okm)
194 .map_err(|_| anyhow::anyhow!("HKDF expansion failed"))?;
195 Ok(okm)
196 }
197}
198
199impl fmt::Debug for HybridSecretKey {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.debug_struct("HybridSecretKey")
202 .field("classical_sk", &"REDACTED")
203 .field("ml_kem_dk", &"REDACTED")
204 .finish()
205 }
206}
207
208#[derive(BorshSerialize, BorshDeserialize, Debug, Clone)]
209pub struct HybridKeyPackage {
210 pub classical_pk: [u8; CLASSICAL_PK_BYTES],
214 pub ml_kem_pk: Vec<u8>,
215}
216
217impl HybridKeyPackage {
218 pub fn encapsulate(&self) -> Result<([u8; 32], HybridCiphertext), anyhow::Error> {
219 let mut rng = OsRng;
220
221 #[cfg(not(feature = "fips"))]
223 let (eph_pk_bytes, classical_shared) = {
224 let eph_sk = StaticSecret::random_from_rng(rng);
225 let eph_pk = X25519PublicKey::from(&eph_sk);
226 let peer = X25519PublicKey::from(self.classical_pk);
227 let shared = eph_sk.diffie_hellman(&peer);
228 (*eph_pk.as_bytes(), *shared.as_bytes())
229 };
230 #[cfg(feature = "fips")]
231 let (eph_pk_bytes, classical_shared): ([u8; CLASSICAL_PK_BYTES], [u8; 32]) = {
232 let aws_rng = SystemRandom::new();
233 let eph_sk = EphemeralPrivateKey::generate(&ECDH_P256, &aws_rng)
234 .map_err(|e| anyhow::anyhow!("aws-lc-rs ECDH-P-256 ephemeral generate: {:?}", e))?;
235 let eph_pk = eph_sk
236 .compute_public_key()
237 .map_err(|e| anyhow::anyhow!("compute_public_key: {:?}", e))?;
238 let mut pk_bytes = [0u8; CLASSICAL_PK_BYTES];
239 pk_bytes.copy_from_slice(eph_pk.as_ref());
240 let peer = UnparsedPublicKey::new(&ECDH_P256, &self.classical_pk[..]);
241 let shared = agreement::agree_ephemeral(
242 eph_sk,
243 peer,
244 anyhow::anyhow!("aws-lc-rs ECDH-P-256 agree_ephemeral failed (peer parse)"),
245 |km| -> Result<[u8; 32], anyhow::Error> {
246 let mut o = [0u8; 32];
247 o.copy_from_slice(km);
248 Ok(o)
249 },
250 )?;
251 (pk_bytes, shared)
252 };
253
254 let ek_array = decode_ml_kem_encap_key(&self.ml_kem_pk)
256 .ok_or_else(|| anyhow::anyhow!("invalid ML-KEM-768 public key length"))?;
257 let ek = MlKem768EncapKey::from_bytes(&ek_array);
258 let (ct, ml_kem_shared) = ek
259 .encapsulate(&mut rng)
260 .map_err(|e| anyhow::anyhow!("ML-KEM encapsulation failed: {:?}", e))?;
261
262 let shared_secret =
264 HybridSecretKey::combine_secrets(&classical_shared, ml_kem_shared.as_slice())?;
265
266 let ciphertext = HybridCiphertext {
267 classical_pk: eph_pk_bytes,
268 ml_kem_ct: ct.as_slice().to_vec(),
269 };
270 Ok((shared_secret, ciphertext))
271 }
272}
273
274#[derive(BorshSerialize, BorshDeserialize, Debug, Clone)]
275pub struct HybridCiphertext {
276 pub classical_pk: [u8; CLASSICAL_PK_BYTES],
279 pub ml_kem_ct: Vec<u8>,
281}
282
283fn decode_ml_kem_encap_key(bytes: &[u8]) -> Option<Encoded<MlKem768EncapKey>> {
292 Encoded::<MlKem768EncapKey>::try_from(bytes).ok()
293}
294
295fn decode_ml_kem_ciphertext(
296 bytes: &[u8],
297) -> Option<Array<u8, <MlKem768 as KemCore>::CiphertextSize>> {
298 Array::<u8, <MlKem768 as KemCore>::CiphertextSize>::try_from(bytes).ok()
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn hybrid_kem_round_trip() {
307 let (sk, pk) = HybridSecretKey::generate();
308 let (ss_send, ct) = pk.encapsulate().expect("encap");
309 let ss_recv = sk.decapsulate(&ct).expect("decap");
310 assert_eq!(
311 ss_send, ss_recv,
312 "encap/decap must agree on the shared secret"
313 );
314 }
315
316 #[test]
317 fn hybrid_kem_two_handshakes_yield_distinct_secrets() {
318 let (_sk, pk) = HybridSecretKey::generate();
319 let (ss1, _ct1) = pk.encapsulate().expect("first encap");
320 let (ss2, _ct2) = pk.encapsulate().expect("second encap");
321 assert_ne!(ss1, ss2);
324 }
325
326 #[test]
327 fn ml_kem_ciphertext_size_matches_fips_203() {
328 let (_sk, pk) = HybridSecretKey::generate();
330 let (_ss, ct) = pk.encapsulate().expect("encap");
331 assert_eq!(ct.ml_kem_ct.len(), 1088);
332 }
333
334 #[test]
335 fn ml_kem_public_key_size_matches_fips_203() {
336 let (_sk, pk) = HybridSecretKey::generate();
338 assert_eq!(pk.ml_kem_pk.len(), 1184);
339 }
340
341 #[test]
342 fn hybrid_kem_two_secrets_distinct_under_same_recipient_key() {
343 let (sk, pk) = HybridSecretKey::generate();
344 let (ss1, ct1) = pk.encapsulate().expect("encap1");
345 let (_ss2, _ct2) = pk.encapsulate().expect("encap2");
346 let pt1 = sk.decapsulate(&ct1).expect("decap1");
347 assert_eq!(pt1, ss1);
349 }
350
351 #[test]
353 fn classical_public_key_size_matches_backend() {
354 let (_sk, pk) = HybridSecretKey::generate();
355 assert_eq!(pk.classical_pk.len(), CLASSICAL_PK_BYTES);
356 #[cfg(not(feature = "fips"))]
357 assert_eq!(CLASSICAL_PK_BYTES, 32, "X25519 public key is 32 bytes");
358 #[cfg(feature = "fips")]
359 assert_eq!(
360 CLASSICAL_PK_BYTES, 65,
361 "ECDH-P-256 uncompressed SEC1 public key is 65 bytes"
362 );
363 }
364
365 #[cfg(feature = "fips")]
367 #[test]
368 fn fips_classical_public_key_is_uncompressed_sec1() {
369 let (_sk, pk) = HybridSecretKey::generate();
370 assert_eq!(
371 pk.classical_pk[0], 0x04,
372 "uncompressed SEC1 P-256 key must lead with 0x04"
373 );
374 }
375}