#![forbid(unsafe_code)]
#![allow(non_snake_case)]
use core::{
fmt::{self, Debug, Display},
marker::PhantomData,
num::NonZeroU16,
result::Result,
};
use aranya_buggy::{bug, Bug, BugExt};
use generic_array::ArrayLength;
use subtle::{Choice, ConstantTimeEq};
use crate::{
aead::{Aead, IndCca2, KeyData, Nonce, OpenError, SealError},
csprng::Csprng,
import::{ExportError, Import, ImportError},
kdf::{Context, Expand, Kdf, KdfError, Prk},
kem::{Kem, KemError},
AlgId,
};
macro_rules! i2osp {
($v:expr) => {
$v.to_be_bytes()
};
($v:expr, $n:ty) => {{
let src = $v.to_be_bytes();
let mut dst = generic_array::GenericArray::<u8, $n>::default();
let idx = dst.len().abs_diff(src.len());
if dst.len() >= src.len() {
dst[idx..].copy_from_slice(&src);
} else {
dst.copy_from_slice(&src[idx..]);
}
dst
}};
}
#[cfg_attr(test, derive(Debug))]
pub enum Mode<'a, T> {
Base,
Psk(Psk<'a>),
Auth(T),
AuthPsk(T, Psk<'a>),
}
impl<T> Display for Mode<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Base => write!(f, "mode_base"),
Self::Psk(_) => write!(f, "mode_psk"),
Self::Auth(_) => write!(f, "mode_auth"),
Self::AuthPsk(_, _) => write!(f, "mode_auth_psk"),
}
}
}
impl<'a, T> Mode<'a, T> {
const DEFAULT_PSK: Psk<'static> = Psk {
psk: &[],
psk_id: &[],
};
pub const fn as_ref(&self) -> Mode<'_, &T> {
match *self {
Self::Base => Mode::Base,
Self::Psk(psk) => Mode::Psk(psk),
Self::Auth(ref k) => Mode::Auth(k),
Self::AuthPsk(ref k, psk) => Mode::AuthPsk(k, psk),
}
}
fn psk(&self) -> &Psk<'a> {
match self {
Mode::Psk(psk) => psk,
Mode::AuthPsk(_, psk) => psk,
_ => &Self::DEFAULT_PSK,
}
}
const fn id(&self) -> u8 {
match self {
Self::Base => 0x00,
Self::Psk(_) => 0x01,
Self::Auth(_) => 0x02,
Self::AuthPsk(_, _) => 0x03,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct InvalidPsk;
impl Display for InvalidPsk {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("invalid pre-shared key: PSK or PSK ID are empty")
}
}
impl core::error::Error for InvalidPsk {}
#[cfg_attr(test, derive(Debug))]
#[derive(Copy, Clone)]
pub struct Psk<'a> {
psk: &'a [u8],
psk_id: &'a [u8],
}
impl<'a> Psk<'a> {
pub fn new(psk: &'a [u8], psk_id: &'a [u8]) -> Result<Self, InvalidPsk> {
if psk.is_empty() || psk_id.is_empty() {
Err(InvalidPsk)
} else {
Ok(Self { psk, psk_id })
}
}
}
impl ConstantTimeEq for Psk<'_> {
fn ct_eq(&self, other: &Self) -> Choice {
self.psk.ct_eq(other.psk) & self.psk_id.ct_eq(other.psk_id)
}
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
pub enum KemId {
#[alg_id(0x0010)]
DhKemP256HkdfSha256,
#[alg_id(0x0011)]
DhKemP384HkdfSha384,
#[alg_id(0x0012)]
DhKemP521HkdfSha512,
#[alg_id(0x0013)]
DhKemCp256HkdfSha256,
#[alg_id(0x0014)]
DhKemCp384HkdfSha384,
#[alg_id(0x0015)]
DhKemCp521HkdfSha512,
#[alg_id(0x0016)]
DhKemSecp256k1HkdfSha256,
#[alg_id(0x0020)]
DhKemX25519HkdfSha256,
#[alg_id(0x0021)]
DhKemX448HkdfSha512,
#[alg_id(0x0030)]
X25519Kyber768Draft00,
#[alg_id(Other)]
Other(NonZeroU16),
}
impl Display for KemId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DhKemP256HkdfSha256 => write!(f, "DHKEM(P-256, HKDF-SHA256)"),
Self::DhKemP384HkdfSha384 => write!(f, "DHKEM(P-384, HKDF-SHA384)"),
Self::DhKemP521HkdfSha512 => write!(f, "DHKEM(P-521, HKDF-SHA512)"),
Self::DhKemCp256HkdfSha256 => write!(f, "DHKEM(CP-256, HKDF-SHA256)"),
Self::DhKemCp384HkdfSha384 => write!(f, "DHKEM(CP-384, HKDF-SHA384)"),
Self::DhKemCp521HkdfSha512 => write!(f, "DHKEM(CP-521, HKDF-SHA512)"),
Self::DhKemSecp256k1HkdfSha256 => write!(f, "DHKEM(secp256k1, HKDF-SHA256)"),
Self::DhKemX25519HkdfSha256 => write!(f, "DHKEM(X25519, HKDF-SHA256)"),
Self::DhKemX448HkdfSha512 => write!(f, "DHKEM(X448, HKDF-SHA512)"),
Self::X25519Kyber768Draft00 => write!(f, "X25519Kyber768Draft00"),
Self::Other(id) => write!(f, "Kem({:#02x})", id),
}
}
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
pub enum KdfId {
#[alg_id(0x0001)]
HkdfSha256,
#[alg_id(0x0002)]
HkdfSha384,
#[alg_id(0x0003)]
HkdfSha512,
#[alg_id(Other)]
Other(NonZeroU16),
}
impl Display for KdfId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::HkdfSha256 => write!(f, "HkdfSha256"),
Self::HkdfSha384 => write!(f, "HkdfSha384"),
Self::HkdfSha512 => write!(f, "HkdfSha512"),
Self::Other(id) => write!(f, "Kdf({:#02x})", id),
}
}
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
pub enum AeadId {
#[alg_id(0x0001)]
Aes128Gcm,
#[alg_id(0x0002)]
Aes256Gcm,
#[alg_id(0x0003)]
ChaCha20Poly1305,
#[alg_id(0xfffd)]
Cmt1Aes256Gcm,
#[alg_id(0xfffe)]
Cmt4Aes256Gcm,
#[alg_id(Other)]
Other(NonZeroU16),
#[alg_id(0xffff)]
ExportOnly,
}
impl Display for AeadId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Aes128Gcm => write!(f, "Aes128Gcm"),
Self::Aes256Gcm => write!(f, "Aes256Gcm"),
Self::ChaCha20Poly1305 => write!(f, "ChaCha20Poly1305"),
Self::Cmt1Aes256Gcm => write!(f, "Cmt1Aes256Gcm"),
Self::Cmt4Aes256Gcm => write!(f, "Cmt4Aes256Gcm"),
Self::Other(id) => write!(f, "Aead({:#02x})", id),
Self::ExportOnly => write!(f, "ExportOnly"),
}
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum HpkeError {
Seal(SealError),
Open(OpenError),
Kdf(KdfError),
Kem(KemError),
Import(ImportError),
Export(ExportError),
MessageLimitReached,
Bug(Bug),
}
impl Display for HpkeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Seal(err) => write!(f, "{}", err),
Self::Open(err) => write!(f, "{}", err),
Self::Kdf(err) => write!(f, "{}", err),
Self::Kem(err) => write!(f, "{}", err),
Self::Import(err) => write!(f, "{}", err),
Self::Export(err) => write!(f, "{}", err),
Self::MessageLimitReached => write!(f, "message limit reached"),
Self::Bug(err) => write!(f, "{err}"),
}
}
}
impl core::error::Error for HpkeError {
fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
match self {
Self::Seal(err) => Some(err),
Self::Open(err) => Some(err),
Self::Kdf(err) => Some(err),
Self::Kem(err) => Some(err),
Self::Import(err) => Some(err),
Self::Export(err) => Some(err),
Self::MessageLimitReached => None,
Self::Bug(err) => Some(err),
}
}
}
impl From<SealError> for HpkeError {
fn from(err: SealError) -> Self {
Self::Seal(err)
}
}
impl From<OpenError> for HpkeError {
fn from(err: OpenError) -> Self {
Self::Open(err)
}
}
impl From<KdfError> for HpkeError {
fn from(err: KdfError) -> Self {
Self::Kdf(err)
}
}
impl From<KemError> for HpkeError {
fn from(err: KemError) -> Self {
Self::Kem(err)
}
}
impl From<ImportError> for HpkeError {
fn from(err: ImportError) -> Self {
Self::Import(err)
}
}
impl From<ExportError> for HpkeError {
fn from(err: ExportError) -> Self {
Self::Export(err)
}
}
impl From<Bug> for HpkeError {
fn from(err: Bug) -> Self {
Self::Bug(err)
}
}
impl From<MessageLimitReached> for HpkeError {
fn from(_err: MessageLimitReached) -> Self {
Self::MessageLimitReached
}
}
pub struct Hpke<K, F, A> {
_kem: PhantomData<K>,
_kdf: PhantomData<F>,
_aead: PhantomData<A>,
}
impl<K: Kem, F: Kdf, A: Aead + IndCca2> Hpke<K, F, A> {
#[allow(clippy::type_complexity)]
pub fn setup_send<R: Csprng>(
rng: &mut R,
mode: Mode<'_, &K::DecapKey>,
pkR: &K::EncapKey,
info: &[u8],
) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
let (shared_secret, enc) = match mode {
Mode::Auth(skS) | Mode::AuthPsk(skS, _) => K::auth_encap::<R>(rng, pkR, skS)?,
Mode::Base | Mode::Psk(_) => K::encap::<R>(rng, pkR)?,
};
let ctx = Self::key_schedule(mode, &shared_secret, info)?;
Ok((enc, ctx.into_send_ctx()))
}
#[allow(clippy::type_complexity)]
pub fn setup_send_deterministically(
mode: Mode<'_, &K::DecapKey>,
pkR: &K::EncapKey,
info: &[u8],
skE: K::DecapKey,
) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
let (shared_secret, enc) = match mode {
Mode::Auth(skS) | Mode::AuthPsk(skS, _) => {
K::auth_encap_deterministically(pkR, skS, skE)?
}
Mode::Base | Mode::Psk(_) => K::encap_deterministically(pkR, skE)?,
};
let ctx = Self::key_schedule(mode, &shared_secret, info)?;
Ok((enc, ctx.into_send_ctx()))
}
pub fn setup_recv(
mode: Mode<'_, &K::EncapKey>,
enc: &K::Encap,
skR: &K::DecapKey,
info: &[u8],
) -> Result<RecvCtx<K, F, A>, HpkeError> {
let shared_secret = match mode {
Mode::Auth(pkS) | Mode::AuthPsk(pkS, _) => K::auth_decap(enc, skR, pkS)?,
Mode::Base | Mode::Psk(_) => K::decap(enc, skR)?,
};
let ctx = Self::key_schedule(mode, &shared_secret, info)?;
Ok(ctx.into_recv_ctx())
}
#[rustfmt::skip]
const HPKE_SUITE_ID: [u8; 10] = [
b'H',
b'P',
b'K',
b'E',
i2osp!(K::ID)[0], i2osp!(K::ID)[1],
i2osp!(F::ID)[0], i2osp!(F::ID)[1],
i2osp!(A::ID)[0], i2osp!(A::ID)[1],
];
fn key_schedule<T>(
mode: Mode<'_, T>,
shared_secret: &K::Secret,
info: &[u8],
) -> Result<Schedule<K, F, A>, HpkeError> {
let Psk { psk, psk_id } = mode.psk();
let psk_id_hash = Self::labeled_extract(b"", "psk_id_hash", psk_id);
let info_hash = Self::labeled_extract(b"", "info_hash", info);
let ks_ctx = [&[mode.id()], psk_id_hash.as_bytes(), info_hash.as_bytes()];
let secret = Self::labeled_extract(shared_secret.as_ref(), "secret", psk);
let key = Self::labeled_expand(&secret, "key", &ks_ctx)?;
let base_nonce = Self::labeled_expand(&secret, "base_nonce", &ks_ctx)?;
let exporter_secret = Self::labeled_expand(&secret, "exp", &ks_ctx)?;
Ok(Schedule {
key,
base_nonce,
exporter_secret,
_kem: PhantomData,
})
}
const HPKE_CTX: Context = Context {
domain: "HPKE-v1",
suite_ids: &Self::HPKE_SUITE_ID,
};
fn labeled_extract(salt: &[u8], label: &'static str, ikm: &[u8]) -> Prk<F::PrkSize> {
Self::HPKE_CTX.labeled_extract::<F>(salt, label, ikm)
}
fn labeled_expand<T: Expand>(
prk: &Prk<F::PrkSize>,
label: &'static str,
info: &[&[u8]],
) -> Result<T, KdfError> {
let key = Self::HPKE_CTX.labeled_expand::<F, T>(prk, label, info)?;
Ok(key)
}
fn labeled_expand_into(
out: &mut [u8],
prk: &Prk<F::PrkSize>,
label: &'static str,
info: &[&[u8]],
) -> Result<(), KdfError> {
Self::HPKE_CTX.labeled_expand_into::<F>(out, prk, label, info)
}
}
struct Schedule<K: Kem, F: Kdf, A: Aead + IndCca2> {
key: KeyData<A>,
base_nonce: Nonce<A::NonceSize>,
exporter_secret: Prk<F::PrkSize>,
_kem: PhantomData<K>,
}
impl<K: Kem, F: Kdf, A: Aead + IndCca2> Schedule<K, F, A> {
fn into_send_ctx(self) -> SendCtx<K, F, A> {
SendCtx {
seal: Either::Right((self.key, self.base_nonce)),
export: ExportCtx::new(self.exporter_secret),
}
}
fn into_recv_ctx(self) -> RecvCtx<K, F, A> {
RecvCtx {
open: Either::Right((self.key, self.base_nonce)),
export: ExportCtx::new(self.exporter_secret),
}
}
}
enum Either<L, R> {
Left(L),
Right(R),
}
impl<L, R> Either<L, R> {
fn get_or_insert_left<F, E>(&mut self, f: F) -> Result<&mut L, E>
where
F: FnOnce(&R) -> Result<L, E>,
E: From<Bug>,
{
match self {
Self::Left(left) => Ok(left),
Self::Right(right) => {
*self = Self::Left(f(right)?);
match self {
Self::Left(left) => Ok(left),
Self::Right(_) => bug!("we just assigned `Self::Left`"),
}
}
}
}
}
type RawKey<A> = (KeyData<A>, Nonce<<A as Aead>::NonceSize>);
pub struct SendCtx<K: Kem, F: Kdf, A: Aead + IndCca2> {
seal: Either<SealCtx<A>, RawKey<A>>,
export: ExportCtx<K, F, A>,
}
impl<K: Kem, F: Kdf, A: Aead + IndCca2> SendCtx<K, F, A> {
pub const OVERHEAD: usize = SealCtx::<A>::OVERHEAD;
pub(crate) fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
match self.seal {
Either::Left(_) => None,
Either::Right((key, base_nonce)) => Some((key, base_nonce)),
}
}
fn seal_ctx(&mut self) -> Result<&mut SealCtx<A>, ImportError> {
self.seal
.get_or_insert_left(|(key, nonce)| SealCtx::new(key, nonce, Seq::ZERO))
}
pub fn seal(
&mut self,
dst: &mut [u8],
plaintext: &[u8],
additional_data: &[u8],
) -> Result<Seq, HpkeError> {
self.seal_ctx()?.seal(dst, plaintext, additional_data)
}
pub fn seal_in_place(
&mut self,
data: impl AsMut<[u8]>,
tag: &mut [u8],
additional_data: &[u8],
) -> Result<Seq, HpkeError> {
self.seal_ctx()?.seal_in_place(data, tag, additional_data)
}
pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
where
T: Expand,
{
self.export.export(context)
}
pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
self.export.export_into(out, context)
}
}
pub struct SealCtx<A: Aead + IndCca2> {
aead: A,
base_nonce: Nonce<A::NonceSize>,
seq: Seq,
}
impl<A: Aead + IndCca2> SealCtx<A> {
pub const OVERHEAD: usize = A::OVERHEAD;
pub(crate) fn new(
key: &KeyData<A>,
base_nonce: &Nonce<A::NonceSize>,
seq: Seq,
) -> Result<Self, ImportError> {
let key = A::Key::import(key.as_bytes())?;
Ok(Self {
aead: A::new(&key),
base_nonce: base_nonce.clone(),
seq,
})
}
fn compute_nonce(&self) -> Result<Nonce<A::NonceSize>, MessageLimitReached> {
self.seq.compute_nonce::<A::NonceSize>(&self.base_nonce)
}
fn increment_seq(&mut self) -> Result<Seq, Bug> {
self.seq.increment::<A::NonceSize>()
}
pub fn seal(
&mut self,
dst: &mut [u8],
plaintext: &[u8],
additional_data: &[u8],
) -> Result<Seq, HpkeError> {
let nonce = self.compute_nonce()?;
self.aead.seal(dst, &nonce, plaintext, additional_data)?;
let prev = self.increment_seq()?;
Ok(prev)
}
pub fn seal_in_place(
&mut self,
mut data: impl AsMut<[u8]>,
tag: &mut [u8],
additional_data: &[u8],
) -> Result<Seq, HpkeError> {
let nonce = self.compute_nonce()?;
self.aead
.seal_in_place(&nonce, data.as_mut(), tag, additional_data)?;
let prev = self.increment_seq()?;
Ok(prev)
}
pub fn seq(&self) -> Seq {
self.seq
}
}
pub struct RecvCtx<K: Kem, F: Kdf, A: Aead + IndCca2> {
open: Either<OpenCtx<A>, RawKey<A>>,
export: ExportCtx<K, F, A>,
}
impl<K: Kem, F: Kdf, A: Aead + IndCca2> RecvCtx<K, F, A> {
pub const OVERHEAD: usize = OpenCtx::<A>::OVERHEAD;
pub(crate) fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
match self.open {
Either::Left(_) => None,
Either::Right((key, base_nonce)) => Some((key, base_nonce)),
}
}
fn open_ctx(&mut self) -> Result<&mut OpenCtx<A>, ImportError> {
self.open
.get_or_insert_left(|(key, nonce)| OpenCtx::new(key, nonce, Seq::ZERO))
}
pub fn open(
&mut self,
dst: &mut [u8],
ciphertext: &[u8],
additional_data: &[u8],
) -> Result<(), HpkeError> {
self.open_ctx()?.open(dst, ciphertext, additional_data)
}
pub fn open_at(
&mut self,
dst: &mut [u8],
ciphertext: &[u8],
additional_data: &[u8],
seq: Seq,
) -> Result<(), HpkeError> {
self.open_ctx()?
.open_at(dst, ciphertext, additional_data, seq)
}
pub fn open_in_place(
&mut self,
data: impl AsMut<[u8]>,
tag: &[u8],
additional_data: &[u8],
) -> Result<(), HpkeError> {
self.open_ctx()?.open_in_place(data, tag, additional_data)
}
pub fn open_in_place_at(
&mut self,
data: impl AsMut<[u8]>,
tag: &[u8],
additional_data: &[u8],
seq: Seq,
) -> Result<(), HpkeError> {
self.open_ctx()?
.open_in_place_at(data, tag, additional_data, seq)
}
pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
where
T: Expand,
{
self.export.export(context)
}
pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
self.export.export_into(out, context)
}
}
pub struct OpenCtx<A: Aead + IndCca2> {
aead: A,
base_nonce: Nonce<A::NonceSize>,
seq: Seq,
}
impl<A: Aead + IndCca2> OpenCtx<A> {
pub const OVERHEAD: usize = A::OVERHEAD;
pub(crate) fn new(
key: &KeyData<A>,
base_nonce: &Nonce<A::NonceSize>,
seq: Seq,
) -> Result<Self, ImportError> {
let key = A::Key::import(key.as_bytes())?;
Ok(Self {
aead: A::new(&key),
base_nonce: base_nonce.clone(),
seq,
})
}
fn increment_seq(&mut self) -> Result<Seq, Bug> {
self.seq.increment::<A::NonceSize>()
}
pub fn open(
&mut self,
dst: &mut [u8],
ciphertext: &[u8],
additional_data: &[u8],
) -> Result<(), HpkeError> {
self.open_at(dst, ciphertext, additional_data, self.seq)?;
self.increment_seq()?;
Ok(())
}
pub fn open_at(
&self,
dst: &mut [u8],
ciphertext: &[u8],
additional_data: &[u8],
seq: Seq,
) -> Result<(), HpkeError> {
let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
self.aead.open(dst, &nonce, ciphertext, additional_data)?;
Ok(())
}
pub fn open_in_place(
&mut self,
mut data: impl AsMut<[u8]>,
tag: &[u8],
additional_data: &[u8],
) -> Result<(), HpkeError> {
self.open_in_place_at(data.as_mut(), tag, additional_data, self.seq)?;
self.increment_seq()?;
Ok(())
}
pub fn open_in_place_at(
&self,
mut data: impl AsMut<[u8]>,
tag: &[u8],
additional_data: &[u8],
seq: Seq,
) -> Result<(), HpkeError> {
let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
self.aead
.open_in_place(&nonce, data.as_mut(), tag, additional_data)?;
Ok(())
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct MessageLimitReached;
impl Display for MessageLimitReached {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("message limit reached")
}
}
impl core::error::Error for MessageLimitReached {}
#[derive(Copy, Clone, Debug, Default, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct Seq {
seq: u64,
}
impl Seq {
pub const ZERO: Self = Self::new(0);
#[inline]
pub const fn new(seq: u64) -> Self {
Self { seq }
}
#[inline]
pub const fn to_u64(self) -> u64 {
self.seq
}
pub(crate) const fn max<N: ArrayLength>() -> u64 {
let shift = 8usize.saturating_mul(N::USIZE);
match 1u64.checked_shl(shift as u32) {
Some(v) => v.saturating_sub(1),
None => u64::MAX,
}
}
fn increment<N: ArrayLength>(&mut self) -> Result<Self, Bug> {
if self.seq >= Self::max::<N>() {
bug!("`Seq::increment` called after limit reached");
}
let prev = self.seq;
self.seq = prev
.checked_add(1)
.assume("`Seq` overflow should be impossible")?;
Ok(Self { seq: prev })
}
fn compute_nonce<N: ArrayLength>(
self,
base_nonce: &Nonce<N>,
) -> Result<Nonce<N>, MessageLimitReached> {
if self.seq >= Self::max::<N>() {
Err(MessageLimitReached)
} else {
let seq_bytes = i2osp!(self.seq, N);
Ok(base_nonce ^ &Nonce::from_bytes(seq_bytes))
}
}
}
impl Display for Seq {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.seq)
}
}
struct ExportCtx<K: Kem, F: Kdf, A: Aead + IndCca2> {
exporter_secret: Prk<F::PrkSize>,
_etc: PhantomData<(K, A)>,
}
impl<K: Kem, F: Kdf, A: Aead + IndCca2> ExportCtx<K, F, A> {
fn new(exporter_secret: Prk<F::PrkSize>) -> Self {
Self {
exporter_secret,
_etc: PhantomData,
}
}
fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
where
T: Expand,
{
Hpke::<K, F, A>::labeled_expand(&self.exporter_secret, "sec", &[context])
}
fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
Hpke::<K, F, A>::labeled_expand_into(out, &self.exporter_secret, "sec", &[context])
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::panic)]
use std::{collections::HashSet, ops::RangeInclusive};
use postcard::experimental::max_size::MaxSize;
use typenum::{U1, U2};
use super::*;
#[test]
fn test_seq_compute_nonce() {
let base = Nonce::<U1>::try_from_slice(&[0xfe]).expect("should be able to create nonce");
let cases = [
(0, Ok(&[0xfe])),
(1, Ok(&[0xff])),
(2, Ok(&[0xfc])),
(4, Ok(&[0xfa])),
(254, Ok(&[0x00])),
(255, Err(MessageLimitReached)),
(256, Err(MessageLimitReached)),
(257, Err(MessageLimitReached)),
(u64::MAX, Err(MessageLimitReached)),
];
for (input, output) in cases {
let got = Seq::new(input).compute_nonce::<U1>(&base);
let want = output.map(|s| Nonce::try_from_slice(s).expect("unable to create nonce"));
assert_eq!(got, want, "seq = {input}");
}
}
#[test]
fn test_seq_unique_nonce() {
let base =
Nonce::<U2>::try_from_slice(&[0xfe, 0xfe]).expect("should be able to create nonce");
let mut seen = HashSet::new();
for v in 0..u16::MAX {
let got = Seq::new(u64::from(v))
.compute_nonce::<U2>(&base)
.expect("unable to create nonce");
assert!(seen.insert(got), "duplicate nonce: {got:?}");
}
}
#[test]
fn test_invalid_psk() {
let err = Psk::new(&[], &[]).expect_err("should get `InvalidPsk`");
assert_eq!(err, InvalidPsk);
}
#[test]
fn test_psk_ct_eq() {
let cases = [
(true, ("abc", "123"), ("abc", "123")),
(false, ("a", "b"), ("a", "x")),
(false, ("a", "b"), ("x", "b")),
(false, ("a", "b"), ("c", "d")),
];
for (pass, lhs, rhs) in cases {
let lhs = Psk::new(lhs.0.as_bytes(), lhs.1.as_bytes()).expect("should not fail");
let rhs = Psk::new(rhs.0.as_bytes(), rhs.1.as_bytes()).expect("should not fail");
assert_eq!(pass, bool::from(lhs.ct_eq(&rhs)));
}
}
#[test]
fn test_aead_id() {
let unassigned = 0x0004..=0xFFFE - 2;
for id in unassigned {
let want = AeadId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
let encoded = postcard::to_vec::<_, { u16::POSTCARD_MAX_SIZE }>(&id)
.expect("should be able to encode `u16`");
let got: AeadId = postcard::from_bytes(&encoded).unwrap_or_else(|err| {
panic!("should be able to decode unassigned `AeadId` {id}: {err}")
});
assert_eq!(got, want);
}
}
#[test]
fn test_aead_id_json() {
let unassigned = 0x0004..=0xFFFE - 2;
for id in unassigned {
let want = AeadId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
let encoded = serde_json::to_string(&id).expect("should be able to encode `u16`");
let got: AeadId = serde_json::from_str(&encoded).unwrap_or_else(|err| {
panic!("should be able to decode unassigned `AeadId` {id}: {err}")
});
assert_eq!(got, want);
}
}
#[test]
fn test_kdf_id() {
let unassigned = 0x0004..=0xFFFF;
for id in unassigned {
let want = KdfId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
let encoded = postcard::to_vec::<_, { u16::POSTCARD_MAX_SIZE }>(&id)
.expect("should be able to encode `u16`");
let got: KdfId = postcard::from_bytes(&encoded).unwrap_or_else(|err| {
panic!("should be able to decode unassigned `KdfId` {id}: {err}")
});
assert_eq!(got, want);
}
}
#[test]
fn test_kem_id() {
let unassigned: [RangeInclusive<u16>; 3] =
[0x0001..=0x000F, 0x0022..=0x002F, 0x0031..=0xFFFF];
for id in unassigned.into_iter().flatten() {
let want = KemId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
let encoded = postcard::to_vec::<_, { u16::POSTCARD_MAX_SIZE }>(&id)
.expect("should be able to encode `u16`");
let got: KemId = postcard::from_bytes(&encoded).unwrap_or_else(|err| {
panic!("should be able to decode unassigned `KemId` {id}: {err}")
});
assert_eq!(got, want);
}
}
}