use core::{cmp::Ordering, fmt};
use buggy::Bug;
use byteorder::{ByteOrder as _, LittleEndian};
pub use spideroak_crypto::hpke::MessageLimitReached;
use spideroak_crypto::{
aead,
hpke::{self, HpkeError, OpenCtx, SealCtx},
import::ImportError,
};
use crate::{
afc::shared::{RawOpenKey, RawSealKey},
ciphersuite::CipherSuite,
policy::LabelId,
};
#[derive(Copy, Clone, Debug, Default, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct Seq(hpke::Seq);
impl Seq {
pub const ZERO: Self = Self(hpke::Seq::ZERO);
pub const fn new(seq: u64) -> Self {
Self(hpke::Seq::new(seq))
}
pub const fn to_u64(&self) -> u64 {
self.0.to_u64()
}
#[cfg(any(test, feature = "test_util"))]
pub(crate) fn max<N: crate::generic_array::ArrayLength>() -> u64 {
hpke::Seq::max::<N>()
}
}
impl From<Seq> for u64 {
fn from(seq: Seq) -> Self {
seq.to_u64()
}
}
impl From<u64> for Seq {
fn from(seq: u64) -> Self {
Self::new(seq)
}
}
impl PartialEq<u64> for Seq {
fn eq(&self, other: &u64) -> bool {
PartialEq::eq(&self.to_u64(), other)
}
}
impl PartialOrd<u64> for Seq {
fn partial_cmp(&self, other: &u64) -> Option<Ordering> {
PartialOrd::partial_cmp(&self.to_u64(), other)
}
}
impl fmt::Display for Seq {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
macro_rules! packed {
(
$(#[$meta:meta])*
$vis:vis struct $name:ident $($tokens:tt)*
) => {
$(#[$meta])*
$vis struct $name $($tokens)*
impl $name {
$vis const PACKED_SIZE: usize = {
#[repr(C, packed)]
#[allow(dead_code)]
$vis struct $name $($tokens)*
::core::mem::size_of::<$name>()
};
}
};
}
packed! {
pub struct AuthData {
pub version: u32,
pub label_id: LabelId,
}
}
impl AuthData {
fn to_bytes(&self) -> [u8; Self::PACKED_SIZE] {
let mut b = [0u8; Self::PACKED_SIZE];
LittleEndian::write_u32(&mut b[0..4], self.version);
b[4..].copy_from_slice(self.label_id.as_bytes());
b
}
}
pub struct SealKey<CS: CipherSuite> {
ctx: SealCtx<CS::Aead>,
}
impl<CS: CipherSuite> SealKey<CS> {
pub const OVERHEAD: usize = SealCtx::<CS::Aead>::OVERHEAD;
pub fn from_raw(key: &RawSealKey<CS>, seq: Seq) -> Result<Self, ImportError> {
let RawSealKey { key, base_nonce } = key;
let ctx = SealCtx::new(key, base_nonce, seq.0)?;
Ok(Self { ctx })
}
pub fn seal(
&mut self,
dst: &mut [u8],
plaintext: &[u8],
ad: &AuthData,
) -> Result<Seq, SealError> {
let seq = self.ctx.seal(dst, plaintext, &ad.to_bytes())?;
Ok(Seq(seq))
}
pub fn seal_in_place(
&mut self,
data: impl AsMut<[u8]>,
tag: &mut [u8],
ad: &AuthData,
) -> Result<Seq, SealError> {
let seq = self.ctx.seal_in_place(data, tag, &ad.to_bytes())?;
Ok(Seq(seq))
}
#[inline]
pub fn seq(&self) -> Seq {
Seq(self.ctx.seq())
}
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum SealError {
#[error("message limit reached")]
MessageLimitReached,
#[error(transparent)]
Other(HpkeError),
#[error(transparent)]
Bug(#[from] Bug),
}
impl From<HpkeError> for SealError {
fn from(err: HpkeError) -> Self {
match err {
HpkeError::MessageLimitReached => Self::MessageLimitReached,
err => Self::Other(err),
}
}
}
pub struct OpenKey<CS: CipherSuite> {
ctx: OpenCtx<CS::Aead>,
}
impl<CS: CipherSuite> OpenKey<CS> {
pub const OVERHEAD: usize = OpenCtx::<CS::Aead>::OVERHEAD;
pub fn from_raw(key: &RawOpenKey<CS>) -> Result<Self, ImportError> {
let RawOpenKey { key, base_nonce } = key;
let ctx = OpenCtx::new(key, base_nonce, Seq::ZERO.0)?;
Ok(Self { ctx })
}
pub fn open(
&self,
dst: &mut [u8],
ciphertext: &[u8],
ad: &AuthData,
seq: Seq,
) -> Result<(), OpenError> {
self.ctx.open_at(dst, ciphertext, &ad.to_bytes(), seq.0)?;
Ok(())
}
pub fn open_in_place(
&self,
data: impl AsMut<[u8]>,
tag: &[u8],
ad: &AuthData,
seq: Seq,
) -> Result<(), OpenError> {
self.ctx
.open_in_place_at(data, tag, &ad.to_bytes(), seq.0)?;
Ok(())
}
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum OpenError {
#[error("authentication error")]
Authentication,
#[error("message limit reached")]
MessageLimitReached,
#[error(transparent)]
Other(HpkeError),
#[error(transparent)]
Bug(#[from] Bug),
}
impl From<HpkeError> for OpenError {
fn from(err: HpkeError) -> Self {
match err {
HpkeError::Open(aead::OpenError::Authentication) => Self::Authentication,
HpkeError::MessageLimitReached => Self::MessageLimitReached,
err => Self::Other(err),
}
}
}