use core::ops::Sub;
use aead::{AeadCore, Key, KeySizeUser, array::ArraySize, consts::U32};
use digest::OutputSizeUser;
#[cfg(feature = "getrandom")]
use getrandom::SysRng;
use rand_core::CryptoRng;
#[cfg(feature = "getrandom")]
use rand_core::UnwrapErr;
use zerocopy::{FromBytes, Immutable, IntoBytes, Unaligned};
use crate::{
EncryptionError, FloeAead, FloeKdf,
keys::{FloeKey, MessageKey},
types::{Header, floe_iv::FloeIv, segment::SegmentMut},
utils::{check_segment_size, plaintext_size},
};
pub struct FloeEncryptor<'a, A, K, const N: usize, const S: u32>
where
A: FloeAead,
K: FloeKdf,
{
header: Header<N>,
message_key: MessageKey<A, K>,
associated_data: &'a [u8],
}
impl<'a, A, K, const N: usize, const S: u32> FloeEncryptor<'a, A, K, N, S>
where
A: FloeAead,
K: FloeKdf,
<<A as AeadCore>::TagSize as ArraySize>::ArrayType<u8>: FromBytes + Immutable + IntoBytes,
<<A as AeadCore>::NonceSize as ArraySize>::ArrayType<u8>:
FromBytes + Immutable + IntoBytes + Unaligned,
<K as OutputSizeUser>::OutputSize: Sub<<A as KeySizeUser>::KeySize>,
<<K as OutputSizeUser>::OutputSize as Sub<<A as KeySizeUser>::KeySize>>::Output: ArraySize,
<K as OutputSizeUser>::OutputSize: Sub<U32>,
<<K as OutputSizeUser>::OutputSize as Sub<U32>>::Output: ArraySize,
<K as OutputSizeUser>::OutputSize: Sub<<K as FloeKdf>::KeySize>,
<<K as OutputSizeUser>::OutputSize as Sub<<K as FloeKdf>::KeySize>>::Output: ArraySize,
{
#[cfg(feature = "getrandom")]
pub fn new(key: &Key<A>, associated_data: &'a [u8]) -> Self {
#[allow(clippy::expect_used)]
Self::with_rng(key, associated_data, &mut UnwrapErr(SysRng))
.expect("should be able to generate enough randomness for the Floe IV")
}
pub fn with_rng<R: CryptoRng>(
key: &Key<A>,
associated_data: &'a [u8],
rng: &mut R,
) -> Result<Self, R::Error> {
check_segment_size::<A, S>();
let floe_key = FloeKey::new(key);
let floe_iv = FloeIv::generate(rng)?;
let header_tag = floe_key.derive_header_tag::<N, S>(&floe_iv, associated_data);
let message_key = floe_key.derive_message_key::<N, S>(&floe_iv, associated_data);
let header = Header::new::<A, K, S>(floe_iv, header_tag);
Ok(Self { message_key, header, associated_data })
}
pub fn input_size(&self) -> usize {
plaintext_size::<A, S>()
}
pub fn output_size(&self, plaintext: &[u8]) -> usize {
assert!(
plaintext.len() <= self.input_size(),
"The plaintext size can't be bigger than the input size"
);
SegmentMut::<A>::output_size(plaintext)
}
pub fn header(&self) -> &Header<N> {
&self.header
}
#[cfg(feature = "getrandom")]
pub fn encrypt_segment(
&self,
plaintext: &[u8],
buffer: &mut [u8],
segment_number: u64,
is_final: bool,
) -> Result<(), EncryptionError> {
let mut rng = UnwrapErr(SysRng);
self.encrypt_segment_with_rng(plaintext, buffer, segment_number, is_final, &mut rng)
}
pub fn encrypt_segment_with_rng<R>(
&self,
plaintext: &[u8],
buffer: &mut [u8],
segment_number: u64,
is_final: bool,
rng: &mut R,
) -> Result<(), EncryptionError>
where
R: CryptoRng,
{
let allowed_plaintext_length = self.input_size();
let plaintext_length = plaintext.len();
if is_final {
if plaintext_length > allowed_plaintext_length {
return Err(EncryptionError::InvalidPlaintextLength {
expected: allowed_plaintext_length,
got: plaintext_length,
});
}
if segment_number >= A::AEAD_MAX_SEGMENTS {
return Err(EncryptionError::MaxSegmentsReached(A::AEAD_MAX_SEGMENTS));
}
} else {
if plaintext_length != allowed_plaintext_length {
return Err(EncryptionError::InvalidPlaintextLength {
expected: allowed_plaintext_length,
got: plaintext_length,
});
}
if segment_number >= (A::AEAD_MAX_SEGMENTS - 1) {
return Err(EncryptionError::MaxSegmentsReached(A::AEAD_MAX_SEGMENTS));
}
}
let segment = SegmentMut::from_buffer_and_plaintext(plaintext, buffer)?;
let epoch_key = self.message_key.derive_epoch_key::<N, S>(
self.header.iv(),
self.associated_data,
segment_number,
is_final,
);
epoch_key.encrypt_segment(segment, rng)
}
}