use oxicrypto_core::{Aead, CryptoError};
use super::ids::AeadId;
use super::labeled::HpkeKdf;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HpkeAead {
Aes128Gcm,
Aes256Gcm,
ChaCha20Poly1305,
ExportOnly,
}
impl HpkeAead {
#[must_use]
pub const fn from_id(id: AeadId) -> Self {
match id {
AeadId::Aes128Gcm => HpkeAead::Aes128Gcm,
AeadId::Aes256Gcm => HpkeAead::Aes256Gcm,
AeadId::ChaCha20Poly1305 => HpkeAead::ChaCha20Poly1305,
AeadId::ExportOnly => HpkeAead::ExportOnly,
}
}
#[must_use]
const fn is_export_only(self) -> bool {
matches!(self, HpkeAead::ExportOnly)
}
fn seal(
self,
key: &[u8],
nonce: &[u8],
aad: &[u8],
pt: &[u8],
ct_out: &mut [u8],
) -> Result<usize, CryptoError> {
match self {
HpkeAead::Aes128Gcm => oxicrypto_aead::Aes128Gcm.seal(key, nonce, aad, pt, ct_out),
HpkeAead::Aes256Gcm => oxicrypto_aead::Aes256Gcm.seal(key, nonce, aad, pt, ct_out),
HpkeAead::ChaCha20Poly1305 => {
oxicrypto_aead::ChaCha20Poly1305.seal(key, nonce, aad, pt, ct_out)
}
HpkeAead::ExportOnly => Err(CryptoError::UnsupportedAlgorithm),
}
}
fn open(
self,
key: &[u8],
nonce: &[u8],
aad: &[u8],
ct: &[u8],
pt_out: &mut [u8],
) -> Result<usize, CryptoError> {
match self {
HpkeAead::Aes128Gcm => oxicrypto_aead::Aes128Gcm.open(key, nonce, aad, ct, pt_out),
HpkeAead::Aes256Gcm => oxicrypto_aead::Aes256Gcm.open(key, nonce, aad, ct, pt_out),
HpkeAead::ChaCha20Poly1305 => {
oxicrypto_aead::ChaCha20Poly1305.open(key, nonce, aad, ct, pt_out)
}
HpkeAead::ExportOnly => Err(CryptoError::UnsupportedAlgorithm),
}
}
}
struct ContextInner {
aead: HpkeAead,
kdf: HpkeKdf,
suite_id: Vec<u8>,
key: Vec<u8>,
base_nonce: Vec<u8>,
exporter_secret: Vec<u8>,
seq: u128,
nn: usize,
nt: usize,
}
impl ContextInner {
fn compute_nonce(&self) -> Vec<u8> {
let seq_be = self.seq.to_be_bytes(); let mut nonce = self.base_nonce.clone();
let start = seq_be.len() - self.nn;
for i in 0..self.nn {
nonce[i] ^= seq_be[start + i];
}
nonce
}
fn check_no_overflow(&self) -> Result<(), CryptoError> {
let bits = 8usize.saturating_mul(self.nn);
if bits >= 128 {
return Ok(());
}
let limit = (1u128 << bits) - 1;
if self.seq >= limit {
return Err(CryptoError::Kex);
}
Ok(())
}
fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Result<Vec<u8>, CryptoError> {
if self.aead.is_export_only() {
return Err(CryptoError::UnsupportedAlgorithm);
}
self.check_no_overflow()?;
let nonce = self.compute_nonce();
let mut ct = vec![0u8; pt.len() + self.nt];
let written = self.aead.seal(&self.key, &nonce, aad, pt, &mut ct)?;
ct.truncate(written);
self.seq += 1;
Ok(ct)
}
fn open(&mut self, aad: &[u8], ct: &[u8]) -> Result<Vec<u8>, CryptoError> {
if self.aead.is_export_only() {
return Err(CryptoError::UnsupportedAlgorithm);
}
self.check_no_overflow()?;
if ct.len() < self.nt {
return Err(CryptoError::InvalidTag);
}
let nonce = self.compute_nonce();
let mut pt = vec![0u8; ct.len() - self.nt];
let written = self.aead.open(&self.key, &nonce, aad, ct, &mut pt)?;
pt.truncate(written);
self.seq += 1;
Ok(pt)
}
fn export(&self, exporter_context: &[u8], l: usize) -> Result<Vec<u8>, CryptoError> {
self.kdf.labeled_expand(
&self.suite_id,
&self.exporter_secret,
b"sec",
exporter_context,
l,
)
}
}
pub(crate) struct ContextConfig {
pub aead: HpkeAead,
pub kdf: HpkeKdf,
pub suite_id: Vec<u8>,
pub key: Vec<u8>,
pub base_nonce: Vec<u8>,
pub exporter_secret: Vec<u8>,
pub nn: usize,
pub nt: usize,
}
impl ContextInner {
fn from_config(config: ContextConfig) -> Self {
ContextInner {
aead: config.aead,
kdf: config.kdf,
suite_id: config.suite_id,
key: config.key,
base_nonce: config.base_nonce,
exporter_secret: config.exporter_secret,
seq: 0,
nn: config.nn,
nt: config.nt,
}
}
}
pub struct HpkeContextS {
inner: ContextInner,
}
impl HpkeContextS {
pub(crate) fn new(config: ContextConfig) -> Self {
Self {
inner: ContextInner::from_config(config),
}
}
pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Result<Vec<u8>, CryptoError> {
self.inner.seal(aad, pt)
}
pub fn export(&self, exporter_context: &[u8], l: usize) -> Result<Vec<u8>, CryptoError> {
self.inner.export(exporter_context, l)
}
#[must_use]
pub fn sequence_number(&self) -> u128 {
self.inner.seq
}
#[cfg(test)]
pub(crate) fn set_sequence_number(&mut self, seq: u128) {
self.inner.seq = seq;
}
}
pub struct HpkeContextR {
inner: ContextInner,
}
impl HpkeContextR {
pub(crate) fn new(config: ContextConfig) -> Self {
Self {
inner: ContextInner::from_config(config),
}
}
pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Result<Vec<u8>, CryptoError> {
self.inner.open(aad, ct)
}
pub fn export(&self, exporter_context: &[u8], l: usize) -> Result<Vec<u8>, CryptoError> {
self.inner.export(exporter_context, l)
}
#[must_use]
pub fn sequence_number(&self) -> u128 {
self.inner.seq
}
}
#[cfg(test)]
mod context_tests {
use super::*;
#[test]
fn export_only_aead_rejects_seal_open() {
assert_eq!(
HpkeAead::ExportOnly.seal(&[], &[], &[], &[], &mut []),
Err(CryptoError::UnsupportedAlgorithm)
);
assert_eq!(
HpkeAead::ExportOnly.open(&[], &[], &[], &[], &mut []),
Err(CryptoError::UnsupportedAlgorithm)
);
}
#[test]
fn from_id_mapping() {
assert_eq!(HpkeAead::from_id(AeadId::Aes128Gcm), HpkeAead::Aes128Gcm);
assert_eq!(HpkeAead::from_id(AeadId::Aes256Gcm), HpkeAead::Aes256Gcm);
assert_eq!(
HpkeAead::from_id(AeadId::ChaCha20Poly1305),
HpkeAead::ChaCha20Poly1305
);
assert_eq!(HpkeAead::from_id(AeadId::ExportOnly), HpkeAead::ExportOnly);
}
#[test]
fn compute_nonce_matches_xor() {
let base = vec![
0x56, 0xd8, 0x90, 0xe5, 0xac, 0xca, 0xaf, 0x01, 0x1c, 0xff, 0x4b, 0x7d,
];
let inner = ContextInner {
aead: HpkeAead::Aes128Gcm,
kdf: HpkeKdf::HkdfSha256,
suite_id: Vec::new(),
key: vec![0u8; 16],
base_nonce: base.clone(),
exporter_secret: vec![0u8; 32],
seq: 0,
nn: 12,
nt: 16,
};
assert_eq!(inner.compute_nonce(), base);
let mut inner1 = inner;
inner1.seq = 1;
let mut expected = base.clone();
let last = expected.len() - 1;
expected[last] ^= 0x01;
assert_eq!(inner1.compute_nonce(), expected);
}
}