use alloc::vec::Vec;
use core::marker::PhantomData;
use zeroize::Zeroizing;
use crate::HpkeError;
use crate::aead::{Aead, SealingAead};
use crate::ciphersuite;
use crate::kdf::{Kdf, labeled_expand};
use crate::kem::Kem;
pub struct Context<K: Kem, F: Kdf, A: Aead> {
cipher: A::Cipher,
base_nonce: Zeroizing<Vec<u8>>,
exporter_secret: Zeroizing<Vec<u8>>,
seq: u64,
#[cfg(any(test, feature = "kat-internals", feature = "differential"))]
raw_key: Zeroizing<Vec<u8>>,
_kfa: PhantomData<(K, F, A)>,
}
struct AssertNonceRange<A: Aead>(PhantomData<A>);
impl<A: Aead> AssertNonceRange<A> {
const CHECK: () = {
assert!(A::NONCE_LEN >= 8, "AEAD::NONCE_LEN must be >= 8");
assert!(A::NONCE_LEN <= 12, "AEAD::NONCE_LEN must be <= 12");
};
}
impl<K: Kem, F: Kdf, A: Aead> Context<K, F, A> {
pub(crate) fn new(
key: Vec<u8>,
base_nonce: Vec<u8>,
exporter_secret: Vec<u8>,
) -> Result<Self, HpkeError> {
let key_z = Zeroizing::new(key);
let cipher = A::init(&key_z)?;
Ok(Self {
cipher,
base_nonce: Zeroizing::new(base_nonce),
exporter_secret: Zeroizing::new(exporter_secret),
seq: 0,
#[cfg(any(test, feature = "kat-internals", feature = "differential"))]
raw_key: Zeroizing::new(key_z.to_vec()),
_kfa: PhantomData,
})
}
pub fn export(&self, exporter_context: &[u8], length: usize) -> Result<Vec<u8>, HpkeError> {
let suite = ciphersuite::<K, F, A>();
labeled_expand::<F>(
&self.exporter_secret,
&suite,
b"sec",
exporter_context,
length,
)
}
fn compute_nonce(&self) -> [u8; 12] {
let () = AssertNonceRange::<A>::CHECK;
let mut nonce = [0u8; 12];
let len = A::NONCE_LEN;
nonce[..len].copy_from_slice(&self.base_nonce[..len]);
let seq_be = self.seq.to_be_bytes();
for i in 0..8 {
nonce[len - 8 + i] ^= seq_be[i];
}
nonce
}
}
#[cfg(any(test, feature = "kat-internals", feature = "differential"))]
impl<K: Kem, F: Kdf, A: Aead> Context<K, F, A> {
#[must_use]
pub fn key(&self) -> &[u8] {
&self.raw_key
}
#[must_use]
pub fn nonce(&self) -> &[u8] {
&self.base_nonce
}
#[must_use]
pub fn exporter_secret(&self) -> &[u8] {
&self.exporter_secret
}
#[must_use]
pub fn sequence_number(&self) -> u64 {
self.seq
}
#[cfg(test)]
pub(crate) fn set_seq_for_test(&mut self, seq: u64) {
self.seq = seq;
}
}
impl<K: Kem, F: Kdf, A: SealingAead> Context<K, F, A> {
pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Result<Vec<u8>, HpkeError> {
if self.seq == u64::MAX {
return Err(HpkeError::MessageLimitReached);
}
let nonce = self.compute_nonce();
let ct = A::seal(&self.cipher, &nonce[..A::NONCE_LEN], aad, pt)?;
self.seq += 1; Ok(ct)
}
pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Result<Vec<u8>, HpkeError> {
if self.seq == u64::MAX {
return Err(HpkeError::MessageLimitReached);
}
let nonce = self.compute_nonce();
let pt = A::open(&self.cipher, &nonce[..A::NONCE_LEN], aad, ct)?;
self.seq += 1;
Ok(pt)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ChaCha20Poly1305, DhKemX25519HkdfSha256, HkdfSha256};
type Ctx = Context<DhKemX25519HkdfSha256, HkdfSha256, ChaCha20Poly1305>;
#[test]
fn seal_open_roundtrip_with_known_state() {
let key = vec![0x42u8; 32];
let base_nonce = vec![0x77u8; 12];
let exporter_secret = vec![0u8; 32];
let mut sender: Ctx =
Context::new(key.clone(), base_nonce.clone(), exporter_secret.clone()).unwrap();
let mut receiver: Ctx = Context::new(key, base_nonce, exporter_secret).unwrap();
let ct = sender.seal(b"aad", b"message").unwrap();
let pt = receiver.open(b"aad", &ct).unwrap();
assert_eq!(pt, b"message");
assert_eq!(sender.sequence_number(), 1);
assert_eq!(receiver.sequence_number(), 1);
for i in 0..3 {
let pt = alloc::format!("msg-{i}");
let ct = sender.seal(b"aad", pt.as_bytes()).unwrap();
let recovered = receiver.open(b"aad", &ct).unwrap();
assert_eq!(recovered, pt.as_bytes());
}
assert_eq!(sender.sequence_number(), 4);
}
#[test]
fn export_is_deterministic() {
let ctx: Ctx = Context::new(vec![0u8; 32], vec![0u8; 12], vec![1u8; 32]).unwrap();
let a = ctx.export(b"context", 32).unwrap();
let b = ctx.export(b"context", 32).unwrap();
assert_eq!(a, b);
assert_eq!(a.len(), 32);
let c = ctx.export(b"different", 32).unwrap();
assert_ne!(a, c);
}
#[test]
fn export_length_bound() {
let ctx: Ctx = Context::new(vec![0u8; 32], vec![0u8; 12], vec![1u8; 32]).unwrap();
assert_eq!(
ctx.export(b"ctx", 8161),
Err(HpkeError::ExportLengthExceeded)
);
}
#[test]
fn seal_rejects_at_message_limit() {
let mut ctx: Ctx = Context::new(vec![0x42u8; 32], vec![0x77u8; 12], vec![0u8; 32]).unwrap();
ctx.set_seq_for_test(u64::MAX);
let r = ctx.seal(b"aad", b"hello");
assert_eq!(r, Err(HpkeError::MessageLimitReached));
}
#[test]
fn open_rejects_at_message_limit() {
let mut ctx: Ctx = Context::new(vec![0x42u8; 32], vec![0x77u8; 12], vec![0u8; 32]).unwrap();
let mut sibling: Ctx =
Context::new(vec![0x42u8; 32], vec![0x77u8; 12], vec![0u8; 32]).unwrap();
let ct = sibling.seal(b"aad", b"hello").unwrap();
ctx.set_seq_for_test(u64::MAX);
let r = ctx.open(b"aad", &ct);
assert!(r.is_err());
}
}