use super::Error;
use super::aead::HpkeAead;
use super::labeled::{labeled_expand, labeled_extract};
use super::suite::CipherSuite;
use alloc::vec::Vec;
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum Mode {
Base,
Psk,
Auth,
AuthPsk,
}
impl Mode {
const fn tag(self) -> u8 {
match self {
Mode::Base => 0x00,
Mode::Psk => 0x01,
Mode::Auth => 0x02,
Mode::AuthPsk => 0x03,
}
}
const fn uses_psk(self) -> bool {
matches!(self, Mode::Psk | Mode::AuthPsk)
}
}
fn verify_psk_inputs(mode: Mode, psk: &[u8], psk_id: &[u8]) -> Result<(), Error> {
let got_psk = !psk.is_empty();
let got_id = !psk_id.is_empty();
if got_psk != got_id {
return Err(Error::PskInputsInconsistent);
}
if got_psk != mode.uses_psk() {
return Err(Error::PskInputsInconsistent);
}
Ok(())
}
type KeyScheduleOutput = (Vec<u8>, Vec<u8>, Vec<u8>);
fn key_schedule(
suite: CipherSuite,
mode: Mode,
shared_secret: &[u8],
info: &[u8],
psk: &[u8],
psk_id: &[u8],
) -> Result<KeyScheduleOutput, Error> {
verify_psk_inputs(mode, psk, psk_id)?;
let suite_id = suite.suite_id();
let kdf = suite.kdf;
let psk_id_hash = labeled_extract(kdf, b"", &suite_id, b"psk_id_hash", psk_id);
let info_hash = labeled_extract(kdf, b"", &suite_id, b"info_hash", info);
let mut key_schedule_context = Vec::with_capacity(1 + psk_id_hash.len() + info_hash.len());
key_schedule_context.push(mode.tag());
key_schedule_context.extend_from_slice(&psk_id_hash);
key_schedule_context.extend_from_slice(&info_hash);
let secret = labeled_extract(kdf, shared_secret, &suite_id, b"secret", psk);
let mut key = alloc::vec![0u8; suite.aead.key_len()];
if !key.is_empty() {
labeled_expand(
kdf,
&secret,
&suite_id,
b"key",
&key_schedule_context,
&mut key,
);
}
let mut base_nonce = alloc::vec![0u8; suite.aead.nonce_len()];
if !base_nonce.is_empty() {
labeled_expand(
kdf,
&secret,
&suite_id,
b"base_nonce",
&key_schedule_context,
&mut base_nonce,
);
}
let mut exporter_secret = alloc::vec![0u8; kdf.output_len()];
labeled_expand(
kdf,
&secret,
&suite_id,
b"exp",
&key_schedule_context,
&mut exporter_secret,
);
Ok((key, base_nonce, exporter_secret))
}
fn compute_nonce(base_nonce: &[u8], seq: u64) -> Vec<u8> {
let nn = base_nonce.len();
let mut nonce = alloc::vec![0u8; nn];
let seq_be = seq.to_be_bytes();
let copy = nn.min(seq_be.len());
nonce[nn - copy..].copy_from_slice(&seq_be[seq_be.len() - copy..]);
for (n, b) in nonce.iter_mut().zip(base_nonce.iter()) {
*n ^= *b;
}
nonce
}
pub struct SenderContext {
suite: CipherSuite,
key: Vec<u8>,
base_nonce: Vec<u8>,
seq: u64,
exporter_secret: Vec<u8>,
}
pub struct ReceiverContext {
suite: CipherSuite,
key: Vec<u8>,
base_nonce: Vec<u8>,
seq: u64,
exporter_secret: Vec<u8>,
}
impl SenderContext {
pub(super) fn new(
suite: CipherSuite,
mode: Mode,
shared_secret: &[u8],
info: &[u8],
psk: &[u8],
psk_id: &[u8],
) -> Result<Self, Error> {
let (key, base_nonce, exporter_secret) =
key_schedule(suite, mode, shared_secret, info, psk, psk_id)?;
Ok(Self {
suite,
key,
base_nonce,
seq: 0,
exporter_secret,
})
}
pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Result<Vec<u8>, Error> {
if self.suite.aead.is_export_only() {
return Err(Error::ExportOnly);
}
let nonce = compute_nonce(&self.base_nonce, self.seq);
let ct = self.suite.aead.seal(&self.key, &nonce, aad, pt)?;
increment_seq(&mut self.seq, self.suite.aead)?;
Ok(ct)
}
pub fn export(&self, exporter_context: &[u8], length: usize) -> Vec<u8> {
export(self.suite, &self.exporter_secret, exporter_context, length)
}
}
impl ReceiverContext {
pub(super) fn new(
suite: CipherSuite,
mode: Mode,
shared_secret: &[u8],
info: &[u8],
psk: &[u8],
psk_id: &[u8],
) -> Result<Self, Error> {
let (key, base_nonce, exporter_secret) =
key_schedule(suite, mode, shared_secret, info, psk, psk_id)?;
Ok(Self {
suite,
key,
base_nonce,
seq: 0,
exporter_secret,
})
}
pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Result<Vec<u8>, Error> {
if self.suite.aead.is_export_only() {
return Err(Error::ExportOnly);
}
let nonce = compute_nonce(&self.base_nonce, self.seq);
let pt = self.suite.aead.open(&self.key, &nonce, aad, ct)?;
increment_seq(&mut self.seq, self.suite.aead)?;
Ok(pt)
}
pub fn export(&self, exporter_context: &[u8], length: usize) -> Vec<u8> {
export(self.suite, &self.exporter_secret, exporter_context, length)
}
}
fn export(
suite: CipherSuite,
exporter_secret: &[u8],
exporter_context: &[u8],
length: usize,
) -> Vec<u8> {
let suite_id = suite.suite_id();
let mut out = alloc::vec![0u8; length];
labeled_expand(
suite.kdf,
exporter_secret,
&suite_id,
b"sec",
exporter_context,
&mut out,
);
out
}
fn increment_seq(seq: &mut u64, aead: HpkeAead) -> Result<(), Error> {
if aead.is_export_only() {
return Ok(());
}
let nn = aead.nonce_len();
let limit_reached = if (8 * nn) >= 64 {
*seq == u64::MAX
} else {
*seq == (1u64 << (8 * nn)) - 1
};
if limit_reached {
return Err(Error::MessageLimitReached);
}
*seq += 1;
Ok(())
}