use super::Error;
use super::aead::HpkeAead;
use super::labeled::{labeled_expand, labeled_extract};
use super::suite::CipherSuite;
use alloc::vec::Vec;
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum Mode {
Base,
Psk,
Auth,
AuthPsk,
}
impl Mode {
const fn tag(self) -> u8 {
match self {
Mode::Base => 0x00,
Mode::Psk => 0x01,
Mode::Auth => 0x02,
Mode::AuthPsk => 0x03,
}
}
const fn uses_psk(self) -> bool {
matches!(self, Mode::Psk | Mode::AuthPsk)
}
}
fn verify_psk_inputs(mode: Mode, psk: &[u8], psk_id: &[u8]) -> Result<(), Error> {
let got_psk = !psk.is_empty();
let got_id = !psk_id.is_empty();
if got_psk != got_id {
return Err(Error::PskInputsInconsistent);
}
if got_psk != mode.uses_psk() {
return Err(Error::PskInputsInconsistent);
}
Ok(())
}
type KeyScheduleOutput = (Vec<u8>, Vec<u8>, Vec<u8>);
fn key_schedule(
suite: CipherSuite,
mode: Mode,
shared_secret: &[u8],
info: &[u8],
psk: &[u8],
psk_id: &[u8],
) -> Result<KeyScheduleOutput, Error> {
verify_psk_inputs(mode, psk, psk_id)?;
let suite_id = suite.suite_id();
let kdf = suite.kdf;
let psk_id_hash = labeled_extract(kdf, b"", &suite_id, b"psk_id_hash", psk_id);
let info_hash = labeled_extract(kdf, b"", &suite_id, b"info_hash", info);
let mut key_schedule_context = Vec::with_capacity(1 + psk_id_hash.len() + info_hash.len());
key_schedule_context.push(mode.tag());
key_schedule_context.extend_from_slice(&psk_id_hash);
key_schedule_context.extend_from_slice(&info_hash);
let mut secret = labeled_extract(kdf, shared_secret, &suite_id, b"secret", psk);
let mut key = alloc::vec![0u8; suite.aead.key_len()];
if !key.is_empty() {
labeled_expand(
kdf,
&secret,
&suite_id,
b"key",
&key_schedule_context,
&mut key,
);
}
let mut base_nonce = alloc::vec![0u8; suite.aead.nonce_len()];
if !base_nonce.is_empty() {
labeled_expand(
kdf,
&secret,
&suite_id,
b"base_nonce",
&key_schedule_context,
&mut base_nonce,
);
}
let mut exporter_secret = alloc::vec![0u8; kdf.output_len()];
labeled_expand(
kdf,
&secret,
&suite_id,
b"exp",
&key_schedule_context,
&mut exporter_secret,
);
for b in secret.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&secret);
Ok((key, base_nonce, exporter_secret))
}
fn compute_nonce(base_nonce: &[u8], seq: u64) -> Vec<u8> {
let nn = base_nonce.len();
let mut nonce = alloc::vec![0u8; nn];
let seq_be = seq.to_be_bytes();
let copy = nn.min(seq_be.len());
nonce[nn - copy..].copy_from_slice(&seq_be[seq_be.len() - copy..]);
for (n, b) in nonce.iter_mut().zip(base_nonce.iter()) {
*n ^= *b;
}
nonce
}
pub struct SenderContext {
suite: CipherSuite,
key: Vec<u8>,
base_nonce: Vec<u8>,
seq: u64,
exhausted: bool,
exporter_secret: Vec<u8>,
}
pub struct ReceiverContext {
suite: CipherSuite,
key: Vec<u8>,
base_nonce: Vec<u8>,
seq: u64,
exhausted: bool,
exporter_secret: Vec<u8>,
}
impl SenderContext {
pub(super) fn new(
suite: CipherSuite,
mode: Mode,
shared_secret: &[u8],
info: &[u8],
psk: &[u8],
psk_id: &[u8],
) -> Result<Self, Error> {
let (key, base_nonce, exporter_secret) =
key_schedule(suite, mode, shared_secret, info, psk, psk_id)?;
Ok(Self {
suite,
key,
base_nonce,
seq: 0,
exhausted: false,
exporter_secret,
})
}
pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Result<Vec<u8>, Error> {
if self.suite.aead.is_export_only() {
return Err(Error::ExportOnly);
}
if self.exhausted {
return Err(Error::MessageLimitReached);
}
let nonce = compute_nonce(&self.base_nonce, self.seq);
let ct = self.suite.aead.seal(&self.key, &nonce, aad, pt)?;
if let Err(e) = increment_seq(&mut self.seq, self.suite.aead) {
self.exhausted = true;
return Err(e);
}
Ok(ct)
}
pub fn export(&self, exporter_context: &[u8], length: usize) -> Result<Vec<u8>, Error> {
export(self.suite, &self.exporter_secret, exporter_context, length)
}
}
impl Drop for SenderContext {
fn drop(&mut self) {
for b in self.key.iter_mut() {
*b = 0;
}
for b in self.base_nonce.iter_mut() {
*b = 0;
}
for b in self.exporter_secret.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&self.key);
let _ = core::hint::black_box(&self.base_nonce);
let _ = core::hint::black_box(&self.exporter_secret);
}
}
impl ReceiverContext {
pub(super) fn new(
suite: CipherSuite,
mode: Mode,
shared_secret: &[u8],
info: &[u8],
psk: &[u8],
psk_id: &[u8],
) -> Result<Self, Error> {
let (key, base_nonce, exporter_secret) =
key_schedule(suite, mode, shared_secret, info, psk, psk_id)?;
Ok(Self {
suite,
key,
base_nonce,
seq: 0,
exhausted: false,
exporter_secret,
})
}
pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Result<Vec<u8>, Error> {
if self.suite.aead.is_export_only() {
return Err(Error::ExportOnly);
}
if self.exhausted {
return Err(Error::MessageLimitReached);
}
let nonce = compute_nonce(&self.base_nonce, self.seq);
let pt = self.suite.aead.open(&self.key, &nonce, aad, ct)?;
if let Err(e) = increment_seq(&mut self.seq, self.suite.aead) {
self.exhausted = true;
return Err(e);
}
Ok(pt)
}
pub fn export(&self, exporter_context: &[u8], length: usize) -> Result<Vec<u8>, Error> {
export(self.suite, &self.exporter_secret, exporter_context, length)
}
}
impl Drop for ReceiverContext {
fn drop(&mut self) {
for b in self.key.iter_mut() {
*b = 0;
}
for b in self.base_nonce.iter_mut() {
*b = 0;
}
for b in self.exporter_secret.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&self.key);
let _ = core::hint::black_box(&self.base_nonce);
let _ = core::hint::black_box(&self.exporter_secret);
}
}
fn export(
suite: CipherSuite,
exporter_secret: &[u8],
exporter_context: &[u8],
length: usize,
) -> Result<Vec<u8>, Error> {
let max = suite
.kdf
.output_len()
.saturating_mul(255)
.min(u16::MAX as usize);
if length > max {
return Err(Error::ExportLengthExceeded);
}
let suite_id = suite.suite_id();
let mut out = alloc::vec![0u8; length];
labeled_expand(
suite.kdf,
exporter_secret,
&suite_id,
b"sec",
exporter_context,
&mut out,
);
Ok(out)
}
fn increment_seq(seq: &mut u64, aead: HpkeAead) -> Result<(), Error> {
if aead.is_export_only() {
return Ok(());
}
let nn = aead.nonce_len();
let limit_reached = if (8 * nn) >= 64 {
*seq == u64::MAX
} else {
*seq == (1u64 << (8 * nn)) - 1
};
if limit_reached {
return Err(Error::MessageLimitReached);
}
*seq += 1;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hpke::{HpkeAead, HpkeKdf, HpkeKem};
fn aes128_suite() -> CipherSuite {
CipherSuite::new(
HpkeKem::DhkemX25519HkdfSha256,
HpkeKdf::HkdfSha256,
HpkeAead::Aes128Gcm,
)
}
fn sender_at(suite: CipherSuite, seq: u64) -> SenderContext {
SenderContext {
suite,
key: alloc::vec![0u8; suite.aead.key_len()],
base_nonce: alloc::vec![0u8; suite.aead.nonce_len()],
seq,
exhausted: false,
exporter_secret: alloc::vec![0u8; suite.kdf.output_len()],
}
}
#[test]
fn paired_contexts_seal_open_roundtrip_after_zeroize() {
let suite = aes128_suite();
let shared_secret = [0x42u8; 32];
let info = b"info";
let mut sender =
SenderContext::new(suite, Mode::Base, &shared_secret, info, b"", b"").unwrap();
let mut receiver =
ReceiverContext::new(suite, Mode::Base, &shared_secret, info, b"", b"").unwrap();
let aad = b"aad";
for i in 0u8..4 {
let pt = alloc::vec![i; 16 + i as usize];
let ct = sender.seal(aad, &pt).unwrap();
assert_eq!(receiver.open(aad, &ct).unwrap(), pt);
}
assert_eq!(
sender.export(b"exp-ctx", 32).unwrap(),
receiver.export(b"exp-ctx", 32).unwrap(),
);
}
#[test]
fn seal_poisons_after_limit_no_nonce_reuse() {
let suite = aes128_suite();
let mut ctx = sender_at(suite, u64::MAX);
let first = ctx.seal(b"aad", b"pt");
assert_eq!(first, Err(Error::MessageLimitReached));
assert!(ctx.exhausted, "context must be poisoned after limit");
assert_eq!(ctx.seq, u64::MAX);
let second = ctx.seal(b"aad", b"pt");
assert_eq!(second, Err(Error::MessageLimitReached));
assert_eq!(ctx.seq, u64::MAX);
}
#[test]
fn export_rejects_overlong_length() {
let suite = aes128_suite();
let ctx = sender_at(suite, 0);
let max = suite.kdf.output_len() * 255;
assert!(ctx.export(b"ctx", max).is_ok());
assert_eq!(
ctx.export(b"ctx", max + 1),
Err(Error::ExportLengthExceeded)
);
assert_eq!(
ctx.export(b"ctx", usize::MAX),
Err(Error::ExportLengthExceeded)
);
}
}