use crate::{
errors::StreamCipherError, Block, OverflowError, SeekNum, StreamCipher, StreamCipherCore,
StreamCipherSeek, StreamCipherSeekCore,
};
use core::fmt;
use crypto_common::{typenum::Unsigned, Iv, IvSizeUser, Key, KeyInit, KeyIvInit, KeySizeUser};
use inout::InOutBuf;
#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};
pub struct StreamCipherCoreWrapper<T: StreamCipherCore> {
core: T,
buffer: Block<T>,
}
impl<T: StreamCipherCore + Default> Default for StreamCipherCoreWrapper<T> {
#[inline]
fn default() -> Self {
Self::from_core(T::default())
}
}
impl<T: StreamCipherCore + Clone> Clone for StreamCipherCoreWrapper<T> {
#[inline]
fn clone(&self) -> Self {
Self {
core: self.core.clone(),
buffer: self.buffer.clone(),
}
}
}
impl<T: StreamCipherCore + fmt::Debug> fmt::Debug for StreamCipherCoreWrapper<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let pos = self.get_pos().into();
let buf_data = &self.buffer[pos..];
f.debug_struct("StreamCipherCoreWrapper")
.field("core", &self.core)
.field("buffer_data", &buf_data)
.finish()
}
}
impl<T: StreamCipherCore> StreamCipherCoreWrapper<T> {
pub fn get_core(&self) -> &T {
&self.core
}
pub fn from_core(core: T) -> Self {
let mut buffer: Block<T> = Default::default();
buffer[0] = T::BlockSize::U8;
Self { core, buffer }
}
#[inline]
fn get_pos(&self) -> u8 {
let pos = self.buffer[0];
if pos == 0 || pos > T::BlockSize::U8 {
debug_assert!(false);
unsafe {
core::hint::unreachable_unchecked();
}
}
pos
}
#[inline]
unsafe fn set_pos_unchecked(&mut self, pos: usize) {
debug_assert!(pos != 0 && pos <= T::BlockSize::USIZE);
self.buffer[0] = pos as u8;
}
#[inline]
fn remaining(&self) -> u8 {
T::BlockSize::U8 - self.get_pos()
}
fn check_remaining(&self, data_len: usize) -> Result<(), StreamCipherError> {
let rem_blocks = match self.core.remaining_blocks() {
Some(v) => v,
None => return Ok(()),
};
let buf_rem = usize::from(self.remaining());
let data_len = match data_len.checked_sub(buf_rem) {
Some(0) | None => return Ok(()),
Some(res) => res,
};
let bs = T::BlockSize::USIZE;
let blocks = (data_len + bs - 1) / bs;
if blocks > rem_blocks {
Err(StreamCipherError)
} else {
Ok(())
}
}
}
impl<T: StreamCipherCore> StreamCipher for StreamCipherCoreWrapper<T> {
#[inline]
fn try_apply_keystream_inout(
&mut self,
mut data: InOutBuf<'_, '_, u8>,
) -> Result<(), StreamCipherError> {
self.check_remaining(data.len())?;
let pos = usize::from(self.get_pos());
let rem = usize::from(self.remaining());
let data_len = data.len();
if rem != 0 {
if data_len <= rem {
data.xor_in2out(&self.buffer[pos..][..data_len]);
unsafe {
self.set_pos_unchecked(pos + data_len);
}
return Ok(());
}
let (mut left, right) = data.split_at(rem);
data = right;
left.xor_in2out(&self.buffer[pos..]);
}
let (blocks, mut tail) = data.into_chunks();
self.core.apply_keystream_blocks_inout(blocks);
let new_pos = if tail.is_empty() {
T::BlockSize::USIZE
} else {
self.core.write_keystream_block(&mut self.buffer);
tail.xor_in2out(&self.buffer[..tail.len()]);
tail.len()
};
unsafe {
self.set_pos_unchecked(new_pos);
}
Ok(())
}
}
impl<T: StreamCipherSeekCore> StreamCipherSeek for StreamCipherCoreWrapper<T> {
fn try_current_pos<SN: SeekNum>(&self) -> Result<SN, OverflowError> {
let pos = self.get_pos();
SN::from_block_byte(self.core.get_block_pos(), pos, T::BlockSize::U8)
}
fn try_seek<SN: SeekNum>(&mut self, new_pos: SN) -> Result<(), StreamCipherError> {
let (block_pos, byte_pos) = new_pos.into_block_byte(T::BlockSize::U8)?;
assert!(byte_pos < T::BlockSize::U8);
self.core.set_block_pos(block_pos);
let new_pos = if byte_pos != 0 {
self.core.write_keystream_block(&mut self.buffer);
byte_pos.into()
} else {
T::BlockSize::USIZE
};
unsafe {
self.set_pos_unchecked(new_pos);
}
Ok(())
}
}
impl<T: KeySizeUser + StreamCipherCore> KeySizeUser for StreamCipherCoreWrapper<T> {
type KeySize = T::KeySize;
}
impl<T: IvSizeUser + StreamCipherCore> IvSizeUser for StreamCipherCoreWrapper<T> {
type IvSize = T::IvSize;
}
impl<T: KeyIvInit + StreamCipherCore> KeyIvInit for StreamCipherCoreWrapper<T> {
#[inline]
fn new(key: &Key<Self>, iv: &Iv<Self>) -> Self {
let mut buffer = Block::<T>::default();
buffer[0] = T::BlockSize::U8;
Self {
core: T::new(key, iv),
buffer,
}
}
}
impl<T: KeyInit + StreamCipherCore> KeyInit for StreamCipherCoreWrapper<T> {
#[inline]
fn new(key: &Key<Self>) -> Self {
let mut buffer = Block::<T>::default();
buffer[0] = T::BlockSize::U8;
Self {
core: T::new(key),
buffer,
}
}
}
#[cfg(feature = "zeroize")]
impl<T: StreamCipherCore> Drop for StreamCipherCoreWrapper<T> {
fn drop(&mut self) {
self.buffer.zeroize();
}
}
#[cfg(feature = "zeroize")]
impl<T: StreamCipherCore + ZeroizeOnDrop> ZeroizeOnDrop for StreamCipherCoreWrapper<T> {}