use crate::StreamCipherCounter;
use super::{
OverflowError, SeekNum, StreamCipher, StreamCipherCore, StreamCipherSeek, StreamCipherSeekCore,
errors::StreamCipherError,
};
use block_buffer::{BlockSizes, ReadBuffer};
use common::{
Iv, IvSizeUser, Key, KeyInit, KeyIvInit, KeySizeUser, array::Array, typenum::Unsigned,
};
use core::fmt;
use inout::InOutBuf;
#[cfg(feature = "zeroize")]
use zeroize::ZeroizeOnDrop;
pub struct StreamCipherCoreWrapper<T>
where
T: StreamCipherCore,
T::BlockSize: BlockSizes,
{
core: T,
buffer: ReadBuffer<T::BlockSize>,
}
impl<T> Clone for StreamCipherCoreWrapper<T>
where
T: StreamCipherCore + Clone,
T::BlockSize: BlockSizes,
{
#[inline]
fn clone(&self) -> Self {
Self {
core: self.core.clone(),
buffer: self.buffer.clone(),
}
}
}
impl<T> fmt::Debug for StreamCipherCoreWrapper<T>
where
T: StreamCipherCore + fmt::Debug,
T::BlockSize: BlockSizes,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamCipherCoreWrapper")
.finish_non_exhaustive()
}
}
impl<T> StreamCipherCoreWrapper<T>
where
T: StreamCipherCore,
T::BlockSize: BlockSizes,
{
pub fn from_core(core: T) -> Self {
Self {
core,
buffer: Default::default(),
}
}
pub fn get_core(&self) -> &T {
&self.core
}
}
impl<T> StreamCipher for StreamCipherCoreWrapper<T>
where
T: StreamCipherCore,
T::BlockSize: BlockSizes,
{
#[inline]
fn check_remaining(&self, data_len: usize) -> Result<(), StreamCipherError> {
let Some(rem_blocks) = self.core.remaining_blocks() else {
return Ok(());
};
let Some(data_len) = data_len.checked_sub(self.buffer.remaining()) else {
return Ok(());
};
let req_blocks = data_len.div_ceil(T::BlockSize::USIZE);
if req_blocks > rem_blocks {
Err(StreamCipherError)
} else {
Ok(())
}
}
#[inline]
fn unchecked_apply_keystream_inout(&mut self, data: InOutBuf<'_, '_, u8>) {
let head_ks = self.buffer.read_cached(data.len());
let (mut head, data) = data.split_at(head_ks.len());
let (blocks, mut tail) = data.into_chunks();
head.xor_in2out(head_ks);
self.core.apply_keystream_blocks_inout(blocks);
self.buffer.write_block(
tail.len(),
|b| self.core.write_keystream_block(b),
|tail_ks| {
tail.xor_in2out(tail_ks);
},
);
}
#[inline]
fn unchecked_write_keystream(&mut self, data: &mut [u8]) {
let head_ks = self.buffer.read_cached(data.len());
let (head, data) = data.split_at_mut(head_ks.len());
let (blocks, tail) = Array::slice_as_chunks_mut(data);
head.copy_from_slice(head_ks);
self.core.write_keystream_blocks(blocks);
self.buffer.write_block(
tail.len(),
|b| self.core.write_keystream_block(b),
|tail_ks| tail.copy_from_slice(tail_ks),
);
}
}
impl<T> StreamCipherSeek for StreamCipherCoreWrapper<T>
where
T: StreamCipherSeekCore,
T::BlockSize: BlockSizes,
{
#[allow(clippy::unwrap_in_result)]
fn try_current_pos<SN: SeekNum>(&self) -> Result<SN, OverflowError> {
let pos = u8::try_from(self.buffer.get_pos())
.expect("buffer position is always smaller than 256");
SN::from_block_byte(self.core.get_block_pos(), pos, T::BlockSize::U8)
}
fn try_seek<SN: SeekNum>(&mut self, new_pos: SN) -> Result<(), StreamCipherError> {
let (block_pos, byte_pos) = new_pos.into_block_byte::<T::Counter>(T::BlockSize::U8)?;
if byte_pos != 0 && block_pos.is_max() {
return Err(StreamCipherError);
}
assert!(byte_pos < T::BlockSize::U8);
self.core.set_block_pos(block_pos);
self.buffer.reset();
self.buffer.write_block(
usize::from(byte_pos),
|b| self.core.write_keystream_block(b),
|_| {},
);
Ok(())
}
}
impl<T> KeySizeUser for StreamCipherCoreWrapper<T>
where
T: KeySizeUser + StreamCipherCore,
T::BlockSize: BlockSizes,
{
type KeySize = T::KeySize;
}
impl<T> IvSizeUser for StreamCipherCoreWrapper<T>
where
T: IvSizeUser + StreamCipherCore,
T::BlockSize: BlockSizes,
{
type IvSize = T::IvSize;
}
impl<T> KeyIvInit for StreamCipherCoreWrapper<T>
where
T: KeyIvInit + StreamCipherCore,
T::BlockSize: BlockSizes,
{
#[inline]
fn new(key: &Key<Self>, iv: &Iv<Self>) -> Self {
Self {
core: T::new(key, iv),
buffer: Default::default(),
}
}
}
impl<T> KeyInit for StreamCipherCoreWrapper<T>
where
T: KeyInit + StreamCipherCore,
T::BlockSize: BlockSizes,
{
#[inline]
fn new(key: &Key<Self>) -> Self {
Self {
core: T::new(key),
buffer: Default::default(),
}
}
}
#[cfg(feature = "zeroize")]
impl<T> ZeroizeOnDrop for StreamCipherCoreWrapper<T>
where
T: StreamCipherCore + ZeroizeOnDrop,
T::BlockSize: BlockSizes,
{
}
#[cfg(feature = "zeroize")]
const _: () = {
#[allow(dead_code, trivial_casts)]
fn check_buffer<BS: BlockSizes>(v: &ReadBuffer<BS>) {
let _ = v as &dyn ZeroizeOnDrop;
}
};