use crate::{
kdf::{Kdf as KdfTrait, LabeledExpand},
kem::Kem as KemTrait,
kex::{Marshallable, Unmarshallable},
setup::ExporterSecret,
util::{full_suite_id, FullSuiteId},
HpkeError,
};
use core::{marker::PhantomData, u8};
use aead::{AeadInPlace as BaseAead, NewAead as BaseNewAead};
use digest::generic_array::GenericArray;
use hkdf::Hkdf;
pub trait Aead {
type AeadImpl: BaseAead + BaseNewAead + Clone;
const AEAD_ID: u16;
}
pub struct AesGcm128 {}
impl Aead for AesGcm128 {
type AeadImpl = aes_gcm::Aes128Gcm;
const AEAD_ID: u16 = 0x0001;
}
pub struct AesGcm256 {}
impl Aead for AesGcm256 {
type AeadImpl = aes_gcm::Aes256Gcm;
const AEAD_ID: u16 = 0x0002;
}
pub struct ChaCha20Poly1305 {}
impl Aead for ChaCha20Poly1305 {
type AeadImpl = chacha20poly1305::ChaCha20Poly1305;
const AEAD_ID: u16 = 0x0003;
}
fn increment_seq<A: Aead>(arr: &mut Seq<A>) -> Result<(), ()> {
let arr = arr.0.as_mut_slice();
for byte in arr.iter_mut().rev() {
if *byte < u8::MAX {
*byte += 1;
return Ok(());
} else {
*byte = 0;
}
}
Err(())
}
fn mix_nonce<A: Aead>(base_nonce: &AeadNonce<A>, seq: &Seq<A>) -> AeadNonce<A> {
let new_nonce_iter = base_nonce
.iter()
.zip(seq.0.iter())
.map(|(nonce_byte, seq_byte)| nonce_byte ^ seq_byte);
GenericArray::from_exact_iter(new_nonce_iter).unwrap()
}
pub(crate) type AeadNonce<A> = GenericArray<u8, <<A as Aead>::AeadImpl as BaseAead>::NonceSize>;
pub(crate) type AeadKey<A> = GenericArray<u8, <<A as Aead>::AeadImpl as aead::NewAead>::KeySize>;
struct Seq<A: Aead>(AeadNonce<A>);
impl<A: Aead> Default for Seq<A> {
fn default() -> Seq<A> {
Seq(<AeadNonce<A> as Default>::default())
}
}
#[cfg(test)]
impl<A: Aead> Clone for Seq<A> {
fn clone(&self) -> Seq<A> {
Seq(self.0.clone())
}
}
pub struct AeadTag<A: Aead>(GenericArray<u8, <A::AeadImpl as BaseAead>::TagSize>);
impl<A: Aead> Marshallable for AeadTag<A> {
type OutputSize = <A::AeadImpl as BaseAead>::TagSize;
fn marshal(&self) -> GenericArray<u8, Self::OutputSize> {
self.0.clone()
}
}
impl<A: Aead> Unmarshallable for AeadTag<A> {
fn unmarshal(encoded: &[u8]) -> Result<Self, HpkeError> {
if encoded.len() != Self::size() {
Err(HpkeError::InvalidEncoding)
} else {
let mut arr = <GenericArray<u8, Self::OutputSize> as Default>::default();
arr.copy_from_slice(encoded);
Ok(AeadTag(arr))
}
}
}
pub(crate) struct AeadCtx<A: Aead, Kdf: KdfTrait, Kem: KemTrait> {
overflowed: bool,
encryptor: A::AeadImpl,
nonce: AeadNonce<A>,
exporter_secret: ExporterSecret<Kdf>,
seq: Seq<A>,
src_kem: PhantomData<Kem>,
suite_id: FullSuiteId,
}
#[cfg(test)]
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> Clone for AeadCtx<A, Kdf, Kem> {
fn clone(&self) -> AeadCtx<A, Kdf, Kem> {
AeadCtx {
overflowed: self.overflowed,
encryptor: self.encryptor.clone(),
nonce: self.nonce.clone(),
exporter_secret: self.exporter_secret.clone(),
seq: self.seq.clone(),
src_kem: PhantomData,
suite_id: self.suite_id.clone(),
}
}
}
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> AeadCtx<A, Kdf, Kem> {
pub(crate) fn new(
key: &AeadKey<A>,
nonce: AeadNonce<A>,
exporter_secret: ExporterSecret<Kdf>,
) -> AeadCtx<A, Kdf, Kem> {
let suite_id = full_suite_id::<A, Kdf, Kem>();
AeadCtx {
overflowed: false,
encryptor: <A::AeadImpl as aead::NewAead>::new(key),
nonce,
exporter_secret,
seq: <Seq<A> as Default>::default(),
src_kem: PhantomData,
suite_id,
}
}
pub fn export(&self, exporter_ctx: &[u8], out_buf: &mut [u8]) -> Result<(), HpkeError> {
let hkdf_ctx = Hkdf::<Kdf::HashImpl>::from_prk(self.exporter_secret.as_slice()).unwrap();
hkdf_ctx
.labeled_expand(&self.suite_id, b"sec", exporter_ctx, out_buf)
.map_err(|_| HpkeError::InvalidKdfLength)
}
}
pub struct AeadCtxR<A: Aead, Kdf: KdfTrait, Kem: KemTrait>(AeadCtx<A, Kdf, Kem>);
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> From<AeadCtx<A, Kdf, Kem>> for AeadCtxR<A, Kdf, Kem> {
fn from(ctx: AeadCtx<A, Kdf, Kem>) -> AeadCtxR<A, Kdf, Kem> {
AeadCtxR(ctx)
}
}
#[cfg(test)]
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> Clone for AeadCtxR<A, Kdf, Kem> {
fn clone(&self) -> AeadCtxR<A, Kdf, Kem> {
self.0.clone().into()
}
}
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> AeadCtxR<A, Kdf, Kem> {
pub fn open(
&mut self,
ciphertext: &mut [u8],
aad: &[u8],
tag: &AeadTag<A>,
) -> Result<(), HpkeError> {
if self.0.overflowed {
Err(HpkeError::SeqOverflow)
} else {
let nonce = mix_nonce(&self.0.nonce, &self.0.seq);
let decrypt_res = self
.0
.encryptor
.decrypt_in_place_detached(&nonce, &aad, ciphertext, &tag.0);
if decrypt_res.is_err() {
return Err(HpkeError::InvalidTag);
}
if increment_seq(&mut self.0.seq).is_err() {
self.0.overflowed = true;
}
Ok(())
}
}
pub fn export(&self, info: &[u8], out_buf: &mut [u8]) -> Result<(), HpkeError> {
self.0.export(info, out_buf)
}
}
pub struct AeadCtxS<A: Aead, Kdf: KdfTrait, Kem: KemTrait>(AeadCtx<A, Kdf, Kem>);
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> From<AeadCtx<A, Kdf, Kem>> for AeadCtxS<A, Kdf, Kem> {
fn from(ctx: AeadCtx<A, Kdf, Kem>) -> AeadCtxS<A, Kdf, Kem> {
AeadCtxS(ctx)
}
}
#[cfg(test)]
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> Clone for AeadCtxS<A, Kdf, Kem> {
fn clone(&self) -> AeadCtxS<A, Kdf, Kem> {
self.0.clone().into()
}
}
impl<A: Aead, Kdf: KdfTrait, Kem: KemTrait> AeadCtxS<A, Kdf, Kem> {
pub fn seal(&mut self, plaintext: &mut [u8], aad: &[u8]) -> Result<AeadTag<A>, HpkeError> {
if self.0.overflowed {
Err(HpkeError::SeqOverflow)
} else {
let nonce = mix_nonce(&self.0.nonce, &self.0.seq);
let tag_res = self
.0
.encryptor
.encrypt_in_place_detached(&nonce, &aad, plaintext);
let tag = match tag_res {
Err(_) => return Err(HpkeError::Encryption),
Ok(t) => t,
};
if increment_seq(&mut self.0.seq).is_err() {
self.0.overflowed = true;
}
Ok(AeadTag(tag))
}
}
pub fn export(&self, info: &[u8], out_buf: &mut [u8]) -> Result<(), HpkeError> {
self.0.export(info, out_buf)
}
}
#[cfg(test)]
mod test {
use super::{AeadTag, AesGcm128, AesGcm256, ChaCha20Poly1305, Seq};
use crate::{kdf::HkdfSha256, kex::Unmarshallable, test_util::gen_ctx_simple_pair, HpkeError};
use core::u8;
macro_rules! test_export_idempotence {
($test_name:ident, $kem_ty:ty) => {
#[test]
fn $test_name() {
type Kem = $kem_ty;
type Kdf = HkdfSha256;
type A = ChaCha20Poly1305;
let (mut sender_ctx, _) = gen_ctx_simple_pair::<A, Kdf, Kem>();
let mut secret1 = [0u8; 16];
sender_ctx
.export(b"test_export_idempotence", &mut secret1)
.unwrap();
let mut plaintext = *b"back hand";
sender_ctx
.seal(&mut plaintext[..], b"")
.expect("seal() failed");
let mut secret2 = [0u8; 16];
sender_ctx
.export(b"test_export_idempotence", &mut secret2)
.unwrap();
assert_eq!(secret1, secret2);
}
};
}
macro_rules! test_overflow {
($test_name:ident, $kem_ty:ty) => {
#[test]
fn $test_name() {
type Kem = $kem_ty;
type Kdf = HkdfSha256;
type A = ChaCha20Poly1305;
let big_seq = {
let mut buf = <Seq<A> as Default>::default();
for byte in buf.0.iter_mut() {
*byte = u8::MAX;
}
buf
};
let (mut sender_ctx, mut receiver_ctx) = gen_ctx_simple_pair::<A, Kdf, Kem>();
sender_ctx.0.seq = big_seq.clone();
receiver_ctx.0.seq = big_seq.clone();
let msg = b"draxx them sklounst";
let aad = b"with my prayers";
{
let mut plaintext = *msg;
let tag = sender_ctx
.seal(&mut plaintext[..], aad)
.expect("seal() failed");
let mut ciphertext = plaintext;
receiver_ctx
.open(&mut ciphertext[..], aad, &tag)
.expect("open() failed");
let roundtrip_plaintext = ciphertext;
assert_eq!(msg, &roundtrip_plaintext);
}
{
let mut plaintext = *msg;
match sender_ctx.seal(&mut plaintext[..], aad) {
Err(HpkeError::SeqOverflow) => {}
Err(e) => panic!("seal() should have overflowed. Instead got {}", e),
_ => panic!("seal() should have overflowed. Instead it succeeded"),
}
let mut dummy_ciphertext = [0u8; 32];
let dummy_tag = AeadTag::unmarshal(&[0; 16]).unwrap();
match receiver_ctx.open(&mut dummy_ciphertext[..], aad, &dummy_tag) {
Err(HpkeError::SeqOverflow) => {}
Err(e) => panic!("open() should have overflowed. Instead got {}", e),
_ => panic!("open() should have overflowed. Instead it succeeded"),
}
}
}
};
}
macro_rules! test_ctx_correctness {
($test_name:ident, $aead_ty:ty, $kem_ty:ty) => {
#[test]
fn $test_name() {
type A = $aead_ty;
type Kdf = HkdfSha256;
type Kem = $kem_ty;
let (mut sender_ctx, mut receiver_ctx) = gen_ctx_simple_pair::<A, Kdf, Kem>();
let msg = b"Love it or leave it, you better gain way";
let aad = b"You better hit bull's eye, the kid don't play";
let mut ciphertext = msg.clone();
let tag = sender_ctx
.seal(&mut ciphertext[..], aad)
.expect("seal() failed");
assert!(&ciphertext[..] != &msg[..]);
receiver_ctx
.open(&mut ciphertext[..], aad, &tag)
.expect("open() failed");
let decrypted = ciphertext;
assert_eq!(&decrypted[..], &msg[..]);
}
};
}
#[cfg(feature = "x25519-dalek")]
test_export_idempotence!(test_export_idempotence_x25519, crate::kem::X25519HkdfSha256);
#[cfg(feature = "p256")]
test_export_idempotence!(test_export_idempotence_p256, crate::kem::DhP256HkdfSha256);
#[cfg(feature = "x25519-dalek")]
test_overflow!(test_overflow_x25519, crate::kem::X25519HkdfSha256);
#[cfg(feature = "p256")]
test_overflow!(test_overflow_p256, crate::kem::DhP256HkdfSha256);
#[cfg(feature = "x25519-dalek")]
test_ctx_correctness!(
test_ctx_correctness_aes128_x25519,
AesGcm128,
crate::kem::X25519HkdfSha256
);
#[cfg(feature = "p256")]
test_ctx_correctness!(
test_ctx_correctness_aes128_p256,
AesGcm128,
crate::kem::DhP256HkdfSha256
);
#[cfg(feature = "x25519-dalek")]
test_ctx_correctness!(
test_ctx_correctness_aes256_x25519,
AesGcm256,
crate::kem::X25519HkdfSha256
);
#[cfg(feature = "p256")]
test_ctx_correctness!(
test_ctx_correctness_aes256_p256,
AesGcm256,
crate::kem::DhP256HkdfSha256
);
#[cfg(feature = "x25519-dalek")]
test_ctx_correctness!(
test_ctx_correctness_chacha_x25519,
ChaCha20Poly1305,
crate::kem::X25519HkdfSha256
);
#[cfg(feature = "p256")]
test_ctx_correctness!(
test_ctx_correctness_chacha_p256,
ChaCha20Poly1305,
crate::kem::DhP256HkdfSha256
);
}