Skip to main content

phantom_protocol/crypto/
hybrid_kem.rs

1//! Hybrid KEM: classical ECDH + ML-KEM-768 (FIPS 203, post-quantum).
2//!
3//! Phase 5.1 — switched the PQ half from `pqcrypto-kyber`'s C reference
4//! implementation of NIST PQC round-3 Kyber768 to the RustCrypto pure-Rust
5//! `ml-kem` crate's FIPS-203 ML-KEM-768. Same algorithm at the math level,
6//! but the byte encoding follows FIPS 203.
7//!
8//! Under `--features fips`, the classical half swaps from X25519
9//! to ECDH-P-256 via `aws-lc-rs`. The classical public-key length on
10//! the wire grows from 32 bytes (X25519) to 65 bytes (uncompressed
11//! SEC1 P-256). Cross-mode interop (fips ↔ non-fips) is **not
12//! supported** — both peers MUST be compiled with matching feature
13//! flags, and the `PROTOCOL_VARIANT` handshake constant
14//! (`transport::handshake::PROTOCOL_VARIANT`) is baked into the
15//! signed transcript so a mixed-mode attempt fails on the client's
16//! signature check rather than producing a silently-wrong shared
17//! secret.
18//!
19//! Both KEM halves contribute 32 bytes of shared secret, combined via
20//! `HKDF-SHA-256` with the label `"HybridKEM_X25519_Kyber768"` on the
21//! default build and `"HybridKEM_P256_Kyber768"` under fips. The label
22//! divergence is intentional defense-in-depth: even if `PROTOCOL_VARIANT`
23//! were stripped, the derived traffic secret would differ.
24
25use borsh::{BorshDeserialize, BorshSerialize};
26use hkdf::Hkdf;
27use ml_kem::array::Array;
28use ml_kem::kem::{Decapsulate, Encapsulate};
29use ml_kem::{Encoded, EncodedSizeUser, KemCore, MlKem768};
30use rand::rngs::OsRng;
31use sha2::Sha256;
32use std::fmt;
33use zeroize::ZeroizeOnDrop;
34
35#[cfg(not(feature = "fips"))]
36use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};
37
38#[cfg(feature = "fips")]
39use aws_lc_rs::{
40    agreement::{self, agree, EphemeralPrivateKey, PrivateKey, UnparsedPublicKey, ECDH_P256},
41    rand::SystemRandom,
42};
43
44type MlKem768DecapKey = <MlKem768 as KemCore>::DecapsulationKey;
45type MlKem768EncapKey = <MlKem768 as KemCore>::EncapsulationKey;
46
47/// Classical KEM public-key byte length on the wire.
48///
49/// - Default build: X25519 → 32 bytes (RFC 7748).
50/// - `--features fips`: ECDH-P-256 uncompressed SEC1 → 65 bytes.
51#[cfg(not(feature = "fips"))]
52pub const CLASSICAL_PK_BYTES: usize = 32;
53#[cfg(feature = "fips")]
54pub const CLASSICAL_PK_BYTES: usize = 65;
55
56/// Combined-secret HKDF label. The default build keeps the V1/V2 label
57/// verbatim so the protocol's KDF-label inventory stays stable; the fips
58/// build uses a distinct label because the classical primitive is
59/// different.
60#[cfg(not(feature = "fips"))]
61const COMBINE_LABEL: &[u8] = b"HybridKEM_X25519_Kyber768";
62#[cfg(feature = "fips")]
63const COMBINE_LABEL: &[u8] = b"HybridKEM_P256_Kyber768";
64
65/// Hybrid secret key. Holds the classical long-term secret (X25519 by
66/// default, ECDH-P-256 under fips) and the ML-KEM-768 decapsulation
67/// key. Both halves are zeroized on drop — `ml_kem`'s `DecapsulationKey`
68/// implements `Zeroize` natively, and the classical side either uses
69/// `x25519_dalek`'s `Zeroize` impl (default) or aws-lc-rs's internal
70/// Drop, which frees the underlying key material.
71///
72/// `ml_kem_dk` is `Box`-ed so the (~2.4 KiB) decapsulation key lives on
73/// the heap; constructing several `HybridSecretKey`s in a deep call
74/// chain (as happens during the handshake) would otherwise stress
75/// tokio's default test thread stack.
76#[derive(ZeroizeOnDrop)]
77pub struct HybridSecretKey {
78    /// Classical long-lived secret. Type depends on the active backend:
79    /// `x25519_dalek::StaticSecret` (default) or
80    /// `aws_lc_rs::agreement::PrivateKey` (`--features fips`, ECDH-P-256).
81    #[cfg(not(feature = "fips"))]
82    pub classical_sk: StaticSecret,
83    #[cfg(feature = "fips")]
84    #[zeroize(skip)] // aws-lc-rs frees the inner key on Drop
85    pub classical_sk: PrivateKey,
86
87    /// ML-KEM-768 decapsulation key (FIPS 203). Boxed to keep stack
88    /// pressure down — the structure is ~2.4 KiB.
89    #[zeroize(skip)] // Box's Drop calls T::Drop which zeroes the inner key
90    pub ml_kem_dk: Box<MlKem768DecapKey>,
91}
92
93impl HybridSecretKey {
94    pub fn generate() -> (Self, HybridKeyPackage) {
95        let mut rng = OsRng;
96
97        // Classical (X25519 or ECDH-P-256) key generation + public key
98        // derivation. Branch is fully cfg-gated; the build pulls in
99        // exactly one path.
100        #[cfg(not(feature = "fips"))]
101        let (classical_sk, classical_pk_bytes) = {
102            let sk = StaticSecret::random_from_rng(rng);
103            let pk = X25519PublicKey::from(&sk);
104            (sk, *pk.as_bytes())
105        };
106        #[cfg(feature = "fips")]
107        let (classical_sk, classical_pk_bytes) = {
108            // PANIC-SAFETY: `PrivateKey::generate` only fails when the
109            // underlying AWS-LC random source is broken — same failure
110            // mode as `getrandom` on the default build, where we also
111            // panic via `OsRng`. `compute_public_key` derives a
112            // P-256 public from a fresh, just-generated valid private,
113            // which cannot fail. A failure here means the FIPS module
114            // is in a non-recoverable state; loud panic is the correct
115            // surface for the embedder.
116            #[allow(clippy::expect_used)]
117            let sk = PrivateKey::generate(&ECDH_P256)
118                .expect("aws-lc-rs ECDH-P-256 generate must succeed");
119            #[allow(clippy::expect_used)]
120            let pk = sk
121                .compute_public_key()
122                .expect("aws-lc-rs ECDH-P-256 compute_public_key must succeed");
123            let mut bytes = [0u8; CLASSICAL_PK_BYTES];
124            bytes.copy_from_slice(pk.as_ref());
125            (sk, bytes)
126        };
127
128        // ML-KEM-768 (post-quantum, FIPS 203). Box the decap key so the
129        // ~2.4 KiB structure never lives on the stack.
130        let (dk, ek) = MlKem768::generate(&mut rng);
131
132        let secret_key = HybridSecretKey {
133            classical_sk,
134            ml_kem_dk: Box::new(dk),
135        };
136        let key_package = HybridKeyPackage {
137            classical_pk: classical_pk_bytes,
138            ml_kem_pk: ek.as_bytes().to_vec(),
139        };
140        (secret_key, key_package)
141    }
142
143    pub fn decapsulate(&self, ciphertext: &HybridCiphertext) -> Result<[u8; 32], anyhow::Error> {
144        // 1. Classical ECDH.
145        #[cfg(not(feature = "fips"))]
146        let classical_shared: [u8; 32] = {
147            let peer = X25519PublicKey::from(ciphertext.classical_pk);
148            let s = self.classical_sk.diffie_hellman(&peer);
149            *s.as_bytes()
150        };
151        #[cfg(feature = "fips")]
152        let classical_shared: [u8; 32] = {
153            let peer = UnparsedPublicKey::new(&ECDH_P256, &ciphertext.classical_pk[..]);
154            // aws-lc-rs's `agree` returns `Result<R, E>` where the
155            // closure is `FnOnce(&[u8]) -> Result<R, E>`. The
156            // `error_value` arg is the E returned when peer-key parse
157            // fails before the closure runs.
158            agree(
159                &self.classical_sk,
160                peer,
161                anyhow::anyhow!("aws-lc-rs ECDH-P-256 agree failed (peer key parse)"),
162                |km| -> Result<[u8; 32], anyhow::Error> {
163                    // ECDH-P-256 shared secret is the 32-byte X coordinate.
164                    let mut out = [0u8; 32];
165                    out.copy_from_slice(km);
166                    Ok(out)
167                },
168            )?
169        };
170
171        // 2. ML-KEM-768 decapsulation.
172        let ct_array = decode_ml_kem_ciphertext(&ciphertext.ml_kem_ct)
173            .ok_or_else(|| anyhow::anyhow!("invalid ML-KEM-768 ciphertext length"))?;
174        let ml_kem_shared = self
175            .ml_kem_dk
176            .decapsulate(&ct_array)
177            .map_err(|e| anyhow::anyhow!("ML-KEM decapsulation failed: {:?}", e))?;
178
179        // 3. Combine the two 32-byte secrets via HKDF.
180        Self::combine_secrets(&classical_shared, ml_kem_shared.as_slice())
181    }
182
183    pub(crate) fn combine_secrets(
184        ecc_secret: &[u8],
185        pq_secret: &[u8],
186    ) -> Result<[u8; 32], anyhow::Error> {
187        // CRYPTO-3: the combined IKM holds both raw classical and ML-KEM shared
188        // secrets — wipe it on every exit path rather than leaving it in freed
189        // memory.
190        let ikm = zeroize::Zeroizing::new([ecc_secret, pq_secret].concat());
191        let hkdf = Hkdf::<Sha256>::new(None, &ikm);
192        let mut okm = [0u8; 32];
193        hkdf.expand(COMBINE_LABEL, &mut okm)
194            .map_err(|_| anyhow::anyhow!("HKDF expansion failed"))?;
195        Ok(okm)
196    }
197}
198
199impl fmt::Debug for HybridSecretKey {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        f.debug_struct("HybridSecretKey")
202            .field("classical_sk", &"REDACTED")
203            .field("ml_kem_dk", &"REDACTED")
204            .finish()
205    }
206}
207
208#[derive(BorshSerialize, BorshDeserialize, Debug, Clone)]
209pub struct HybridKeyPackage {
210    /// Classical public key. Encoded as raw bytes; semantics depend on
211    /// the build (X25519 32-byte key by default, P-256 uncompressed
212    /// SEC1 65-byte key under fips).
213    pub classical_pk: [u8; CLASSICAL_PK_BYTES],
214    pub ml_kem_pk: Vec<u8>,
215}
216
217impl HybridKeyPackage {
218    pub fn encapsulate(&self) -> Result<([u8; 32], HybridCiphertext), anyhow::Error> {
219        let mut rng = OsRng;
220
221        // 1. Classical ECDH: fresh ephemeral on the sender side.
222        #[cfg(not(feature = "fips"))]
223        let (eph_pk_bytes, classical_shared) = {
224            let eph_sk = StaticSecret::random_from_rng(rng);
225            let eph_pk = X25519PublicKey::from(&eph_sk);
226            let peer = X25519PublicKey::from(self.classical_pk);
227            let shared = eph_sk.diffie_hellman(&peer);
228            (*eph_pk.as_bytes(), *shared.as_bytes())
229        };
230        #[cfg(feature = "fips")]
231        let (eph_pk_bytes, classical_shared): ([u8; CLASSICAL_PK_BYTES], [u8; 32]) = {
232            let aws_rng = SystemRandom::new();
233            let eph_sk = EphemeralPrivateKey::generate(&ECDH_P256, &aws_rng)
234                .map_err(|e| anyhow::anyhow!("aws-lc-rs ECDH-P-256 ephemeral generate: {:?}", e))?;
235            let eph_pk = eph_sk
236                .compute_public_key()
237                .map_err(|e| anyhow::anyhow!("compute_public_key: {:?}", e))?;
238            let mut pk_bytes = [0u8; CLASSICAL_PK_BYTES];
239            pk_bytes.copy_from_slice(eph_pk.as_ref());
240            let peer = UnparsedPublicKey::new(&ECDH_P256, &self.classical_pk[..]);
241            let shared = agreement::agree_ephemeral(
242                eph_sk,
243                peer,
244                anyhow::anyhow!("aws-lc-rs ECDH-P-256 agree_ephemeral failed (peer parse)"),
245                |km| -> Result<[u8; 32], anyhow::Error> {
246                    let mut o = [0u8; 32];
247                    o.copy_from_slice(km);
248                    Ok(o)
249                },
250            )?;
251            (pk_bytes, shared)
252        };
253
254        // 2. ML-KEM-768 encapsulation against the peer's encap key.
255        let ek_array = decode_ml_kem_encap_key(&self.ml_kem_pk)
256            .ok_or_else(|| anyhow::anyhow!("invalid ML-KEM-768 public key length"))?;
257        let ek = MlKem768EncapKey::from_bytes(&ek_array);
258        let (ct, ml_kem_shared) = ek
259            .encapsulate(&mut rng)
260            .map_err(|e| anyhow::anyhow!("ML-KEM encapsulation failed: {:?}", e))?;
261
262        // 3. Combine via HKDF.
263        let shared_secret =
264            HybridSecretKey::combine_secrets(&classical_shared, ml_kem_shared.as_slice())?;
265
266        let ciphertext = HybridCiphertext {
267            classical_pk: eph_pk_bytes,
268            ml_kem_ct: ct.as_slice().to_vec(),
269        };
270        Ok((shared_secret, ciphertext))
271    }
272}
273
274#[derive(BorshSerialize, BorshDeserialize, Debug, Clone)]
275pub struct HybridCiphertext {
276    /// Sender's ephemeral classical public key. Encoding matches
277    /// [`HybridKeyPackage::classical_pk`].
278    pub classical_pk: [u8; CLASSICAL_PK_BYTES],
279    /// ML-KEM-768 ciphertext bytes (FIPS-203 encoded).
280    pub ml_kem_ct: Vec<u8>,
281}
282
283// ─── Encoding helpers ─────────────────────────────────────────────────────
284//
285// `ml-kem` stores its byte-encoded keys and ciphertexts as `Encoded<T>`,
286// a `GenericArray<u8, N>` from the `hybrid-array` crate. We carry them on
287// the wire as `Vec<u8>` (borsh-friendly) and round-trip via these
288// helpers. Length mismatches return `None` so callers can map them to a
289// proper handshake / KEM error.
290
291fn decode_ml_kem_encap_key(bytes: &[u8]) -> Option<Encoded<MlKem768EncapKey>> {
292    Encoded::<MlKem768EncapKey>::try_from(bytes).ok()
293}
294
295fn decode_ml_kem_ciphertext(
296    bytes: &[u8],
297) -> Option<Array<u8, <MlKem768 as KemCore>::CiphertextSize>> {
298    Array::<u8, <MlKem768 as KemCore>::CiphertextSize>::try_from(bytes).ok()
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn hybrid_kem_round_trip() {
307        let (sk, pk) = HybridSecretKey::generate();
308        let (ss_send, ct) = pk.encapsulate().expect("encap");
309        let ss_recv = sk.decapsulate(&ct).expect("decap");
310        assert_eq!(
311            ss_send, ss_recv,
312            "encap/decap must agree on the shared secret"
313        );
314    }
315
316    #[test]
317    fn hybrid_kem_two_handshakes_yield_distinct_secrets() {
318        let (_sk, pk) = HybridSecretKey::generate();
319        let (ss1, _ct1) = pk.encapsulate().expect("first encap");
320        let (ss2, _ct2) = pk.encapsulate().expect("second encap");
321        // Same recipient, different sender ephemeral classical + different
322        // ML-KEM randomness → different shared secrets.
323        assert_ne!(ss1, ss2);
324    }
325
326    #[test]
327    fn ml_kem_ciphertext_size_matches_fips_203() {
328        // FIPS-203 ML-KEM-768 ciphertext is 1088 bytes.
329        let (_sk, pk) = HybridSecretKey::generate();
330        let (_ss, ct) = pk.encapsulate().expect("encap");
331        assert_eq!(ct.ml_kem_ct.len(), 1088);
332    }
333
334    #[test]
335    fn ml_kem_public_key_size_matches_fips_203() {
336        // FIPS-203 ML-KEM-768 encap key is 1184 bytes.
337        let (_sk, pk) = HybridSecretKey::generate();
338        assert_eq!(pk.ml_kem_pk.len(), 1184);
339    }
340
341    #[test]
342    fn hybrid_kem_two_secrets_distinct_under_same_recipient_key() {
343        let (sk, pk) = HybridSecretKey::generate();
344        let (ss1, ct1) = pk.encapsulate().expect("encap1");
345        let (_ss2, _ct2) = pk.encapsulate().expect("encap2");
346        let pt1 = sk.decapsulate(&ct1).expect("decap1");
347        // The recipient's decap yields the same secret as the sender's encap1.
348        assert_eq!(pt1, ss1);
349    }
350
351    /// Classical public key length matches the active backend.
352    #[test]
353    fn classical_public_key_size_matches_backend() {
354        let (_sk, pk) = HybridSecretKey::generate();
355        assert_eq!(pk.classical_pk.len(), CLASSICAL_PK_BYTES);
356        #[cfg(not(feature = "fips"))]
357        assert_eq!(CLASSICAL_PK_BYTES, 32, "X25519 public key is 32 bytes");
358        #[cfg(feature = "fips")]
359        assert_eq!(
360            CLASSICAL_PK_BYTES, 65,
361            "ECDH-P-256 uncompressed SEC1 public key is 65 bytes"
362        );
363    }
364
365    /// fips-only: P-256 SEC1 uncompressed encoding starts with 0x04.
366    #[cfg(feature = "fips")]
367    #[test]
368    fn fips_classical_public_key_is_uncompressed_sec1() {
369        let (_sk, pk) = HybridSecretKey::generate();
370        assert_eq!(
371            pk.classical_pk[0], 0x04,
372            "uncompressed SEC1 P-256 key must lead with 0x04"
373        );
374    }
375}