Skip to main content

pq_ratchet/
ratchet.rs

1//! Hybrid Double Ratchet state machine.
2//!
3//! Extends the Signal Double Ratchet algorithm with ML-KEM-768 post-quantum
4//! ratcheting through the SCKA epoch model. Two root-KDF invocations per step:
5//!
6//! Receiving chain: `KDF_RK(rk, DH(old_dh, their_dh) || decap(their_ct))`
7//! Sending chain:   `KDF_RK(rk, DH(new_dh, their_dh) || encap(their_ek).ss)`
8//!
9//! Attackers have to break X25519 AND ML-KEM-768. If one falls, the other
10//! still provides full security.
11
12use std::collections::HashMap;
13use std::fmt;
14
15use rand_core::CryptoRngCore;
16use subtle::ConstantTimeEq;
17use x25519_dalek::{PublicKey, StaticSecret};
18use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
19
20use crate::{
21    error::RatchetError,
22    kdf::{kdf_ck, kdf_rk},
23    scka::{PqCt, PqEk, SckaState, PQ_CT_LEN, PQ_DK_LEN, PQ_EK_LEN, PQ_SS_LEN},
24};
25
26/// Maximum skipped messages in a single chain step. Signal's recommended value.
27/// Raise it if your transport layer reorders deeper than 1,000 messages.
28pub const MAX_SKIP: usize = 1_000;
29
30/// Hard cap on the total number of entries in the skipped-key cache across all
31/// DH epochs.  Without this, a malicious peer can force unbounded memory growth
32/// by ratcheting through many epochs each with skipped messages.
33pub const MAX_SKIP_TOTAL: usize = 2_000;
34
35// ── Public types ──────────────────────────────────────────────────────────────
36
37/// Per-message header transmitted alongside each ciphertext.
38///
39/// Callers are responsible for authenticating this header (e.g. AEAD additional
40/// data)  --  this crate derives keys only, it does not encrypt.
41///
42/// Fields are not publicly settable to prevent construction of invalid headers.
43/// Use [`Header::new`] to create a header, or read fields through the accessor
44/// methods after receiving one from [`HybridRatchet::ratchet_encrypt`].
45#[derive(Clone, Debug)]
46pub struct Header {
47    /// Sender's current X25519 DH ratchet public key.
48    pub(crate) dh_pk: [u8; 32],
49    /// Index of this message within the current sending chain.
50    pub(crate) n: u32,
51    /// Length of the *previous* sending chain (needed to skip stale messages).
52    pub(crate) pn: u32,
53    /// Sender's current ML-KEM-768 encapsulation key.
54    /// Receiver encapsulates to this and returns a ciphertext in their next header.
55    pub(crate) pq_ek: Option<PqEk>,
56    /// ML-KEM-768 ciphertext  --  the sender's encapsulation response to the receiver's
57    /// most recently seen EK.  Receiver decapsulates to obtain the PQ shared-secret.
58    pub(crate) pq_ct: Option<PqCt>,
59}
60
61/// Single-use message key derived from the symmetric ratchet chain.
62///
63/// Use this to key an AEAD cipher (e.g. ChaCha20-Poly1305).  Zeroed on drop.
64#[derive(ZeroizeOnDrop)]
65pub struct MessageKey(pub [u8; 32]);
66
67impl MessageKey {
68    /// Return the raw 32-byte message key.
69    pub fn as_bytes(&self) -> &[u8; 32] {
70        &self.0
71    }
72}
73
74impl fmt::Debug for MessageKey {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_tuple("MessageKey").field(&"[REDACTED]").finish()
77    }
78}
79
80// ── Header constructors and accessors ────────────────────────────────────────
81
82impl Header {
83    /// Construct a [`Header`] with the given fields.
84    pub fn new(
85        dh_pk: [u8; 32],
86        n: u32,
87        pn: u32,
88        pq_ek: Option<PqEk>,
89        pq_ct: Option<PqCt>,
90    ) -> Self {
91        Self {
92            dh_pk,
93            n,
94            pn,
95            pq_ek,
96            pq_ct,
97        }
98    }
99
100    /// Sender's current X25519 DH ratchet public key.
101    pub fn dh_pk(&self) -> [u8; 32] {
102        self.dh_pk
103    }
104
105    /// Index of this message within the current sending chain.
106    pub fn n(&self) -> u32 {
107        self.n
108    }
109
110    /// Length of the previous sending chain.
111    pub fn pn(&self) -> u32 {
112        self.pn
113    }
114
115    /// Sender's current ML-KEM-768 encapsulation key, if present.
116    pub fn pq_ek(&self) -> Option<&PqEk> {
117        self.pq_ek.as_ref()
118    }
119
120    /// ML-KEM-768 ciphertext response to receiver's EK, if present.
121    pub fn pq_ct(&self) -> Option<&PqCt> {
122        self.pq_ct.as_ref()
123    }
124}
125
126// ── Header wire encoding ──────────────────────────────────────────────────────
127
128impl Header {
129    /// Encode to a canonical byte string for use as AEAD additional data.
130    ///
131    /// Format (all integers little-endian):
132    /// ```text
133    /// dh_pk (32) | n (4) | pn (4) | flags (1) | [pq_ek (1184)] | [pq_ct (1088)]
134    /// ```
135    /// `flags` bit 0 = EK present; bit 1 = CT present.
136    ///
137    /// Allocates a new `Vec`.  Use [`write_to`](Self::write_to) to append into
138    /// a reusable buffer instead.
139    pub fn encode(&self) -> Vec<u8> {
140        let mut buf = Vec::with_capacity(
141            41 + self.pq_ek.as_ref().map_or(0, |_| PQ_EK_LEN)
142                + self.pq_ct.as_ref().map_or(0, |_| PQ_CT_LEN),
143        );
144        self.write_to(&mut buf);
145        buf
146    }
147
148    /// Append the canonical encoding to an existing buffer.
149    ///
150    /// Prefer this over [`encode`](Self::encode) when the caller already has a
151    /// pre-allocated `Vec` (e.g. a per-connection AEAD header buffer) to avoid
152    /// a per-message allocation.
153    pub fn write_to(&self, buf: &mut Vec<u8>) {
154        let flags: u8 = (self.pq_ek.is_some() as u8) | ((self.pq_ct.is_some() as u8) << 1);
155        buf.extend_from_slice(&self.dh_pk);
156        buf.extend_from_slice(&self.n.to_le_bytes());
157        buf.extend_from_slice(&self.pn.to_le_bytes());
158        buf.push(flags);
159        if let Some(ek) = &self.pq_ek {
160            buf.extend_from_slice(&ek.0);
161        }
162        if let Some(ct) = &self.pq_ct {
163            buf.extend_from_slice(&ct.0);
164        }
165    }
166
167    /// Decode from the canonical byte representation produced by [`Header::encode`].
168    pub fn decode(bytes: &[u8]) -> Result<Self, RatchetError> {
169        const MIN: usize = 41; // dh_pk(32) + n(4) + pn(4) + flags(1)
170        if bytes.len() < MIN {
171            return Err(RatchetError::MalformedHeader("too short"));
172        }
173        // Slices are guaranteed to be the right length by the MIN check above.
174        let dh_pk: [u8; 32] = bytes[..32].try_into().expect("32-byte slice after MIN check");
175        let n = u32::from_le_bytes(bytes[32..36].try_into().expect("4-byte slice after MIN check"));
176        let pn = u32::from_le_bytes(bytes[36..40].try_into().expect("4-byte slice after MIN check"));
177        let flags = bytes[40];
178        if flags & !0x03 != 0 {
179            return Err(RatchetError::MalformedHeader("unknown flags"));
180        }
181        let has_ek = flags & 0x01 != 0;
182        let has_ct = flags & 0x02 != 0;
183
184        let mut pos = MIN;
185        let pq_ek = if has_ek {
186            let end = pos + PQ_EK_LEN;
187            if bytes.len() < end {
188                return Err(RatchetError::MalformedHeader("truncated EK"));
189            }
190            let ek: [u8; PQ_EK_LEN] = bytes[pos..end].try_into().expect("slice length guaranteed by bounds check");
191            pos = end;
192            Some(PqEk(ek))
193        } else {
194            None
195        };
196        let pq_ct = if has_ct {
197            let end = pos + PQ_CT_LEN;
198            if bytes.len() < end {
199                return Err(RatchetError::MalformedHeader("truncated CT"));
200            }
201            let ct: [u8; PQ_CT_LEN] = bytes[pos..end].try_into().expect("slice length guaranteed by bounds check");
202            pos = end;
203            Some(PqCt(ct))
204        } else {
205            None
206        };
207        if pos != bytes.len() {
208            return Err(RatchetError::MalformedHeader("trailing bytes"));
209        }
210        Ok(Header {
211            dh_pk,
212            n,
213            pn,
214            pq_ek,
215            pq_ct,
216        })
217    }
218}
219
220// ── State ─────────────────────────────────────────────────────────────────────
221
222/// Hybrid Double Ratchet session state.
223///
224/// Holds all key material for one end of a conversation.  Zeroes all secret
225/// fields on drop via [`ZeroizeOnDrop`] and manual [`Drop`].
226pub struct HybridRatchet {
227    // ── DH ratchet ───────────────────────────────────────────────────────────
228    dh_sk: StaticSecret,            // our current DH private key
229    dh_pk: PublicKey,               // our current DH public key
230    dh_pk_remote: Option<[u8; 32]>, // peer's latest DH public key; None until first recv
231
232    // ── Root & chain keys ────────────────────────────────────────────────────
233    rk: [u8; 32],          // root key, mixed at each DH ratchet step
234    cks: Option<[u8; 32]>, // sending chain key  (None for receiver before first decrypt)
235    ckr: Option<[u8; 32]>, // receiving chain key (None for sender before first decrypt)
236
237    // ── Message counters ─────────────────────────────────────────────────────
238    ns: u32, // messages sent in current sending chain
239    nr: u32, // messages received in current receiving chain
240    pn: u32, // length of *previous* sending chain
241
242    // ── Out-of-order cache ───────────────────────────────────────────────────
243    /// (remote_dh_pk, message_index) → message_key
244    ///
245    /// Values are wrapped in `Zeroizing` so they are explicitly zeroed when
246    /// removed or when the map is dropped.  The map is pre-allocated at
247    /// `MAX_SKIP_TOTAL` capacity on construction so it never reallocates  -- 
248    /// preventing freed backing memory from retaining plaintext key material.
249    skipped: HashMap<([u8; 32], u32), Zeroizing<[u8; 32]>>,
250
251    // ── Post-quantum state ───────────────────────────────────────────────────
252    scka: SckaState,
253}
254
255impl Drop for HybridRatchet {
256    fn drop(&mut self) {
257        self.rk.zeroize();
258        if let Some(ref mut k) = self.cks {
259            k.zeroize();
260        }
261        if let Some(ref mut k) = self.ckr {
262            k.zeroize();
263        }
264        // Zeroizing<[u8;32]> values already zero themselves on drop, but
265        // explicitly zeroing here clears the memory before the HashMap
266        // allocator reclaims it.
267        for v in self.skipped.values_mut() {
268            v.zeroize();
269        }
270    }
271}
272
273// ── Constructors ──────────────────────────────────────────────────────────────
274
275impl HybridRatchet {
276    /// Initialise as the **sending** party (e.g. Alice in Signal X3DH).
277    ///
278    /// Performs an immediate DH ratchet step against `peer_dh_pk`, deriving the
279    /// initial sending chain key.  The result can call `ratchet_encrypt` immediately.
280    ///
281    /// # Arguments
282    /// * `shared_secret`  --  32-byte pre-shared secret from the key agreement phase
283    ///   (e.g. the output of PQXDH or X3DH).
284    /// * `peer_dh_pk`     --  peer's X25519 ratchet public key (from their prekey bundle).
285    pub fn init_sender(
286        shared_secret: &[u8; 32],
287        peer_dh_pk: &[u8; 32],
288        rng: &mut impl CryptoRngCore,
289    ) -> Self {
290        let dh_sk = StaticSecret::random_from_rng(&mut *rng);
291        let dh_pk = PublicKey::from(&dh_sk);
292        let peer_pk = PublicKey::from(*peer_dh_pk);
293
294        let dh_ss = dh_sk.diffie_hellman(&peer_pk);
295        // No PQ exchange has occurred yet; contribute zero bytes for this first step.
296        let (rk, cks) = kdf_rk(shared_secret, dh_ss.as_bytes(), &[0u8; PQ_SS_LEN]);
297
298        let scka = SckaState::new(rng);
299
300        Self {
301            dh_sk,
302            dh_pk,
303            dh_pk_remote: Some(*peer_dh_pk),
304            rk,
305            cks: Some(cks),
306            ckr: None,
307            ns: 0,
308            nr: 0,
309            pn: 0,
310            skipped: HashMap::with_capacity(MAX_SKIP_TOTAL),
311            scka,
312        }
313    }
314
315    /// Initialise as the **receiving** party (e.g. Bob in Signal X3DH).
316    ///
317    /// Does not perform a DH ratchet step yet; call `ratchet_decrypt` on the first
318    /// incoming message to derive both the receiving and sending chain keys.
319    ///
320    /// # Arguments
321    /// * `shared_secret`  --  32-byte pre-shared secret (same as sender's).
322    /// * `our_dh_sk`      --  our X25519 ratchet secret key (the one whose public key
323    ///   was shared with the sender in the prekey bundle).
324    pub fn init_receiver(
325        shared_secret: &[u8; 32],
326        our_dh_sk: StaticSecret,
327        rng: &mut impl CryptoRngCore,
328    ) -> Self {
329        let dh_pk = PublicKey::from(&our_dh_sk);
330        let scka = SckaState::new(rng);
331
332        Self {
333            dh_sk: our_dh_sk,
334            dh_pk,
335            dh_pk_remote: None,
336            rk: *shared_secret,
337            cks: None,
338            ckr: None,
339            ns: 0,
340            nr: 0,
341            pn: 0,
342            skipped: HashMap::with_capacity(MAX_SKIP_TOTAL),
343            scka,
344        }
345    }
346
347    /// Return our current X25519 DH public key (for sharing in a prekey bundle).
348    ///
349    /// Returns raw bytes.  X25519 public keys are conventionally exchanged as
350    /// byte arrays; use [`x25519_dalek::PublicKey::from`] to convert if needed.
351    /// The PQ equivalent, [`our_pq_ek`](Self::our_pq_ek), returns a typed
352    /// [`PqEk`] wrapper instead  --  the difference is intentional since ML-KEM
353    /// keys are less commonly manipulated as raw bytes.
354    pub fn our_dh_pk(&self) -> [u8; 32] {
355        *self.dh_pk.as_bytes()
356    }
357
358    /// Return our current ML-KEM-768 encapsulation key (for sharing in a prekey bundle).
359    ///
360    /// Returns a typed [`PqEk`] wrapper around the 1184-byte encoding.
361    pub fn our_pq_ek(&self) -> PqEk {
362        self.scka.our_ek().clone()
363    }
364}
365
366// ── Encrypt / Decrypt ─────────────────────────────────────────────────────────
367
368impl HybridRatchet {
369    /// Advance the sending chain and return a `(Header, MessageKey)` pair.
370    ///
371    /// The message key is derived from the current sending chain key.  Use it to
372    /// AEAD-encrypt the plaintext, then transmit `header` alongside the ciphertext.
373    ///
374    /// The `rng` parameter is reserved for future protocol extensions and is not
375    /// consumed by this call.  Any [`CryptoRngCore`] implementor is accepted.
376    ///
377    /// **PQ ciphertext retransmission**: if a pending ML-KEM ciphertext exists (set
378    /// during the last DH ratchet step), it is included in every outgoing header
379    /// until the peer's next DH ratchet step replaces it.  This ensures the PQ
380    /// shared-secret is established even if the first message carrying the CT is
381    /// lost in transit.  A caller MUST NOT assume the CT was received until the
382    /// peer's subsequent DH ratchet is observed.
383    ///
384    /// # Errors
385    /// Returns [`RatchetError::NoSendingChain`] if the sending chain has not been
386    /// initialised yet (receiver must call `ratchet_decrypt` at least once first).
387    pub fn ratchet_encrypt(
388        &mut self,
389        _rng: &mut impl CryptoRngCore,
390    ) -> Result<(Header, MessageKey), RatchetError> {
391        let cks = self.cks.as_mut().ok_or(RatchetError::NoSendingChain)?;
392
393        let (new_ck, mk) = kdf_ck(cks);
394        *cks = new_ck;
395
396        let n = self.ns;
397        self.ns += 1;
398
399        let header = Header {
400            dh_pk: *self.dh_pk.as_bytes(),
401            n,
402            pn: self.pn,
403            pq_ek: Some(self.scka.our_ek().clone()),
404            // Peek (clone without consuming) so the CT is retransmitted in every
405            // message until the peer's next DH ratchet generates a replacement.
406            pq_ct: self.scka.pending_ct_ref().cloned(),
407        };
408
409        Ok((header, MessageKey(mk)))
410    }
411
412    /// Advance the receiving chain and return the `MessageKey` for this message.
413    ///
414    /// Handles DH ratchet steps, out-of-order delivery (caching skipped keys), and
415    /// the hybrid PQ key exchange within each DH ratchet step.
416    ///
417    /// The `rng` parameter is only consumed when the incoming header carries a new
418    /// DH public key (triggering a full DH ratchet step with ML-KEM encapsulation).
419    /// For within-epoch messages  --  including cache hits from out-of-order delivery  -- 
420    /// the RNG is untouched.
421    ///
422    /// **Implicit key confirmation**: successful AEAD decryption by the caller
423    /// serves as the only confirmation that both parties derived matching keys.
424    /// There is no explicit key confirmation step  --  this matches Signal's design.
425    ///
426    /// # Errors
427    /// - [`RatchetError::TooManySkipped`] if the out-of-order cache would overflow.
428    /// - [`RatchetError::MessageKeyNotFound`] if this message was already decrypted
429    ///   or is older than what remains in the cache.
430    /// - Crypto errors from malformed PQ material in the header.
431    pub fn ratchet_decrypt(
432        &mut self,
433        header: &Header,
434        rng: &mut impl CryptoRngCore,
435    ) -> Result<MessageKey, RatchetError> {
436        // 1. Fast path: this key was already cached from a previous out-of-order scan.
437        if let Some(mk) = self.skipped.remove(&(header.dh_pk, header.n)) {
438            return Ok(MessageKey(*mk));
439        }
440
441        // 2. Determine if the remote DH key has changed (DHRatchet trigger).
442        // Constant-time comparison prevents timing side-channels on the DH key.
443        // DH public keys are public, but constant-time is best practice for all
444        // cryptographic material.
445        let is_new_dh = match self.dh_pk_remote {
446            Some(pk) => pk.ct_ne(&header.dh_pk).into(),
447            None => true,
448        };
449
450        if is_new_dh {
451            // Skip over any remaining messages in the current receiving chain
452            // (they were sent before the peer ratcheted their DH key).
453            if self.ckr.is_some() {
454                self.skip_message_keys(header.pn)?;
455            }
456            self.dh_ratchet(header, rng)?;
457        }
458
459        // 3. Skip to the target message number within the current receiving chain.
460        // Reject within-epoch replays: n < nr means the key was already consumed.
461        // (dh_ratchet resets nr to 0, so this check is only relevant for same-epoch msgs.)
462        if header.n < self.nr {
463            return Err(RatchetError::MessageKeyNotFound);
464        }
465        self.skip_message_keys(header.n)?;
466
467        // 4. Derive the message key.
468        let ckr = self.ckr.as_mut().ok_or(RatchetError::NoReceivingChain)?;
469        let (new_ck, mk) = kdf_ck(ckr);
470        *ckr = new_ck;
471        self.nr += 1;
472
473        Ok(MessageKey(mk))
474    }
475}
476
477// ── Internal helpers ──────────────────────────────────────────────────────────
478
479impl HybridRatchet {
480    /// Perform a full DHRatchet step (two root-KDF invocations) upon receiving a
481    /// message with a new remote DH public key.
482    fn dh_ratchet(
483        &mut self,
484        header: &Header,
485        rng: &mut impl CryptoRngCore,
486    ) -> Result<(), RatchetError> {
487        // ── All fallible and panicking operations  --  state is NOT modified yet ─
488        // Any error or panic leaves the session completely intact.
489
490        // PQ receiving: decapsulate the CT the peer sent for our current EK.
491        let pq_recv: [u8; PQ_SS_LEN] = match &header.pq_ct {
492            Some(ct) => self.scka.decap(ct)?,
493            None => [0u8; PQ_SS_LEN],
494        };
495
496        // PQ sending: encapsulate to the EK announced in this header.
497        let (pq_send, opt_pending_ct): ([u8; PQ_SS_LEN], Option<PqCt>) = match &header.pq_ek {
498            Some(ek) => {
499                let (ss, ct) = self.scka.encap_to(ek, rng)?;
500                (ss, Some(ct))
501            }
502            None => ([0u8; PQ_SS_LEN], None),
503        };
504
505        // When the peer sends no EK, preserve any pending CT from the current
506        // epoch so it is not silently dropped when we rotate the SCKA state.
507        let ct_to_carry = opt_pending_ct.or_else(|| self.scka.pending_ct_ref().cloned());
508
509        // Generate new DH keypair  --  random_from_rng is allowed to panic.
510        let new_dh_sk = StaticSecret::random_from_rng(&mut *rng);
511        let new_dh_pk = PublicKey::from(&new_dh_sk);
512
513        // Rotate to a fresh ML-KEM epoch  --  MlKem768::generate is allowed to panic.
514        let mut new_scka = SckaState::new(rng);
515        if let Some(ct) = ct_to_carry {
516            new_scka.set_pending_ct(ct);
517        }
518
519        // ── All state mutations follow; every panicking operation is complete ─
520        let peer_pk = PublicKey::from(header.dh_pk);
521
522        // Step 1: Receiving chain (old DH secret × new peer key, PQ recv secret).
523        let dh_recv = self.dh_sk.diffie_hellman(&peer_pk);
524        let (rk1, ckr) = kdf_rk(&self.rk, dh_recv.as_bytes(), &pq_recv);
525
526        // Step 2: Sending chain (new DH secret × new peer key, PQ send secret).
527        let dh_send = new_dh_sk.diffie_hellman(&peer_pk);
528        let (rk2, cks) = kdf_rk(&rk1, dh_send.as_bytes(), &pq_send);
529
530        // Commit atomically  --  session is fully consistent after this block.
531        self.pn = self.ns;
532        self.ns = 0;
533        self.nr = 0;
534        self.dh_pk_remote = Some(header.dh_pk);
535        self.rk = rk2;
536        self.ckr = Some(ckr);
537        self.cks = Some(cks);
538        self.dh_sk = new_dh_sk;
539        self.dh_pk = new_dh_pk;
540        self.scka = new_scka;
541
542        Ok(())
543    }
544
545    /// Cache message keys for messages `self.nr .. until` in the current receiving
546    /// chain, enforcing the `MAX_SKIP` limit.
547    fn skip_message_keys(&mut self, until: u32) -> Result<(), RatchetError> {
548        if until < self.nr {
549            return Ok(());
550        }
551
552        let to_skip = (until - self.nr) as usize;
553        // Per-batch limit: prevents a single malformed header from forcing a
554        // huge number of symmetric-ratchet steps in one call.
555        if to_skip > MAX_SKIP {
556            return Err(RatchetError::TooManySkipped(to_skip));
557        }
558        // Global total limit: prevents a malicious peer from exhausting memory
559        // across many DH epochs each with skipped messages.
560        if self.skipped.len() + to_skip > MAX_SKIP_TOTAL {
561            return Err(RatchetError::TooManySkipped(self.skipped.len() + to_skip));
562        }
563
564        if let Some(ref mut ckr) = self.ckr {
565            let remote_pk = self.dh_pk_remote
566                .expect("dh_pk_remote is always Some when ckr is Some  --  both set in dh_ratchet");
567            for i in self.nr..until {
568                let (new_ck, mk) = kdf_ck(ckr);
569                *ckr = new_ck;
570                self.skipped.insert((remote_pk, i), Zeroizing::new(mk));
571            }
572            self.nr = until;
573        }
574
575        Ok(())
576    }
577}
578
579// ── State serialization ───────────────────────────────────────────────────────
580
581impl HybridRatchet {
582    /// Serialize the full session state to a zeroize-on-drop byte vector.
583    ///
584    /// The format is a versioned binary blob (v1).  Deserialize with
585    /// [`from_bytes`](Self::from_bytes).  Callers are responsible for encrypting
586    /// this blob at rest  --  it contains all secret key material.
587    ///
588    /// Binary layout (little-endian integers):
589    /// ```text
590    /// version(1) | dh_sk(32) | remote_flag(1) | [dh_pk_remote(32)]
591    /// | rk(32) | cks_flag(1) | [cks(32)] | ckr_flag(1) | [ckr(32)]
592    /// | ns(4) | nr(4) | pn(4)
593    /// | skipped_count(4) | (remote_pk(32) | idx(4) | mk(32)) * count
594    /// | scka_dk(2400) | scka_ek(1184) | pct_flag(1) | [pct(1088)]
595    /// ```
596    ///
597    /// # Forward-compatibility warning
598    /// The serialized format is versioned (`v1`) but no migration path exists.
599    /// If the ml-kem crate changes key encoding (e.g., NIST FIPS 203 revision)
600    /// or this crate changes the binary layout, ALL persisted sessions become
601    /// unrecoverable with `from_bytes` returning an error on unknown versions.
602    /// Re-key all sessions before upgrading the crate.  Treat persisted blobs
603    /// as opaque  --  do not rely on their internal structure across crate versions.
604    pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
605        let n_skip = self.skipped.len();
606        let mut buf: Vec<u8> = Vec::with_capacity(
607            1 + 32
608                + 1
609                + 32
610                + 32
611                + 1
612                + 32
613                + 1
614                + 32
615                + 4
616                + 4
617                + 4
618                + 4
619                + n_skip * 68
620                + PQ_DK_LEN
621                + PQ_EK_LEN
622                + 1
623                + PQ_CT_LEN,
624        );
625
626        buf.push(0x01); // format version
627
628        buf.extend_from_slice(&self.dh_sk.to_bytes());
629
630        match self.dh_pk_remote {
631            Some(pk) => {
632                buf.push(1);
633                buf.extend_from_slice(&pk);
634            }
635            None => {
636                buf.push(0);
637            }
638        }
639
640        buf.extend_from_slice(&self.rk);
641
642        for opt in [self.cks, self.ckr] {
643            match opt {
644                Some(k) => {
645                    buf.push(1);
646                    buf.extend_from_slice(&k);
647                }
648                None => {
649                    buf.push(0);
650                }
651            }
652        }
653
654        buf.extend_from_slice(&self.ns.to_le_bytes());
655        buf.extend_from_slice(&self.nr.to_le_bytes());
656        buf.extend_from_slice(&self.pn.to_le_bytes());
657
658        buf.extend_from_slice(&(n_skip as u32).to_le_bytes());
659        for ((rpk, idx), mk) in &self.skipped {
660            buf.extend_from_slice(rpk);
661            buf.extend_from_slice(&idx.to_le_bytes());
662            buf.extend_from_slice(mk.as_ref());
663        }
664
665        buf.extend_from_slice(&self.scka.dk_bytes());
666        buf.extend_from_slice(self.scka.ek_bytes_raw());
667        match self.scka.pending_ct_ref() {
668            Some(ct) => {
669                buf.push(1);
670                buf.extend_from_slice(&ct.0);
671            }
672            None => {
673                buf.push(0);
674            }
675        }
676
677        Zeroizing::new(buf)
678    }
679
680    /// Restore session state from bytes produced by [`to_bytes`](Self::to_bytes).
681    ///
682    /// Fails with [`RatchetError::MalformedState`] if the version tag is not
683    /// `0x01` (the only currently supported format).
684    pub fn from_bytes(bytes: &[u8]) -> Result<Self, RatchetError> {
685        let err = |msg| RatchetError::MalformedState(msg);
686        let mut pos = 0usize;
687
688        macro_rules! read {
689            ($n:expr) => {{
690                let end = pos + $n;
691                if end > bytes.len() {
692                    return Err(err("truncated"));
693                }
694                let slice = &bytes[pos..end];
695                pos = end;
696                slice
697            }};
698        }
699        macro_rules! read_arr {
700            ($n:expr) => {{
701                let s = read!($n);
702                let arr: [u8; $n] = s.try_into().unwrap();
703                arr
704            }};
705        }
706        macro_rules! read_u32 {
707            () => {
708                u32::from_le_bytes(read_arr!(4))
709            };
710        }
711        macro_rules! read_opt {
712            ($n:expr) => {{
713                let flag = read_arr!(1)[0];
714                match flag {
715                    0 => None,
716                    1 => Some(read_arr!($n)),
717                    _ => return Err(err("invalid flag")),
718                }
719            }};
720        }
721
722        // version
723        if read_arr!(1)[0] != 0x01 {
724            return Err(err("unknown version"));
725        }
726
727        let dh_sk_bytes = read_arr!(32);
728        let dh_pk_remote = read_opt!(32);
729
730        let rk = read_arr!(32);
731        let cks = read_opt!(32);
732        let ckr = read_opt!(32);
733        let ns = read_u32!();
734        let nr = read_u32!();
735        let pn = read_u32!();
736
737        // Reject counter values that are dangerously close to u32::MAX.  After
738        // 2^31 messages on a single chain the session must be re-keyed anyway.
739        if ns > (u32::MAX / 2) || nr > (u32::MAX / 2) || pn > (u32::MAX / 2) {
740            return Err(err("message counter out of safe range"));
741        }
742
743        let n_skip = read_u32!() as usize;
744        if n_skip > MAX_SKIP_TOTAL {
745            return Err(err("skipped cache exceeds limit"));
746        }
747        // Pre-allocate at the maximum to prevent reallocation (which would leave
748        // unzeroed key material in freed heap memory).
749        let mut skipped = HashMap::with_capacity(MAX_SKIP_TOTAL);
750        for _ in 0..n_skip {
751            let rpk: [u8; 32] = read_arr!(32);
752            let idx = read_u32!();
753            let mk = read_arr!(32);
754            skipped.insert((rpk, idx), Zeroizing::new(mk));
755        }
756
757        let dk_bytes: [u8; PQ_DK_LEN] = read_arr!(PQ_DK_LEN);
758        let ek_bytes: [u8; PQ_EK_LEN] = read_arr!(PQ_EK_LEN);
759        let pending_ct = read_opt!(PQ_CT_LEN).map(PqCt);
760
761        if pos != bytes.len() {
762            return Err(err("trailing bytes"));
763        }
764
765        let dh_sk = StaticSecret::from(dh_sk_bytes);
766        let dh_pk = PublicKey::from(&dh_sk);
767        let scka = SckaState::from_parts(&dk_bytes, ek_bytes, pending_ct)
768            .ok_or_else(|| err("invalid ML-KEM DK"))?;
769
770        Ok(HybridRatchet {
771            dh_sk,
772            dh_pk,
773            dh_pk_remote,
774            rk,
775            cks,
776            ckr,
777            ns,
778            nr,
779            pn,
780            skipped,
781            scka,
782        })
783    }
784}
785
786// ── Redacted Debug ────────────────────────────────────────────────────────────
787
788impl fmt::Debug for HybridRatchet {
789    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
790        f.debug_struct("HybridRatchet")
791            .field("ns", &self.ns)
792            .field("nr", &self.nr)
793            .field("pn", &self.pn)
794            .field("skipped_count", &self.skipped.len())
795            .field("has_cks", &self.cks.is_some())
796            .field("has_ckr", &self.ckr.is_some())
797            .finish_non_exhaustive()
798    }
799}
800
801// ── Optional serde support ────────────────────────────────────────────────────
802
803/// # Binary formats only
804///
805/// The `Serialize` impl calls `to_bytes()` and passes the raw key material to
806/// the serializer.  With text-based formats (JSON, YAML, TOML) the serializer
807/// base64-encodes the bytes into a `String` that is **not** zeroed on drop  -- 
808/// key material may linger in heap memory.  Only use binary serializers
809/// (bincode, postcard, messagepack) and encrypt the result immediately at rest.
810#[cfg(feature = "serde")]
811impl serde::Serialize for HybridRatchet {
812    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
813        s.serialize_bytes(&self.to_bytes())
814    }
815}
816
817#[cfg(feature = "serde")]
818impl<'de> serde::Deserialize<'de> for HybridRatchet {
819    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
820        struct Visitor;
821        impl<'de> serde::de::Visitor<'de> for Visitor {
822            type Value = HybridRatchet;
823            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
824                write!(f, "pq-ratchet session state as bytes")
825            }
826            fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<HybridRatchet, E> {
827                HybridRatchet::from_bytes(v).map_err(E::custom)
828            }
829            fn visit_byte_buf<E: serde::de::Error>(self, v: Vec<u8>) -> Result<HybridRatchet, E> {
830                self.visit_bytes(&v)
831            }
832            fn visit_seq<A: serde::de::SeqAccess<'de>>(
833                self,
834                mut seq: A,
835            ) -> Result<HybridRatchet, A::Error> {
836                let mut buf = Vec::new();
837                while let Some(b) = seq.next_element::<u8>()? {
838                    buf.push(b);
839                }
840                self.visit_bytes(&buf)
841            }
842        }
843        d.deserialize_bytes(Visitor)
844    }
845}