use super::{Block, Error};
use core::{fmt, slice};
use crypto_common::{BlockSizeUser, BlockSizes};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
pub struct ReadBuffer<BS: BlockSizes> {
buffer: Block<Self>,
}
impl<BS: BlockSizes> BlockSizeUser for ReadBuffer<BS> {
type BlockSize = BS;
}
impl<BS: BlockSizes> fmt::Debug for ReadBuffer<BS> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadBuffer")
.field("remaining_data", &self.get_pos())
.finish()
}
}
impl<BS: BlockSizes> Default for ReadBuffer<BS> {
#[inline]
fn default() -> Self {
let mut buffer = Block::<Self>::default();
buffer[0] = BS::U8;
Self { buffer }
}
}
impl<BS: BlockSizes> Clone for ReadBuffer<BS> {
#[inline]
fn clone(&self) -> Self {
Self {
buffer: self.buffer.clone(),
}
}
}
impl<BS: BlockSizes> ReadBuffer<BS> {
#[inline(always)]
pub fn get_pos(&self) -> usize {
let pos = self.buffer[0];
if pos == 0 || pos > BS::U8 {
debug_assert!(false);
unsafe {
core::hint::unreachable_unchecked();
}
}
pos as usize
}
#[inline(always)]
pub fn size(&self) -> usize {
BS::USIZE
}
#[inline(always)]
pub fn remaining(&self) -> usize {
self.size() - self.get_pos()
}
#[inline(always)]
fn set_pos_unchecked(&mut self, pos: usize) {
debug_assert!(pos <= BS::USIZE);
self.buffer[0] = pos as u8;
}
#[inline]
pub fn read(&mut self, mut data: &mut [u8], mut gen_block: impl FnMut(&mut Block<Self>)) {
let pos = self.get_pos();
let r = self.remaining();
let n = data.len();
if r != 0 {
if n < r {
data.copy_from_slice(&self.buffer[pos..][..n]);
self.set_pos_unchecked(pos + n);
return;
}
let (left, right) = data.split_at_mut(r);
data = right;
left.copy_from_slice(&self.buffer[pos..]);
}
let (blocks, leftover) = Self::to_blocks_mut(data);
for block in blocks {
gen_block(block);
}
let n = leftover.len();
if n != 0 {
let mut block = Default::default();
gen_block(&mut block);
leftover.copy_from_slice(&block[..n]);
self.buffer = block;
self.set_pos_unchecked(n);
} else {
self.set_pos_unchecked(BS::USIZE);
}
}
#[inline]
pub fn serialize(&self) -> Block<Self> {
let mut res = self.buffer.clone();
let pos = self.get_pos();
for b in res[1..pos].iter_mut() {
*b = 0;
}
res
}
#[inline]
pub fn deserialize(buffer: &Block<Self>) -> Result<Self, Error> {
let pos = buffer[0];
if pos == 0 || pos > BS::U8 || buffer[1..pos as usize].iter().any(|&b| b != 0) {
Err(Error)
} else {
Ok(Self {
buffer: buffer.clone(),
})
}
}
#[inline(always)]
fn to_blocks_mut(data: &mut [u8]) -> (&mut [Block<Self>], &mut [u8]) {
let nb = data.len() / BS::USIZE;
let (left, right) = data.split_at_mut(nb * BS::USIZE);
let p = left.as_mut_ptr() as *mut Block<Self>;
let blocks = unsafe { slice::from_raw_parts_mut(p, nb) };
(blocks, right)
}
}
#[cfg(feature = "zeroize")]
impl<BS: BlockSizes> Zeroize for ReadBuffer<BS> {
#[inline]
fn zeroize(&mut self) {
self.buffer.zeroize();
}
}