pq_ratchet/ratchet.rs
1//! Hybrid Double Ratchet state machine.
2//!
3//! Extends the Signal Double Ratchet algorithm with ML-KEM-768 post-quantum
4//! ratcheting through the SCKA epoch model. Two root-KDF invocations per step:
5//!
6//! Receiving chain: `KDF_RK(rk, DH(old_dh, their_dh) || decap(their_ct))`
7//! Sending chain: `KDF_RK(rk, DH(new_dh, their_dh) || encap(their_ek).ss)`
8//!
9//! Attackers have to break X25519 AND ML-KEM-768. If one falls, the other
10//! still provides full security.
11
12use std::collections::HashMap;
13use std::fmt;
14
15use rand_core::CryptoRngCore;
16use subtle::ConstantTimeEq;
17use x25519_dalek::{PublicKey, StaticSecret};
18use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
19
20use crate::{
21 error::RatchetError,
22 kdf::{kdf_ck, kdf_rk},
23 scka::{PqCt, PqEk, SckaState, PQ_CT_LEN, PQ_DK_LEN, PQ_EK_LEN, PQ_SS_LEN},
24};
25
26/// Maximum skipped messages in a single chain step. Signal's recommended value.
27/// Raise it if your transport layer reorders deeper than 1,000 messages.
28pub const MAX_SKIP: usize = 1_000;
29
30/// Hard cap on the total number of entries in the skipped-key cache across all
31/// DH epochs. Without this, a malicious peer can force unbounded memory growth
32/// by ratcheting through many epochs each with skipped messages.
33pub const MAX_SKIP_TOTAL: usize = 2_000;
34
35// ── Public types ──────────────────────────────────────────────────────────────
36
37/// Per-message header transmitted alongside each ciphertext.
38///
39/// Callers are responsible for authenticating this header (e.g. AEAD additional
40/// data) -- this crate derives keys only, it does not encrypt.
41///
42/// Fields are not publicly settable to prevent construction of invalid headers.
43/// Use [`Header::new`] to create a header, or read fields through the accessor
44/// methods after receiving one from [`HybridRatchet::ratchet_encrypt`].
45#[derive(Clone, Debug)]
46pub struct Header {
47 /// Sender's current X25519 DH ratchet public key.
48 pub(crate) dh_pk: [u8; 32],
49 /// Index of this message within the current sending chain.
50 pub(crate) n: u32,
51 /// Length of the *previous* sending chain (needed to skip stale messages).
52 pub(crate) pn: u32,
53 /// Sender's current ML-KEM-768 encapsulation key.
54 /// Receiver encapsulates to this and returns a ciphertext in their next header.
55 pub(crate) pq_ek: Option<PqEk>,
56 /// ML-KEM-768 ciphertext -- the sender's encapsulation response to the receiver's
57 /// most recently seen EK. Receiver decapsulates to obtain the PQ shared-secret.
58 pub(crate) pq_ct: Option<PqCt>,
59}
60
61/// Single-use message key derived from the symmetric ratchet chain.
62///
63/// Use this to key an AEAD cipher (e.g. ChaCha20-Poly1305). Zeroed on drop.
64#[derive(ZeroizeOnDrop)]
65pub struct MessageKey(pub [u8; 32]);
66
67impl MessageKey {
68 /// Return the raw 32-byte message key.
69 pub fn as_bytes(&self) -> &[u8; 32] {
70 &self.0
71 }
72}
73
74impl fmt::Debug for MessageKey {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.debug_tuple("MessageKey").field(&"[REDACTED]").finish()
77 }
78}
79
80// ── Header constructors and accessors ────────────────────────────────────────
81
82impl Header {
83 /// Construct a [`Header`] with the given fields.
84 pub fn new(
85 dh_pk: [u8; 32],
86 n: u32,
87 pn: u32,
88 pq_ek: Option<PqEk>,
89 pq_ct: Option<PqCt>,
90 ) -> Self {
91 Self {
92 dh_pk,
93 n,
94 pn,
95 pq_ek,
96 pq_ct,
97 }
98 }
99
100 /// Sender's current X25519 DH ratchet public key.
101 pub fn dh_pk(&self) -> [u8; 32] {
102 self.dh_pk
103 }
104
105 /// Index of this message within the current sending chain.
106 pub fn n(&self) -> u32 {
107 self.n
108 }
109
110 /// Length of the previous sending chain.
111 pub fn pn(&self) -> u32 {
112 self.pn
113 }
114
115 /// Sender's current ML-KEM-768 encapsulation key, if present.
116 pub fn pq_ek(&self) -> Option<&PqEk> {
117 self.pq_ek.as_ref()
118 }
119
120 /// ML-KEM-768 ciphertext response to receiver's EK, if present.
121 pub fn pq_ct(&self) -> Option<&PqCt> {
122 self.pq_ct.as_ref()
123 }
124}
125
126// ── Header wire encoding ──────────────────────────────────────────────────────
127
128impl Header {
129 /// Encode to a canonical byte string for use as AEAD additional data.
130 ///
131 /// Format (all integers little-endian):
132 /// ```text
133 /// dh_pk (32) | n (4) | pn (4) | flags (1) | [pq_ek (1184)] | [pq_ct (1088)]
134 /// ```
135 /// `flags` bit 0 = EK present; bit 1 = CT present.
136 ///
137 /// Allocates a new `Vec`. Use [`write_to`](Self::write_to) to append into
138 /// a reusable buffer instead.
139 pub fn encode(&self) -> Vec<u8> {
140 let mut buf = Vec::with_capacity(
141 41 + self.pq_ek.as_ref().map_or(0, |_| PQ_EK_LEN)
142 + self.pq_ct.as_ref().map_or(0, |_| PQ_CT_LEN),
143 );
144 self.write_to(&mut buf);
145 buf
146 }
147
148 /// Append the canonical encoding to an existing buffer.
149 ///
150 /// Prefer this over [`encode`](Self::encode) when the caller already has a
151 /// pre-allocated `Vec` (e.g. a per-connection AEAD header buffer) to avoid
152 /// a per-message allocation.
153 pub fn write_to(&self, buf: &mut Vec<u8>) {
154 let flags: u8 = (self.pq_ek.is_some() as u8) | ((self.pq_ct.is_some() as u8) << 1);
155 buf.extend_from_slice(&self.dh_pk);
156 buf.extend_from_slice(&self.n.to_le_bytes());
157 buf.extend_from_slice(&self.pn.to_le_bytes());
158 buf.push(flags);
159 if let Some(ek) = &self.pq_ek {
160 buf.extend_from_slice(&ek.0);
161 }
162 if let Some(ct) = &self.pq_ct {
163 buf.extend_from_slice(&ct.0);
164 }
165 }
166
167 /// Decode from the canonical byte representation produced by [`Header::encode`].
168 pub fn decode(bytes: &[u8]) -> Result<Self, RatchetError> {
169 const MIN: usize = 41; // dh_pk(32) + n(4) + pn(4) + flags(1)
170 if bytes.len() < MIN {
171 return Err(RatchetError::MalformedHeader("too short"));
172 }
173 // Slices are guaranteed to be the right length by the MIN check above.
174 let dh_pk: [u8; 32] = bytes[..32].try_into().expect("32-byte slice after MIN check");
175 let n = u32::from_le_bytes(bytes[32..36].try_into().expect("4-byte slice after MIN check"));
176 let pn = u32::from_le_bytes(bytes[36..40].try_into().expect("4-byte slice after MIN check"));
177 let flags = bytes[40];
178 if flags & !0x03 != 0 {
179 return Err(RatchetError::MalformedHeader("unknown flags"));
180 }
181 let has_ek = flags & 0x01 != 0;
182 let has_ct = flags & 0x02 != 0;
183
184 let mut pos = MIN;
185 let pq_ek = if has_ek {
186 let end = pos + PQ_EK_LEN;
187 if bytes.len() < end {
188 return Err(RatchetError::MalformedHeader("truncated EK"));
189 }
190 let ek: [u8; PQ_EK_LEN] = bytes[pos..end].try_into().expect("slice length guaranteed by bounds check");
191 pos = end;
192 Some(PqEk(ek))
193 } else {
194 None
195 };
196 let pq_ct = if has_ct {
197 let end = pos + PQ_CT_LEN;
198 if bytes.len() < end {
199 return Err(RatchetError::MalformedHeader("truncated CT"));
200 }
201 let ct: [u8; PQ_CT_LEN] = bytes[pos..end].try_into().expect("slice length guaranteed by bounds check");
202 pos = end;
203 Some(PqCt(ct))
204 } else {
205 None
206 };
207 if pos != bytes.len() {
208 return Err(RatchetError::MalformedHeader("trailing bytes"));
209 }
210 Ok(Header {
211 dh_pk,
212 n,
213 pn,
214 pq_ek,
215 pq_ct,
216 })
217 }
218}
219
220// ── State ─────────────────────────────────────────────────────────────────────
221
222/// Hybrid Double Ratchet session state.
223///
224/// Holds all key material for one end of a conversation. Zeroes all secret
225/// fields on drop via [`ZeroizeOnDrop`] and manual [`Drop`].
226pub struct HybridRatchet {
227 // ── DH ratchet ───────────────────────────────────────────────────────────
228 dh_sk: StaticSecret, // our current DH private key
229 dh_pk: PublicKey, // our current DH public key
230 dh_pk_remote: Option<[u8; 32]>, // peer's latest DH public key; None until first recv
231
232 // ── Root & chain keys ────────────────────────────────────────────────────
233 rk: [u8; 32], // root key, mixed at each DH ratchet step
234 cks: Option<[u8; 32]>, // sending chain key (None for receiver before first decrypt)
235 ckr: Option<[u8; 32]>, // receiving chain key (None for sender before first decrypt)
236
237 // ── Message counters ─────────────────────────────────────────────────────
238 ns: u32, // messages sent in current sending chain
239 nr: u32, // messages received in current receiving chain
240 pn: u32, // length of *previous* sending chain
241
242 // ── Out-of-order cache ───────────────────────────────────────────────────
243 /// (remote_dh_pk, message_index) → message_key
244 ///
245 /// Values are wrapped in `Zeroizing` so they are explicitly zeroed when
246 /// removed or when the map is dropped. The map is pre-allocated at
247 /// `MAX_SKIP_TOTAL` capacity on construction so it never reallocates --
248 /// preventing freed backing memory from retaining plaintext key material.
249 skipped: HashMap<([u8; 32], u32), Zeroizing<[u8; 32]>>,
250
251 // ── Post-quantum state ───────────────────────────────────────────────────
252 scka: SckaState,
253}
254
255impl Drop for HybridRatchet {
256 fn drop(&mut self) {
257 self.rk.zeroize();
258 if let Some(ref mut k) = self.cks {
259 k.zeroize();
260 }
261 if let Some(ref mut k) = self.ckr {
262 k.zeroize();
263 }
264 // Zeroizing<[u8;32]> values already zero themselves on drop, but
265 // explicitly zeroing here clears the memory before the HashMap
266 // allocator reclaims it.
267 for v in self.skipped.values_mut() {
268 v.zeroize();
269 }
270 }
271}
272
273// ── Constructors ──────────────────────────────────────────────────────────────
274
275impl HybridRatchet {
276 /// Initialise as the **sending** party (e.g. Alice in Signal X3DH).
277 ///
278 /// Performs an immediate DH ratchet step against `peer_dh_pk`, deriving the
279 /// initial sending chain key. The result can call `ratchet_encrypt` immediately.
280 ///
281 /// # Arguments
282 /// * `shared_secret` -- 32-byte pre-shared secret from the key agreement phase
283 /// (e.g. the output of PQXDH or X3DH).
284 /// * `peer_dh_pk` -- peer's X25519 ratchet public key (from their prekey bundle).
285 pub fn init_sender(
286 shared_secret: &[u8; 32],
287 peer_dh_pk: &[u8; 32],
288 rng: &mut impl CryptoRngCore,
289 ) -> Self {
290 let dh_sk = StaticSecret::random_from_rng(&mut *rng);
291 let dh_pk = PublicKey::from(&dh_sk);
292 let peer_pk = PublicKey::from(*peer_dh_pk);
293
294 let dh_ss = dh_sk.diffie_hellman(&peer_pk);
295 // No PQ exchange has occurred yet; contribute zero bytes for this first step.
296 let (rk, cks) = kdf_rk(shared_secret, dh_ss.as_bytes(), &[0u8; PQ_SS_LEN]);
297
298 let scka = SckaState::new(rng);
299
300 Self {
301 dh_sk,
302 dh_pk,
303 dh_pk_remote: Some(*peer_dh_pk),
304 rk,
305 cks: Some(cks),
306 ckr: None,
307 ns: 0,
308 nr: 0,
309 pn: 0,
310 skipped: HashMap::with_capacity(MAX_SKIP_TOTAL),
311 scka,
312 }
313 }
314
315 /// Initialise as the **receiving** party (e.g. Bob in Signal X3DH).
316 ///
317 /// Does not perform a DH ratchet step yet; call `ratchet_decrypt` on the first
318 /// incoming message to derive both the receiving and sending chain keys.
319 ///
320 /// # Arguments
321 /// * `shared_secret` -- 32-byte pre-shared secret (same as sender's).
322 /// * `our_dh_sk` -- our X25519 ratchet secret key (the one whose public key
323 /// was shared with the sender in the prekey bundle).
324 pub fn init_receiver(
325 shared_secret: &[u8; 32],
326 our_dh_sk: StaticSecret,
327 rng: &mut impl CryptoRngCore,
328 ) -> Self {
329 let dh_pk = PublicKey::from(&our_dh_sk);
330 let scka = SckaState::new(rng);
331
332 Self {
333 dh_sk: our_dh_sk,
334 dh_pk,
335 dh_pk_remote: None,
336 rk: *shared_secret,
337 cks: None,
338 ckr: None,
339 ns: 0,
340 nr: 0,
341 pn: 0,
342 skipped: HashMap::with_capacity(MAX_SKIP_TOTAL),
343 scka,
344 }
345 }
346
347 /// Return our current X25519 DH public key (for sharing in a prekey bundle).
348 ///
349 /// Returns raw bytes. X25519 public keys are conventionally exchanged as
350 /// byte arrays; use [`x25519_dalek::PublicKey::from`] to convert if needed.
351 /// The PQ equivalent, [`our_pq_ek`](Self::our_pq_ek), returns a typed
352 /// [`PqEk`] wrapper instead -- the difference is intentional since ML-KEM
353 /// keys are less commonly manipulated as raw bytes.
354 pub fn our_dh_pk(&self) -> [u8; 32] {
355 *self.dh_pk.as_bytes()
356 }
357
358 /// Return our current ML-KEM-768 encapsulation key (for sharing in a prekey bundle).
359 ///
360 /// Returns a typed [`PqEk`] wrapper around the 1184-byte encoding.
361 pub fn our_pq_ek(&self) -> PqEk {
362 self.scka.our_ek().clone()
363 }
364}
365
366// ── Encrypt / Decrypt ─────────────────────────────────────────────────────────
367
368impl HybridRatchet {
369 /// Advance the sending chain and return a `(Header, MessageKey)` pair.
370 ///
371 /// The message key is derived from the current sending chain key. Use it to
372 /// AEAD-encrypt the plaintext, then transmit `header` alongside the ciphertext.
373 ///
374 /// The `rng` parameter is reserved for future protocol extensions and is not
375 /// consumed by this call. Any [`CryptoRngCore`] implementor is accepted.
376 ///
377 /// **PQ ciphertext retransmission**: if a pending ML-KEM ciphertext exists (set
378 /// during the last DH ratchet step), it is included in every outgoing header
379 /// until the peer's next DH ratchet step replaces it. This ensures the PQ
380 /// shared-secret is established even if the first message carrying the CT is
381 /// lost in transit. A caller MUST NOT assume the CT was received until the
382 /// peer's subsequent DH ratchet is observed.
383 ///
384 /// # Errors
385 /// Returns [`RatchetError::NoSendingChain`] if the sending chain has not been
386 /// initialised yet (receiver must call `ratchet_decrypt` at least once first).
387 pub fn ratchet_encrypt(
388 &mut self,
389 _rng: &mut impl CryptoRngCore,
390 ) -> Result<(Header, MessageKey), RatchetError> {
391 let cks = self.cks.as_mut().ok_or(RatchetError::NoSendingChain)?;
392
393 let (new_ck, mk) = kdf_ck(cks);
394 *cks = new_ck;
395
396 let n = self.ns;
397 self.ns += 1;
398
399 let header = Header {
400 dh_pk: *self.dh_pk.as_bytes(),
401 n,
402 pn: self.pn,
403 pq_ek: Some(self.scka.our_ek().clone()),
404 // Peek (clone without consuming) so the CT is retransmitted in every
405 // message until the peer's next DH ratchet generates a replacement.
406 pq_ct: self.scka.pending_ct_ref().cloned(),
407 };
408
409 Ok((header, MessageKey(mk)))
410 }
411
412 /// Advance the receiving chain and return the `MessageKey` for this message.
413 ///
414 /// Handles DH ratchet steps, out-of-order delivery (caching skipped keys), and
415 /// the hybrid PQ key exchange within each DH ratchet step.
416 ///
417 /// The `rng` parameter is only consumed when the incoming header carries a new
418 /// DH public key (triggering a full DH ratchet step with ML-KEM encapsulation).
419 /// For within-epoch messages -- including cache hits from out-of-order delivery --
420 /// the RNG is untouched.
421 ///
422 /// **Implicit key confirmation**: successful AEAD decryption by the caller
423 /// serves as the only confirmation that both parties derived matching keys.
424 /// There is no explicit key confirmation step -- this matches Signal's design.
425 ///
426 /// # Errors
427 /// - [`RatchetError::TooManySkipped`] if the out-of-order cache would overflow.
428 /// - [`RatchetError::MessageKeyNotFound`] if this message was already decrypted
429 /// or is older than what remains in the cache.
430 /// - Crypto errors from malformed PQ material in the header.
431 pub fn ratchet_decrypt(
432 &mut self,
433 header: &Header,
434 rng: &mut impl CryptoRngCore,
435 ) -> Result<MessageKey, RatchetError> {
436 // 1. Fast path: this key was already cached from a previous out-of-order scan.
437 if let Some(mk) = self.skipped.remove(&(header.dh_pk, header.n)) {
438 return Ok(MessageKey(*mk));
439 }
440
441 // 2. Determine if the remote DH key has changed (DHRatchet trigger).
442 // Constant-time comparison prevents timing side-channels on the DH key.
443 // DH public keys are public, but constant-time is best practice for all
444 // cryptographic material.
445 let is_new_dh = match self.dh_pk_remote {
446 Some(pk) => pk.ct_ne(&header.dh_pk).into(),
447 None => true,
448 };
449
450 if is_new_dh {
451 // Skip over any remaining messages in the current receiving chain
452 // (they were sent before the peer ratcheted their DH key).
453 if self.ckr.is_some() {
454 self.skip_message_keys(header.pn)?;
455 }
456 self.dh_ratchet(header, rng)?;
457 }
458
459 // 3. Skip to the target message number within the current receiving chain.
460 // Reject within-epoch replays: n < nr means the key was already consumed.
461 // (dh_ratchet resets nr to 0, so this check is only relevant for same-epoch msgs.)
462 if header.n < self.nr {
463 return Err(RatchetError::MessageKeyNotFound);
464 }
465 self.skip_message_keys(header.n)?;
466
467 // 4. Derive the message key.
468 let ckr = self.ckr.as_mut().ok_or(RatchetError::NoReceivingChain)?;
469 let (new_ck, mk) = kdf_ck(ckr);
470 *ckr = new_ck;
471 self.nr += 1;
472
473 Ok(MessageKey(mk))
474 }
475}
476
477// ── Internal helpers ──────────────────────────────────────────────────────────
478
479impl HybridRatchet {
480 /// Perform a full DHRatchet step (two root-KDF invocations) upon receiving a
481 /// message with a new remote DH public key.
482 fn dh_ratchet(
483 &mut self,
484 header: &Header,
485 rng: &mut impl CryptoRngCore,
486 ) -> Result<(), RatchetError> {
487 // ── All fallible and panicking operations -- state is NOT modified yet ─
488 // Any error or panic leaves the session completely intact.
489
490 // PQ receiving: decapsulate the CT the peer sent for our current EK.
491 let pq_recv: [u8; PQ_SS_LEN] = match &header.pq_ct {
492 Some(ct) => self.scka.decap(ct)?,
493 None => [0u8; PQ_SS_LEN],
494 };
495
496 // PQ sending: encapsulate to the EK announced in this header.
497 let (pq_send, opt_pending_ct): ([u8; PQ_SS_LEN], Option<PqCt>) = match &header.pq_ek {
498 Some(ek) => {
499 let (ss, ct) = self.scka.encap_to(ek, rng)?;
500 (ss, Some(ct))
501 }
502 None => ([0u8; PQ_SS_LEN], None),
503 };
504
505 // When the peer sends no EK, preserve any pending CT from the current
506 // epoch so it is not silently dropped when we rotate the SCKA state.
507 let ct_to_carry = opt_pending_ct.or_else(|| self.scka.pending_ct_ref().cloned());
508
509 // Generate new DH keypair -- random_from_rng is allowed to panic.
510 let new_dh_sk = StaticSecret::random_from_rng(&mut *rng);
511 let new_dh_pk = PublicKey::from(&new_dh_sk);
512
513 // Rotate to a fresh ML-KEM epoch -- MlKem768::generate is allowed to panic.
514 let mut new_scka = SckaState::new(rng);
515 if let Some(ct) = ct_to_carry {
516 new_scka.set_pending_ct(ct);
517 }
518
519 // ── All state mutations follow; every panicking operation is complete ─
520 let peer_pk = PublicKey::from(header.dh_pk);
521
522 // Step 1: Receiving chain (old DH secret × new peer key, PQ recv secret).
523 let dh_recv = self.dh_sk.diffie_hellman(&peer_pk);
524 let (rk1, ckr) = kdf_rk(&self.rk, dh_recv.as_bytes(), &pq_recv);
525
526 // Step 2: Sending chain (new DH secret × new peer key, PQ send secret).
527 let dh_send = new_dh_sk.diffie_hellman(&peer_pk);
528 let (rk2, cks) = kdf_rk(&rk1, dh_send.as_bytes(), &pq_send);
529
530 // Commit atomically -- session is fully consistent after this block.
531 self.pn = self.ns;
532 self.ns = 0;
533 self.nr = 0;
534 self.dh_pk_remote = Some(header.dh_pk);
535 self.rk = rk2;
536 self.ckr = Some(ckr);
537 self.cks = Some(cks);
538 self.dh_sk = new_dh_sk;
539 self.dh_pk = new_dh_pk;
540 self.scka = new_scka;
541
542 Ok(())
543 }
544
545 /// Cache message keys for messages `self.nr .. until` in the current receiving
546 /// chain, enforcing the `MAX_SKIP` limit.
547 fn skip_message_keys(&mut self, until: u32) -> Result<(), RatchetError> {
548 if until < self.nr {
549 return Ok(());
550 }
551
552 let to_skip = (until - self.nr) as usize;
553 // Per-batch limit: prevents a single malformed header from forcing a
554 // huge number of symmetric-ratchet steps in one call.
555 if to_skip > MAX_SKIP {
556 return Err(RatchetError::TooManySkipped(to_skip));
557 }
558 // Global total limit: prevents a malicious peer from exhausting memory
559 // across many DH epochs each with skipped messages.
560 if self.skipped.len() + to_skip > MAX_SKIP_TOTAL {
561 return Err(RatchetError::TooManySkipped(self.skipped.len() + to_skip));
562 }
563
564 if let Some(ref mut ckr) = self.ckr {
565 let remote_pk = self.dh_pk_remote
566 .expect("dh_pk_remote is always Some when ckr is Some -- both set in dh_ratchet");
567 for i in self.nr..until {
568 let (new_ck, mk) = kdf_ck(ckr);
569 *ckr = new_ck;
570 self.skipped.insert((remote_pk, i), Zeroizing::new(mk));
571 }
572 self.nr = until;
573 }
574
575 Ok(())
576 }
577}
578
579// ── State serialization ───────────────────────────────────────────────────────
580
581impl HybridRatchet {
582 /// Serialize the full session state to a zeroize-on-drop byte vector.
583 ///
584 /// The format is a versioned binary blob (v1). Deserialize with
585 /// [`from_bytes`](Self::from_bytes). Callers are responsible for encrypting
586 /// this blob at rest -- it contains all secret key material.
587 ///
588 /// Binary layout (little-endian integers):
589 /// ```text
590 /// version(1) | dh_sk(32) | remote_flag(1) | [dh_pk_remote(32)]
591 /// | rk(32) | cks_flag(1) | [cks(32)] | ckr_flag(1) | [ckr(32)]
592 /// | ns(4) | nr(4) | pn(4)
593 /// | skipped_count(4) | (remote_pk(32) | idx(4) | mk(32)) * count
594 /// | scka_dk(2400) | scka_ek(1184) | pct_flag(1) | [pct(1088)]
595 /// ```
596 ///
597 /// # Forward-compatibility warning
598 /// The serialized format is versioned (`v1`) but no migration path exists.
599 /// If the ml-kem crate changes key encoding (e.g., NIST FIPS 203 revision)
600 /// or this crate changes the binary layout, ALL persisted sessions become
601 /// unrecoverable with `from_bytes` returning an error on unknown versions.
602 /// Re-key all sessions before upgrading the crate. Treat persisted blobs
603 /// as opaque -- do not rely on their internal structure across crate versions.
604 pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
605 let n_skip = self.skipped.len();
606 let mut buf: Vec<u8> = Vec::with_capacity(
607 1 + 32
608 + 1
609 + 32
610 + 32
611 + 1
612 + 32
613 + 1
614 + 32
615 + 4
616 + 4
617 + 4
618 + 4
619 + n_skip * 68
620 + PQ_DK_LEN
621 + PQ_EK_LEN
622 + 1
623 + PQ_CT_LEN,
624 );
625
626 buf.push(0x01); // format version
627
628 buf.extend_from_slice(&self.dh_sk.to_bytes());
629
630 match self.dh_pk_remote {
631 Some(pk) => {
632 buf.push(1);
633 buf.extend_from_slice(&pk);
634 }
635 None => {
636 buf.push(0);
637 }
638 }
639
640 buf.extend_from_slice(&self.rk);
641
642 for opt in [self.cks, self.ckr] {
643 match opt {
644 Some(k) => {
645 buf.push(1);
646 buf.extend_from_slice(&k);
647 }
648 None => {
649 buf.push(0);
650 }
651 }
652 }
653
654 buf.extend_from_slice(&self.ns.to_le_bytes());
655 buf.extend_from_slice(&self.nr.to_le_bytes());
656 buf.extend_from_slice(&self.pn.to_le_bytes());
657
658 buf.extend_from_slice(&(n_skip as u32).to_le_bytes());
659 for ((rpk, idx), mk) in &self.skipped {
660 buf.extend_from_slice(rpk);
661 buf.extend_from_slice(&idx.to_le_bytes());
662 buf.extend_from_slice(mk.as_ref());
663 }
664
665 buf.extend_from_slice(&self.scka.dk_bytes());
666 buf.extend_from_slice(self.scka.ek_bytes_raw());
667 match self.scka.pending_ct_ref() {
668 Some(ct) => {
669 buf.push(1);
670 buf.extend_from_slice(&ct.0);
671 }
672 None => {
673 buf.push(0);
674 }
675 }
676
677 Zeroizing::new(buf)
678 }
679
680 /// Restore session state from bytes produced by [`to_bytes`](Self::to_bytes).
681 ///
682 /// Fails with [`RatchetError::MalformedState`] if the version tag is not
683 /// `0x01` (the only currently supported format).
684 pub fn from_bytes(bytes: &[u8]) -> Result<Self, RatchetError> {
685 let err = |msg| RatchetError::MalformedState(msg);
686 let mut pos = 0usize;
687
688 macro_rules! read {
689 ($n:expr) => {{
690 let end = pos + $n;
691 if end > bytes.len() {
692 return Err(err("truncated"));
693 }
694 let slice = &bytes[pos..end];
695 pos = end;
696 slice
697 }};
698 }
699 macro_rules! read_arr {
700 ($n:expr) => {{
701 let s = read!($n);
702 let arr: [u8; $n] = s.try_into().unwrap();
703 arr
704 }};
705 }
706 macro_rules! read_u32 {
707 () => {
708 u32::from_le_bytes(read_arr!(4))
709 };
710 }
711 macro_rules! read_opt {
712 ($n:expr) => {{
713 let flag = read_arr!(1)[0];
714 match flag {
715 0 => None,
716 1 => Some(read_arr!($n)),
717 _ => return Err(err("invalid flag")),
718 }
719 }};
720 }
721
722 // version
723 if read_arr!(1)[0] != 0x01 {
724 return Err(err("unknown version"));
725 }
726
727 let dh_sk_bytes = read_arr!(32);
728 let dh_pk_remote = read_opt!(32);
729
730 let rk = read_arr!(32);
731 let cks = read_opt!(32);
732 let ckr = read_opt!(32);
733 let ns = read_u32!();
734 let nr = read_u32!();
735 let pn = read_u32!();
736
737 // Reject counter values that are dangerously close to u32::MAX. After
738 // 2^31 messages on a single chain the session must be re-keyed anyway.
739 if ns > (u32::MAX / 2) || nr > (u32::MAX / 2) || pn > (u32::MAX / 2) {
740 return Err(err("message counter out of safe range"));
741 }
742
743 let n_skip = read_u32!() as usize;
744 if n_skip > MAX_SKIP_TOTAL {
745 return Err(err("skipped cache exceeds limit"));
746 }
747 // Pre-allocate at the maximum to prevent reallocation (which would leave
748 // unzeroed key material in freed heap memory).
749 let mut skipped = HashMap::with_capacity(MAX_SKIP_TOTAL);
750 for _ in 0..n_skip {
751 let rpk: [u8; 32] = read_arr!(32);
752 let idx = read_u32!();
753 let mk = read_arr!(32);
754 skipped.insert((rpk, idx), Zeroizing::new(mk));
755 }
756
757 let dk_bytes: [u8; PQ_DK_LEN] = read_arr!(PQ_DK_LEN);
758 let ek_bytes: [u8; PQ_EK_LEN] = read_arr!(PQ_EK_LEN);
759 let pending_ct = read_opt!(PQ_CT_LEN).map(PqCt);
760
761 if pos != bytes.len() {
762 return Err(err("trailing bytes"));
763 }
764
765 let dh_sk = StaticSecret::from(dh_sk_bytes);
766 let dh_pk = PublicKey::from(&dh_sk);
767 let scka = SckaState::from_parts(&dk_bytes, ek_bytes, pending_ct)
768 .ok_or_else(|| err("invalid ML-KEM DK"))?;
769
770 Ok(HybridRatchet {
771 dh_sk,
772 dh_pk,
773 dh_pk_remote,
774 rk,
775 cks,
776 ckr,
777 ns,
778 nr,
779 pn,
780 skipped,
781 scka,
782 })
783 }
784}
785
786// ── Redacted Debug ────────────────────────────────────────────────────────────
787
788impl fmt::Debug for HybridRatchet {
789 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
790 f.debug_struct("HybridRatchet")
791 .field("ns", &self.ns)
792 .field("nr", &self.nr)
793 .field("pn", &self.pn)
794 .field("skipped_count", &self.skipped.len())
795 .field("has_cks", &self.cks.is_some())
796 .field("has_ckr", &self.ckr.is_some())
797 .finish_non_exhaustive()
798 }
799}
800
801// ── Optional serde support ────────────────────────────────────────────────────
802
803/// # Binary formats only
804///
805/// The `Serialize` impl calls `to_bytes()` and passes the raw key material to
806/// the serializer. With text-based formats (JSON, YAML, TOML) the serializer
807/// base64-encodes the bytes into a `String` that is **not** zeroed on drop --
808/// key material may linger in heap memory. Only use binary serializers
809/// (bincode, postcard, messagepack) and encrypt the result immediately at rest.
810#[cfg(feature = "serde")]
811impl serde::Serialize for HybridRatchet {
812 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
813 s.serialize_bytes(&self.to_bytes())
814 }
815}
816
817#[cfg(feature = "serde")]
818impl<'de> serde::Deserialize<'de> for HybridRatchet {
819 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
820 struct Visitor;
821 impl<'de> serde::de::Visitor<'de> for Visitor {
822 type Value = HybridRatchet;
823 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
824 write!(f, "pq-ratchet session state as bytes")
825 }
826 fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<HybridRatchet, E> {
827 HybridRatchet::from_bytes(v).map_err(E::custom)
828 }
829 fn visit_byte_buf<E: serde::de::Error>(self, v: Vec<u8>) -> Result<HybridRatchet, E> {
830 self.visit_bytes(&v)
831 }
832 fn visit_seq<A: serde::de::SeqAccess<'de>>(
833 self,
834 mut seq: A,
835 ) -> Result<HybridRatchet, A::Error> {
836 let mut buf = Vec::new();
837 while let Some(b) = seq.next_element::<u8>()? {
838 buf.push(b);
839 }
840 self.visit_bytes(&buf)
841 }
842 }
843 d.deserialize_bytes(Visitor)
844 }
845}