use aead::{AeadCore, AeadInOut, Generate, Key, KeyInit, Nonce, array::ArraySize};
use rand_core::CryptoRng;
use zerocopy::{BigEndian, FromBytes, Immutable, IntoBytes, KnownLayout, U64};
use crate::{
DecryptionError, EncryptionError,
types::segment::{NON_FINAL_SEGMENT_HEADER, Segment, SegmentMut},
};
#[derive(Debug, IntoBytes, Immutable, KnownLayout)]
#[repr(C)]
struct AssociatedData {
segment_number: U64<BigEndian>,
is_final: bool,
}
#[cfg_attr(feature = "zeroize", derive(zeroize::ZeroizeOnDrop))]
pub(crate) struct EpochKey<A>
where
A: AeadInOut + KeyInit,
{
pub(super) key: Key<A>,
pub(super) segment_number: u64,
pub(super) is_final: bool,
}
impl<A> EpochKey<A>
where
A: AeadInOut + KeyInit,
<<A as AeadCore>::TagSize as ArraySize>::ArrayType<u8>: FromBytes + Immutable,
<<A as AeadCore>::NonceSize as ArraySize>::ArrayType<u8>: FromBytes + Immutable,
{
pub(crate) fn encrypt_segment<R>(
self,
segment: SegmentMut<'_, A>,
rng: &mut R,
) -> Result<(), EncryptionError>
where
R: CryptoRng,
{
let plaintext_buffer = segment.ciphertext;
let header = Self::segment_header(plaintext_buffer.len(), self.is_final);
let nonce = Nonce::<A>::try_generate_from_rng(rng)
.map_err(|_| EncryptionError::NonceGenerationFailed)?;
let aead = A::new(&self.key);
let associated_data = self.associated_data();
let tag = aead.encrypt_inout_detached(
&nonce,
associated_data.as_bytes(),
plaintext_buffer.into(),
)?;
segment.header.set(header);
segment.nonce.copy_from_slice(&nonce);
segment.tag.copy_from_slice(&tag);
Ok(())
}
pub(crate) fn decrypt_segment(
self,
segment: &Segment<'_, A>,
buffer: &mut [u8],
) -> Result<(), DecryptionError> {
debug_assert_eq!(
segment.ciphertext().len(),
buffer.len(),
"The ciphertext and output buffer for the plaintext should have the same size"
);
let aead = A::new(&self.key);
let associated_data = self.associated_data();
buffer.copy_from_slice(segment.ciphertext());
Ok(aead.decrypt_inout_detached(
segment.nonce(),
associated_data.as_bytes(),
buffer.into(),
segment.tag(),
)?)
}
fn associated_data(&self) -> AssociatedData {
AssociatedData { segment_number: U64::new(self.segment_number), is_final: self.is_final }
}
fn segment_header(plaintext_buffer_length: usize, is_final: bool) -> u32 {
if is_final {
#[allow(clippy::expect_used)]
let final_segment_length =
plaintext_buffer_length.checked_add(Segment::<A>::overhead()).expect(
"Adding the length of the encrypted segment overhead \
to the length of the final segment shouldn't overflow",
);
#[allow(clippy::expect_used)]
let final_segment_length: u32 = final_segment_length
.try_into()
.expect("The length of the final encrypted segment should fit into 32 bits");
final_segment_length
} else {
NON_FINAL_SEGMENT_HEADER
}
}
}