use std::collections::HashMap;
use std::fmt;
use rand_core::CryptoRngCore;
use subtle::ConstantTimeEq;
use x25519_dalek::{PublicKey, StaticSecret};
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
use crate::{
error::RatchetError,
kdf::{kdf_ck, kdf_rk},
scka::{PqCt, PqEk, SckaState, PQ_CT_LEN, PQ_DK_LEN, PQ_EK_LEN, PQ_SS_LEN},
};
pub const MAX_SKIP: usize = 1_000;
pub const MAX_SKIP_TOTAL: usize = 2_000;
#[derive(Clone, Debug)]
pub struct Header {
pub(crate) dh_pk: [u8; 32],
pub(crate) n: u32,
pub(crate) pn: u32,
pub(crate) pq_ek: Option<PqEk>,
pub(crate) pq_ct: Option<PqCt>,
}
#[derive(ZeroizeOnDrop)]
pub struct MessageKey(pub [u8; 32]);
impl MessageKey {
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
impl fmt::Debug for MessageKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MessageKey").field(&"[REDACTED]").finish()
}
}
impl Header {
pub fn new(
dh_pk: [u8; 32],
n: u32,
pn: u32,
pq_ek: Option<PqEk>,
pq_ct: Option<PqCt>,
) -> Self {
Self {
dh_pk,
n,
pn,
pq_ek,
pq_ct,
}
}
pub fn dh_pk(&self) -> [u8; 32] {
self.dh_pk
}
pub fn n(&self) -> u32 {
self.n
}
pub fn pn(&self) -> u32 {
self.pn
}
pub fn pq_ek(&self) -> Option<&PqEk> {
self.pq_ek.as_ref()
}
pub fn pq_ct(&self) -> Option<&PqCt> {
self.pq_ct.as_ref()
}
}
impl Header {
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(
41 + self.pq_ek.as_ref().map_or(0, |_| PQ_EK_LEN)
+ self.pq_ct.as_ref().map_or(0, |_| PQ_CT_LEN),
);
self.write_to(&mut buf);
buf
}
pub fn write_to(&self, buf: &mut Vec<u8>) {
let flags: u8 = (self.pq_ek.is_some() as u8) | ((self.pq_ct.is_some() as u8) << 1);
buf.extend_from_slice(&self.dh_pk);
buf.extend_from_slice(&self.n.to_le_bytes());
buf.extend_from_slice(&self.pn.to_le_bytes());
buf.push(flags);
if let Some(ek) = &self.pq_ek {
buf.extend_from_slice(&ek.0);
}
if let Some(ct) = &self.pq_ct {
buf.extend_from_slice(&ct.0);
}
}
pub fn decode(bytes: &[u8]) -> Result<Self, RatchetError> {
const MIN: usize = 41; if bytes.len() < MIN {
return Err(RatchetError::MalformedHeader("too short"));
}
let dh_pk: [u8; 32] = bytes[..32].try_into().expect("32-byte slice after MIN check");
let n = u32::from_le_bytes(bytes[32..36].try_into().expect("4-byte slice after MIN check"));
let pn = u32::from_le_bytes(bytes[36..40].try_into().expect("4-byte slice after MIN check"));
let flags = bytes[40];
if flags & !0x03 != 0 {
return Err(RatchetError::MalformedHeader("unknown flags"));
}
let has_ek = flags & 0x01 != 0;
let has_ct = flags & 0x02 != 0;
let mut pos = MIN;
let pq_ek = if has_ek {
let end = pos + PQ_EK_LEN;
if bytes.len() < end {
return Err(RatchetError::MalformedHeader("truncated EK"));
}
let ek: [u8; PQ_EK_LEN] = bytes[pos..end].try_into().expect("slice length guaranteed by bounds check");
pos = end;
Some(PqEk(ek))
} else {
None
};
let pq_ct = if has_ct {
let end = pos + PQ_CT_LEN;
if bytes.len() < end {
return Err(RatchetError::MalformedHeader("truncated CT"));
}
let ct: [u8; PQ_CT_LEN] = bytes[pos..end].try_into().expect("slice length guaranteed by bounds check");
pos = end;
Some(PqCt(ct))
} else {
None
};
if pos != bytes.len() {
return Err(RatchetError::MalformedHeader("trailing bytes"));
}
Ok(Header {
dh_pk,
n,
pn,
pq_ek,
pq_ct,
})
}
}
pub struct HybridRatchet {
dh_sk: StaticSecret, dh_pk: PublicKey, dh_pk_remote: Option<[u8; 32]>,
rk: [u8; 32], cks: Option<[u8; 32]>, ckr: Option<[u8; 32]>,
ns: u32, nr: u32, pn: u32,
skipped: HashMap<([u8; 32], u32), Zeroizing<[u8; 32]>>,
scka: SckaState,
}
impl Drop for HybridRatchet {
fn drop(&mut self) {
self.rk.zeroize();
if let Some(ref mut k) = self.cks {
k.zeroize();
}
if let Some(ref mut k) = self.ckr {
k.zeroize();
}
for v in self.skipped.values_mut() {
v.zeroize();
}
}
}
impl HybridRatchet {
pub fn init_sender(
shared_secret: &[u8; 32],
peer_dh_pk: &[u8; 32],
rng: &mut impl CryptoRngCore,
) -> Self {
let dh_sk = StaticSecret::random_from_rng(&mut *rng);
let dh_pk = PublicKey::from(&dh_sk);
let peer_pk = PublicKey::from(*peer_dh_pk);
let dh_ss = dh_sk.diffie_hellman(&peer_pk);
let (rk, cks) = kdf_rk(shared_secret, dh_ss.as_bytes(), &[0u8; PQ_SS_LEN]);
let scka = SckaState::new(rng);
Self {
dh_sk,
dh_pk,
dh_pk_remote: Some(*peer_dh_pk),
rk,
cks: Some(cks),
ckr: None,
ns: 0,
nr: 0,
pn: 0,
skipped: HashMap::with_capacity(MAX_SKIP_TOTAL),
scka,
}
}
pub fn init_receiver(
shared_secret: &[u8; 32],
our_dh_sk: StaticSecret,
rng: &mut impl CryptoRngCore,
) -> Self {
let dh_pk = PublicKey::from(&our_dh_sk);
let scka = SckaState::new(rng);
Self {
dh_sk: our_dh_sk,
dh_pk,
dh_pk_remote: None,
rk: *shared_secret,
cks: None,
ckr: None,
ns: 0,
nr: 0,
pn: 0,
skipped: HashMap::with_capacity(MAX_SKIP_TOTAL),
scka,
}
}
pub fn our_dh_pk(&self) -> [u8; 32] {
*self.dh_pk.as_bytes()
}
pub fn our_pq_ek(&self) -> PqEk {
self.scka.our_ek().clone()
}
}
impl HybridRatchet {
pub fn ratchet_encrypt(
&mut self,
_rng: &mut impl CryptoRngCore,
) -> Result<(Header, MessageKey), RatchetError> {
let cks = self.cks.as_mut().ok_or(RatchetError::NoSendingChain)?;
let (new_ck, mk) = kdf_ck(cks);
*cks = new_ck;
let n = self.ns;
self.ns += 1;
let header = Header {
dh_pk: *self.dh_pk.as_bytes(),
n,
pn: self.pn,
pq_ek: Some(self.scka.our_ek().clone()),
pq_ct: self.scka.pending_ct_ref().cloned(),
};
Ok((header, MessageKey(mk)))
}
pub fn ratchet_decrypt(
&mut self,
header: &Header,
rng: &mut impl CryptoRngCore,
) -> Result<MessageKey, RatchetError> {
if let Some(mk) = self.skipped.remove(&(header.dh_pk, header.n)) {
return Ok(MessageKey(*mk));
}
let is_new_dh = match self.dh_pk_remote {
Some(pk) => pk.ct_ne(&header.dh_pk).into(),
None => true,
};
if is_new_dh {
if self.ckr.is_some() {
self.skip_message_keys(header.pn)?;
}
self.dh_ratchet(header, rng)?;
}
if header.n < self.nr {
return Err(RatchetError::MessageKeyNotFound);
}
self.skip_message_keys(header.n)?;
let ckr = self.ckr.as_mut().ok_or(RatchetError::NoReceivingChain)?;
let (new_ck, mk) = kdf_ck(ckr);
*ckr = new_ck;
self.nr += 1;
Ok(MessageKey(mk))
}
}
impl HybridRatchet {
fn dh_ratchet(
&mut self,
header: &Header,
rng: &mut impl CryptoRngCore,
) -> Result<(), RatchetError> {
let pq_recv: [u8; PQ_SS_LEN] = match &header.pq_ct {
Some(ct) => self.scka.decap(ct)?,
None => [0u8; PQ_SS_LEN],
};
let (pq_send, opt_pending_ct): ([u8; PQ_SS_LEN], Option<PqCt>) = match &header.pq_ek {
Some(ek) => {
let (ss, ct) = self.scka.encap_to(ek, rng)?;
(ss, Some(ct))
}
None => ([0u8; PQ_SS_LEN], None),
};
let ct_to_carry = opt_pending_ct.or_else(|| self.scka.pending_ct_ref().cloned());
let new_dh_sk = StaticSecret::random_from_rng(&mut *rng);
let new_dh_pk = PublicKey::from(&new_dh_sk);
let mut new_scka = SckaState::new(rng);
if let Some(ct) = ct_to_carry {
new_scka.set_pending_ct(ct);
}
let peer_pk = PublicKey::from(header.dh_pk);
let dh_recv = self.dh_sk.diffie_hellman(&peer_pk);
let (rk1, ckr) = kdf_rk(&self.rk, dh_recv.as_bytes(), &pq_recv);
let dh_send = new_dh_sk.diffie_hellman(&peer_pk);
let (rk2, cks) = kdf_rk(&rk1, dh_send.as_bytes(), &pq_send);
self.pn = self.ns;
self.ns = 0;
self.nr = 0;
self.dh_pk_remote = Some(header.dh_pk);
self.rk = rk2;
self.ckr = Some(ckr);
self.cks = Some(cks);
self.dh_sk = new_dh_sk;
self.dh_pk = new_dh_pk;
self.scka = new_scka;
Ok(())
}
fn skip_message_keys(&mut self, until: u32) -> Result<(), RatchetError> {
if until < self.nr {
return Ok(());
}
let to_skip = (until - self.nr) as usize;
if to_skip > MAX_SKIP {
return Err(RatchetError::TooManySkipped(to_skip));
}
if self.skipped.len() + to_skip > MAX_SKIP_TOTAL {
return Err(RatchetError::TooManySkipped(self.skipped.len() + to_skip));
}
if let Some(ref mut ckr) = self.ckr {
let remote_pk = self.dh_pk_remote
.expect("dh_pk_remote is always Some when ckr is Some -- both set in dh_ratchet");
for i in self.nr..until {
let (new_ck, mk) = kdf_ck(ckr);
*ckr = new_ck;
self.skipped.insert((remote_pk, i), Zeroizing::new(mk));
}
self.nr = until;
}
Ok(())
}
}
impl HybridRatchet {
pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
let n_skip = self.skipped.len();
let mut buf: Vec<u8> = Vec::with_capacity(
1 + 32
+ 1
+ 32
+ 32
+ 1
+ 32
+ 1
+ 32
+ 4
+ 4
+ 4
+ 4
+ n_skip * 68
+ PQ_DK_LEN
+ PQ_EK_LEN
+ 1
+ PQ_CT_LEN,
);
buf.push(0x01);
buf.extend_from_slice(&self.dh_sk.to_bytes());
match self.dh_pk_remote {
Some(pk) => {
buf.push(1);
buf.extend_from_slice(&pk);
}
None => {
buf.push(0);
}
}
buf.extend_from_slice(&self.rk);
for opt in [self.cks, self.ckr] {
match opt {
Some(k) => {
buf.push(1);
buf.extend_from_slice(&k);
}
None => {
buf.push(0);
}
}
}
buf.extend_from_slice(&self.ns.to_le_bytes());
buf.extend_from_slice(&self.nr.to_le_bytes());
buf.extend_from_slice(&self.pn.to_le_bytes());
buf.extend_from_slice(&(n_skip as u32).to_le_bytes());
for ((rpk, idx), mk) in &self.skipped {
buf.extend_from_slice(rpk);
buf.extend_from_slice(&idx.to_le_bytes());
buf.extend_from_slice(mk.as_ref());
}
buf.extend_from_slice(&self.scka.dk_bytes());
buf.extend_from_slice(self.scka.ek_bytes_raw());
match self.scka.pending_ct_ref() {
Some(ct) => {
buf.push(1);
buf.extend_from_slice(&ct.0);
}
None => {
buf.push(0);
}
}
Zeroizing::new(buf)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, RatchetError> {
let err = |msg| RatchetError::MalformedState(msg);
let mut pos = 0usize;
macro_rules! read {
($n:expr) => {{
let end = pos + $n;
if end > bytes.len() {
return Err(err("truncated"));
}
let slice = &bytes[pos..end];
pos = end;
slice
}};
}
macro_rules! read_arr {
($n:expr) => {{
let s = read!($n);
let arr: [u8; $n] = s.try_into().unwrap();
arr
}};
}
macro_rules! read_u32 {
() => {
u32::from_le_bytes(read_arr!(4))
};
}
macro_rules! read_opt {
($n:expr) => {{
let flag = read_arr!(1)[0];
match flag {
0 => None,
1 => Some(read_arr!($n)),
_ => return Err(err("invalid flag")),
}
}};
}
if read_arr!(1)[0] != 0x01 {
return Err(err("unknown version"));
}
let dh_sk_bytes = read_arr!(32);
let dh_pk_remote = read_opt!(32);
let rk = read_arr!(32);
let cks = read_opt!(32);
let ckr = read_opt!(32);
let ns = read_u32!();
let nr = read_u32!();
let pn = read_u32!();
if ns > (u32::MAX / 2) || nr > (u32::MAX / 2) || pn > (u32::MAX / 2) {
return Err(err("message counter out of safe range"));
}
let n_skip = read_u32!() as usize;
if n_skip > MAX_SKIP_TOTAL {
return Err(err("skipped cache exceeds limit"));
}
let mut skipped = HashMap::with_capacity(MAX_SKIP_TOTAL);
for _ in 0..n_skip {
let rpk: [u8; 32] = read_arr!(32);
let idx = read_u32!();
let mk = read_arr!(32);
skipped.insert((rpk, idx), Zeroizing::new(mk));
}
let dk_bytes: [u8; PQ_DK_LEN] = read_arr!(PQ_DK_LEN);
let ek_bytes: [u8; PQ_EK_LEN] = read_arr!(PQ_EK_LEN);
let pending_ct = read_opt!(PQ_CT_LEN).map(PqCt);
if pos != bytes.len() {
return Err(err("trailing bytes"));
}
let dh_sk = StaticSecret::from(dh_sk_bytes);
let dh_pk = PublicKey::from(&dh_sk);
let scka = SckaState::from_parts(&dk_bytes, ek_bytes, pending_ct)
.ok_or_else(|| err("invalid ML-KEM DK"))?;
Ok(HybridRatchet {
dh_sk,
dh_pk,
dh_pk_remote,
rk,
cks,
ckr,
ns,
nr,
pn,
skipped,
scka,
})
}
}
impl fmt::Debug for HybridRatchet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HybridRatchet")
.field("ns", &self.ns)
.field("nr", &self.nr)
.field("pn", &self.pn)
.field("skipped_count", &self.skipped.len())
.field("has_cks", &self.cks.is_some())
.field("has_ckr", &self.ckr.is_some())
.finish_non_exhaustive()
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for HybridRatchet {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_bytes(&self.to_bytes())
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for HybridRatchet {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = HybridRatchet;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "pq-ratchet session state as bytes")
}
fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<HybridRatchet, E> {
HybridRatchet::from_bytes(v).map_err(E::custom)
}
fn visit_byte_buf<E: serde::de::Error>(self, v: Vec<u8>) -> Result<HybridRatchet, E> {
self.visit_bytes(&v)
}
fn visit_seq<A: serde::de::SeqAccess<'de>>(
self,
mut seq: A,
) -> Result<HybridRatchet, A::Error> {
let mut buf = Vec::new();
while let Some(b) = seq.next_element::<u8>()? {
buf.push(b);
}
self.visit_bytes(&buf)
}
}
d.deserialize_bytes(Visitor)
}
}