use crate::{
kdf::{Kdf as KdfTrait, LabeledExpand, SimpleHkdf},
kem::Kem as KemTrait,
setup::ExporterSecret,
util::{enforce_equal_len, enforce_outbuf_len, full_suite_id, write_u64_be, FullSuiteId},
Deserializable, HpkeError, Serializable,
};
use core::{default::Default, marker::PhantomData};
use aead::{AeadCore as BaseAeadCore, AeadInPlace as BaseAeadInPlace, KeyInit as BaseKeyInit};
use generic_array::GenericArray;
use zeroize::Zeroize;
pub trait Aead {
#[doc(hidden)]
type AeadImpl: BaseAeadCore + BaseAeadInPlace + BaseKeyInit + Clone + Send + Sync;
const AEAD_ID: u16;
}
pub(crate) struct AeadNonce<A: Aead>(
pub(crate) GenericArray<u8, <A::AeadImpl as BaseAeadCore>::NonceSize>,
);
#[cfg(test)]
impl<A: Aead> Clone for AeadNonce<A> {
fn clone(&self) -> AeadNonce<A> {
AeadNonce(self.0.clone())
}
}
impl<A: Aead> Default for AeadNonce<A> {
fn default() -> AeadNonce<A> {
AeadNonce(GenericArray::<u8, <A::AeadImpl as BaseAeadCore>::NonceSize>::default())
}
}
impl<A: Aead> Drop for AeadNonce<A> {
fn drop(&mut self) {
self.0.zeroize();
}
}
pub(crate) struct AeadKey<A: Aead>(
pub(crate) GenericArray<u8, <A::AeadImpl as aead::KeySizeUser>::KeySize>,
);
impl<A: Aead> Default for AeadKey<A> {
fn default() -> AeadKey<A> {
AeadKey(GenericArray::<
u8,
<A::AeadImpl as aead::KeySizeUser>::KeySize,
>::default())
}
}
impl<A: Aead> Drop for AeadKey<A> {
fn drop(&mut self) {
self.0.zeroize();
}
}
#[derive(Clone, Default, Zeroize)]
#[zeroize(drop)]
struct Seq(u64);
fn increment_seq(seq: &Seq) -> Option<Seq> {
seq.0.checked_add(1).map(Seq)
}
fn mix_nonce<A: Aead>(base_nonce: &AeadNonce<A>, seq: &Seq) -> AeadNonce<A> {
let mut seq_buf = AeadNonce::<A>::default();
let seq_size = core::mem::size_of::<Seq>();
let nonce_size = base_nonce.0.len();
write_u64_be(&mut seq_buf.0[nonce_size - seq_size..], seq.0);
let new_nonce_iter = base_nonce
.0
.iter()
.zip(seq_buf.0.iter())
.map(|(nonce_byte, seq_byte)| nonce_byte ^ seq_byte);
AeadNonce(GenericArray::from_exact_iter(new_nonce_iter).unwrap())
}
#[derive(Clone)]
pub struct AeadTag<A: Aead>(GenericArray<u8, <A::AeadImpl as BaseAeadCore>::TagSize>);
impl<A: Aead> Default for AeadTag<A> {
fn default() -> AeadTag<A> {
AeadTag(GenericArray::<u8, <A::AeadImpl as BaseAeadCore>::TagSize>::default())
}
}
impl<A: Aead> Serializable for AeadTag<A> {
type OutputSize = <A::AeadImpl as BaseAeadCore>::TagSize;
fn write_exact(&self, buf: &mut [u8]) {
enforce_outbuf_len::<Self>(buf);
buf.copy_from_slice(&self.0);
}
}
impl<A: Aead> Deserializable for AeadTag<A> {
fn from_bytes(encoded: &[u8]) -> Result<Self, HpkeError> {
enforce_equal_len(Self::size(), encoded.len())?;
let mut arr = <GenericArray<u8, Self::OutputSize> as Default>::default();
arr.copy_from_slice(encoded);
Ok(AeadTag(arr))
}
}
pub(crate) struct AeadCtx<A: Aead, Kdf: KdfTrait, Kem: KemTrait> {
overflowed: bool,
encryptor: A::AeadImpl,
base_nonce: AeadNonce<A>,
exporter_secret: ExporterSecret<Kdf>,
seq: Seq,
src_kem: PhantomData<Kem>,
suite_id: FullSuiteId,
}
#[cfg(test)]
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> Clone for AeadCtx<A, Kdf, Kem> {
fn clone(&self) -> AeadCtx<A, Kdf, Kem> {
AeadCtx {
overflowed: self.overflowed,
encryptor: self.encryptor.clone(),
base_nonce: self.base_nonce.clone(),
exporter_secret: self.exporter_secret.clone(),
seq: self.seq.clone(),
src_kem: PhantomData,
suite_id: self.suite_id,
}
}
}
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> AeadCtx<A, Kdf, Kem> {
pub(crate) fn new(
key: &AeadKey<A>,
base_nonce: AeadNonce<A>,
exporter_secret: ExporterSecret<Kdf>,
) -> AeadCtx<A, Kdf, Kem> {
let suite_id = full_suite_id::<A, Kdf, Kem>();
AeadCtx {
overflowed: false,
encryptor: <A::AeadImpl as aead::KeyInit>::new(&key.0),
base_nonce,
exporter_secret,
seq: <Seq as Default>::default(),
src_kem: PhantomData,
suite_id,
}
}
pub fn export(&self, exporter_ctx: &[u8], out_buf: &mut [u8]) -> Result<(), HpkeError> {
let hkdf_ctx = SimpleHkdf::<Kdf>::from_prk(self.exporter_secret.0.as_slice()).unwrap();
hkdf_ctx
.labeled_expand(&self.suite_id, b"sec", exporter_ctx, out_buf)
.map_err(|_| HpkeError::KdfOutputTooLong)
}
#[cfg(test)]
pub(crate) fn current_nonce(&self) -> AeadNonce<A> {
mix_nonce::<A>(&self.base_nonce, &self.seq)
}
}
pub struct AeadCtxR<A: Aead, Kdf: KdfTrait, Kem: KemTrait>(AeadCtx<A, Kdf, Kem>);
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> From<AeadCtx<A, Kdf, Kem>> for AeadCtxR<A, Kdf, Kem> {
fn from(ctx: AeadCtx<A, Kdf, Kem>) -> AeadCtxR<A, Kdf, Kem> {
AeadCtxR(ctx)
}
}
#[cfg(test)]
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> Clone for AeadCtxR<A, Kdf, Kem> {
fn clone(&self) -> AeadCtxR<A, Kdf, Kem> {
self.0.clone().into()
}
}
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> AeadCtxR<A, Kdf, Kem> {
pub fn open_in_place_detached(
&mut self,
ciphertext: &mut [u8],
aad: &[u8],
tag: &AeadTag<A>,
) -> Result<(), HpkeError> {
if self.0.overflowed {
Err(HpkeError::MessageLimitReached)
} else {
let nonce = mix_nonce::<A>(&self.0.base_nonce, &self.0.seq);
let decrypt_res = self
.0
.encryptor
.decrypt_in_place_detached(&nonce.0, aad, ciphertext, &tag.0);
if decrypt_res.is_err() {
return Err(HpkeError::OpenError);
}
match increment_seq(&self.0.seq) {
Some(new_seq) => self.0.seq = new_seq,
None => self.0.overflowed = true,
}
Ok(())
}
}
#[cfg_attr(docsrs, doc(cfg(any(feature = "alloc", feature = "std"))))]
#[cfg(any(feature = "alloc", feature = "std"))]
pub fn open(&mut self, ciphertext: &[u8], aad: &[u8]) -> Result<crate::Vec<u8>, HpkeError> {
let tag_len = AeadTag::<A>::size();
let msg_len = ciphertext
.len()
.checked_sub(tag_len)
.ok_or(HpkeError::OpenError)?;
let (ciphertext, tag_slice) = ciphertext.split_at(msg_len);
let mut buf = ciphertext.to_vec();
let tag = {
let mut t = <AeadTag<A> as Default>::default();
t.0.copy_from_slice(tag_slice);
t
};
self.open_in_place_detached(&mut buf, aad, &tag)?;
Ok(buf)
}
pub fn export(&self, info: &[u8], out_buf: &mut [u8]) -> Result<(), HpkeError> {
self.0.export(info, out_buf)
}
}
pub struct AeadCtxS<A: Aead, Kdf: KdfTrait, Kem: KemTrait>(pub(crate) AeadCtx<A, Kdf, Kem>);
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> From<AeadCtx<A, Kdf, Kem>> for AeadCtxS<A, Kdf, Kem> {
fn from(ctx: AeadCtx<A, Kdf, Kem>) -> AeadCtxS<A, Kdf, Kem> {
AeadCtxS(ctx)
}
}
#[cfg(test)]
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> Clone for AeadCtxS<A, Kdf, Kem> {
fn clone(&self) -> AeadCtxS<A, Kdf, Kem> {
self.0.clone().into()
}
}
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> AeadCtxS<A, Kdf, Kem> {
pub fn seal_in_place_detached(
&mut self,
plaintext: &mut [u8],
aad: &[u8],
) -> Result<AeadTag<A>, HpkeError> {
if self.0.overflowed {
Err(HpkeError::MessageLimitReached)
} else {
let nonce = mix_nonce::<A>(&self.0.base_nonce, &self.0.seq);
let tag = self
.0
.encryptor
.encrypt_in_place_detached(&nonce.0, aad, plaintext)
.map_err(|_| HpkeError::SealError)?;
match increment_seq(&self.0.seq) {
Some(new_seq) => self.0.seq = new_seq,
None => self.0.overflowed = true,
}
Ok(AeadTag(tag))
}
}
#[cfg_attr(docsrs, doc(cfg(any(feature = "alloc", feature = "std"))))]
#[cfg(any(feature = "alloc", feature = "std"))]
pub fn seal(&mut self, plaintext: &[u8], aad: &[u8]) -> Result<crate::Vec<u8>, HpkeError> {
let msg_len = plaintext.len();
let tag_len = AeadTag::<A>::size();
let mut buf = vec![0u8; msg_len + tag_len];
buf[..msg_len].copy_from_slice(plaintext);
let tag = self.seal_in_place_detached(&mut buf[..plaintext.len()], aad)?;
buf[msg_len..msg_len + tag_len].copy_from_slice(&tag.0);
Ok(buf)
}
pub fn export(&self, info: &[u8], out_buf: &mut [u8]) -> Result<(), HpkeError> {
self.0.export(info, out_buf)
}
}
#[cfg(test)]
mod aes_gcm;
mod chacha20_poly1305;
mod export_only;
#[cfg(test)]
pub use crate::aead::aes_gcm::*;
#[doc(inline)]
pub use crate::aead::{chacha20_poly1305::*, export_only::*};
#[cfg(test)]
mod test {
use super::{AeadTag, AesGcm128, AesGcm256, ChaCha20Poly1305, ExportOnlyAead, Seq};
use crate::{
kdf::HkdfSha256, test_util::gen_ctx_simple_pair, Deserializable, HpkeError, Serializable,
};
macro_rules! test_invalid_nonce {
($test_name:ident, $aead_ty:ty) => {
#[test]
fn $test_name() {
type A = $aead_ty;
let tag_res = AeadTag::<A>::from_bytes(&[0; 5]);
if let Err(e) = tag_res {
assert_eq!(e, HpkeError::IncorrectInputLength(AeadTag::<A>::size(), 5));
} else {
panic!("AeadTag was unexpectedly valid");
}
}
};
}
#[cfg(any(feature = "alloc", feature = "std"))]
macro_rules! test_export_idempotence {
($test_name:ident, $kem_ty:ty) => {
#[test]
fn $test_name() {
type Kem = $kem_ty;
type Kdf = HkdfSha256;
type A = ChaCha20Poly1305;
let (mut sender_ctx, _) = gen_ctx_simple_pair::<A, Kdf, Kem>();
let mut secret1 = [0u8; 16];
sender_ctx
.export(b"test_export_idempotence", &mut secret1)
.unwrap();
let plaintext = b"back hand";
sender_ctx.seal(plaintext, b"").expect("seal() failed");
let mut secret2 = [0u8; 16];
sender_ctx
.export(b"test_export_idempotence", &mut secret2)
.unwrap();
assert_eq!(secret1, secret2);
}
};
}
#[cfg(any(feature = "alloc", feature = "std"))]
macro_rules! test_exportonly_panics {
($test_name1:ident, $test_name2:ident, $kem_ty:ty) => {
#[should_panic]
#[test]
fn $test_name1() {
type Kem = $kem_ty;
type Kdf = HkdfSha256;
type A = ExportOnlyAead;
let (mut sender_ctx, _) = gen_ctx_simple_pair::<A, Kdf, Kem>();
let plaintext = b"back hand";
let _ = sender_ctx.seal(plaintext, b"");
}
#[should_panic]
#[test]
fn $test_name2() {
type Kem = $kem_ty;
type Kdf = HkdfSha256;
type A = ExportOnlyAead;
let (_, mut receiver_ctx) = gen_ctx_simple_pair::<A, Kdf, Kem>();
let invalid_ciphertext = vec![0u8; 60];
let aad = b"with my prayers";
let _ = receiver_ctx.open(&invalid_ciphertext, aad);
}
};
}
#[cfg(any(feature = "alloc", feature = "std"))]
macro_rules! test_overflow {
($test_name:ident, $kem_ty:ty) => {
#[test]
fn $test_name() {
type Kem = $kem_ty;
type Kdf = HkdfSha256;
type A = ChaCha20Poly1305;
let big_seq = {
let mut seq = <Seq as Default>::default();
seq.0 = u64::MAX;
seq
};
let (mut sender_ctx, mut receiver_ctx) = gen_ctx_simple_pair::<A, Kdf, Kem>();
sender_ctx.0.seq = big_seq.clone();
receiver_ctx.0.seq = big_seq.clone();
let msg = b"draxx them sklounst";
let aad = b"you have to have the kebapi";
{
let mut buf = msg.clone();
let ciphertext = sender_ctx.seal(&mut buf, aad).expect("seal() failed");
let roundtrip_plaintext =
receiver_ctx.open(&ciphertext, aad).expect("open() failed");
assert_eq!(msg, roundtrip_plaintext.as_slice());
}
{
match sender_ctx.seal(msg, aad) {
Err(HpkeError::MessageLimitReached) => {
}
Err(e) => panic!("seal() should have overflowed. Instead got {}", e),
_ => panic!("seal() should have overflowed. Instead it succeeded"),
}
let placeholder_ciphertext = [0u8; 32];
match receiver_ctx.open(&placeholder_ciphertext, aad) {
Err(HpkeError::MessageLimitReached) => {
}
Err(e) => panic!("open() should have overflowed. Instead got {}", e),
_ => panic!("open() should have overflowed. Instead it succeeded"),
}
}
}
};
}
#[cfg(any(feature = "alloc", feature = "std"))]
macro_rules! test_ctx_correctness {
($test_name:ident, $aead_ty:ty, $kem_ty:ty) => {
#[test]
fn $test_name() {
type A = $aead_ty;
type Kdf = HkdfSha256;
type Kem = $kem_ty;
let (mut sender_ctx, mut receiver_ctx) = gen_ctx_simple_pair::<A, Kdf, Kem>();
let msg = b"Love it or leave it, you better gain way";
let aad = b"You better hit bull's eye, the kid don't play";
let ciphertext = sender_ctx.seal(msg, aad).expect("seal() failed");
assert_ne!(&ciphertext, msg);
let decrypted = receiver_ctx.open(&ciphertext, aad).expect("open() failed");
assert_eq!(&decrypted, msg);
let invalid_ciphertext = [0x00; 32];
assert!(receiver_ctx.open(&invalid_ciphertext, aad).is_err());
let ciphertext = sender_ctx.seal(msg, aad).expect("second seal() failed");
let decrypted = receiver_ctx
.open(&ciphertext, aad)
.expect("second open() failed");
assert_eq!(&decrypted, msg);
}
};
}
test_invalid_nonce!(test_invalid_nonce_aes128, AesGcm128);
test_invalid_nonce!(test_invalid_nonce_aes256, AesGcm128);
test_invalid_nonce!(test_invalid_nonce_chacha, ChaCha20Poly1305);
#[cfg(all(feature = "secp", any(feature = "alloc", feature = "std")))]
mod secp_tests {
use super::*;
test_export_idempotence!(test_export_idempotence_k256, crate::kem::SecpK256HkdfSha256);
test_exportonly_panics!(
test_exportonly_panics_k256_seal,
test_exportonly_panics_k256_open,
crate::kem::SecpK256HkdfSha256
);
test_overflow!(test_overflow_k256, crate::kem::SecpK256HkdfSha256);
test_ctx_correctness!(
test_ctx_correctness_aes128_k256,
AesGcm128,
crate::kem::SecpK256HkdfSha256
);
test_ctx_correctness!(
test_ctx_correctness_aes256_k256,
AesGcm256,
crate::kem::SecpK256HkdfSha256
);
test_ctx_correctness!(
test_ctx_correctness_chacha_k256,
ChaCha20Poly1305,
crate::kem::SecpK256HkdfSha256
);
}
#[should_panic]
#[test]
fn test_write_exact() {
let tag = AeadTag::<ChaCha20Poly1305>::default();
let mut buf = [0u8; 33];
tag.write_exact(&mut buf);
}
}