use alloc::vec::Vec;
use zeroize::Zeroize;
use crate::{tag_input, Call, Error};
pub trait Safe<T, const W: usize>
where
T: Default + Copy + Zeroize,
{
fn permute(&mut self, state: &mut [T; W]);
fn tag(&mut self, input: &[u8]) -> T;
fn add(&mut self, right: &T, left: &T) -> T;
fn initialized_state(tag: T) -> [T; W] {
let mut state = [T::default(); W];
state[0] = tag;
state
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Sponge<S, T, const W: usize>
where
S: Safe<T, W>,
T: Default + Copy + Zeroize,
{
state: [T; W],
pub(crate) safe: S,
pos_absorb: usize,
pos_squeeze: usize,
io_count: usize,
iopattern: Vec<Call>,
domain_sep: u64,
pub(crate) output: Vec<T>,
}
impl<S, T, const W: usize> Sponge<S, T, W>
where
S: Safe<T, W>,
T: Default + Copy + Zeroize,
{
const CAPACITY: usize = 1;
const RATE: usize = W - Self::CAPACITY;
pub fn start(
safe: S,
iopattern: impl Into<Vec<Call>>,
domain_sep: u64,
) -> Result<Self, Error> {
let iopattern: Vec<Call> = iopattern.into();
let mut safe = safe;
let tag = safe.tag(&tag_input(&iopattern, domain_sep)?);
let state = S::initialized_state(tag);
Ok(Self {
state,
safe,
pos_absorb: 0,
pos_squeeze: 0,
io_count: 0,
iopattern,
domain_sep,
output: Vec::new(),
})
}
pub fn finish(mut self) -> Result<Vec<T>, Error> {
let ret = match self.io_count == self.iopattern.len() {
true => Ok(self.output.clone()),
false => Err(Error::IOPatternViolation),
};
self.zeroize();
ret
}
pub fn absorb(
&mut self,
len: usize,
input: impl AsRef<[T]>,
) -> Result<(), Error> {
if input.as_ref().len() < len {
self.zeroize();
return Err(Error::TooFewInputElements);
}
match self.iopattern.get(self.io_count) {
Some(Call::Absorb(call_len)) if *call_len == len => {}
Some(Call::Absorb(_)) => {
self.zeroize();
return Err(Error::IOPatternViolation);
}
_ => {
self.zeroize();
return Err(Error::IOPatternViolation);
}
}
for element in input.as_ref().iter().take(len) {
if self.pos_absorb == Self::RATE {
self.safe.permute(&mut self.state);
self.pos_absorb = 0;
}
let pos = self.pos_absorb + Self::CAPACITY;
let previous_value = self.state[pos];
let sum = self.safe.add(&previous_value, element);
self.state[pos] = sum;
self.pos_absorb += 1;
}
self.pos_squeeze = Self::RATE;
self.io_count += 1;
Ok(())
}
pub fn squeeze(&mut self, len: usize) -> Result<(), Error> {
match self.iopattern.get(self.io_count) {
Some(Call::Squeeze(call_len)) if *call_len == len => {}
Some(Call::Squeeze(_)) => {
self.zeroize();
return Err(Error::IOPatternViolation);
}
_ => {
self.zeroize();
return Err(Error::IOPatternViolation);
}
}
for _ in 0..len {
if self.pos_squeeze == Self::RATE {
self.safe.permute(&mut self.state);
self.pos_squeeze = 0;
self.pos_absorb = 0;
}
self.output
.push(self.state[self.pos_squeeze + Self::CAPACITY]);
self.pos_squeeze += 1;
}
self.io_count += 1;
Ok(())
}
}
impl<S, T, const W: usize> Drop for Sponge<S, T, W>
where
S: Safe<T, W>,
T: Default + Copy + Zeroize,
{
fn drop(&mut self) {
self.zeroize();
}
}
impl<S, T, const W: usize> Zeroize for Sponge<S, T, W>
where
S: Safe<T, W>,
T: Default + Copy + Zeroize,
{
fn zeroize(&mut self) {
self.state.zeroize();
self.pos_absorb.zeroize();
self.pos_squeeze.zeroize();
self.output.zeroize();
}
}