mod bct;
mod bct_curlp;
use crate::ternary::sponge::{CurlP, CurlPRounds, Sponge, HASH_LENGTH};
use bct::{BcTrit, BcTritArr, BcTritBuf};
use bct_curlp::BctCurlP;
use bee_ternary::{
raw::{RawEncoding, RawEncodingBuf},
Btrit, T1B1Buf, TritBuf,
};
pub const BATCH_SIZE: usize = 8 * std::mem::size_of::<usize>();
const HIGH_BITS: usize = usize::max_value();
pub struct BatchHasher<B: RawEncodingBuf> {
trit_inputs: Vec<TritBuf<B>>,
bct_inputs: BcTritBuf,
bct_hashes: BcTritArr<HASH_LENGTH>,
buf_demux: TritBuf,
bct_curlp: BctCurlP,
curlp: CurlP,
}
impl<B> BatchHasher<B>
where
B: RawEncodingBuf,
B::Slice: RawEncoding<Trit = Btrit>,
{
pub fn new(input_length: usize, rounds: CurlPRounds) -> Self {
Self {
trit_inputs: Vec::with_capacity(BATCH_SIZE),
bct_inputs: BcTritBuf::zeros(input_length),
bct_hashes: BcTritArr::<HASH_LENGTH>::zeros(),
buf_demux: TritBuf::zeros(HASH_LENGTH),
bct_curlp: BctCurlP::new(rounds),
curlp: CurlP::new(rounds),
}
}
pub fn add(&mut self, input: TritBuf<B>) {
assert!(self.trit_inputs.len() < BATCH_SIZE, "Batch is full.");
assert_eq!(input.len(), self.bct_inputs.len(), "Input has an incorrect size.");
self.trit_inputs.push(input);
}
pub fn len(&self) -> usize {
self.trit_inputs.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn mux(&mut self) {
let count = self.trit_inputs.len();
for i in 0..self.bct_inputs.len() {
let BcTrit(lo, hi) = unsafe { self.bct_inputs.get_unchecked_mut(i) };
for j in 0..count {
match unsafe { self.trit_inputs.get_unchecked(j).get_unchecked(i) } {
Btrit::NegOne => *lo |= 1 << j,
Btrit::PlusOne => *hi |= 1 << j,
Btrit::Zero => {
*lo |= 1 << j;
*hi |= 1 << j;
}
}
}
}
}
fn demux(&mut self, index: usize) -> TritBuf {
for (bc_trit, btrit) in self.bct_hashes.iter().zip(self.buf_demux.iter_mut()) {
let lo = (bc_trit.lo() >> index) & 1;
let hi = (bc_trit.hi() >> index) & 1;
*btrit = match (lo, hi) {
(1, 0) => Btrit::NegOne,
(0, 1) => Btrit::PlusOne,
_ => Btrit::Zero,
};
}
self.buf_demux.clone()
}
pub fn hash_batched(&mut self) -> impl Iterator<Item = TritBuf> + '_ {
let total = self.trit_inputs.len();
self.bct_curlp.reset();
self.mux();
self.bct_curlp.absorb(&self.bct_inputs);
self.bct_curlp.squeeze_into(&mut self.bct_hashes);
self.trit_inputs.clear();
self.bct_inputs.fill(0);
BatchedHashes {
hasher: self,
range: 0..total,
}
}
pub fn hash_unbatched(&mut self) -> impl Iterator<Item = TritBuf> + '_ {
self.curlp.reset();
UnbatchedHashes {
curl: &mut self.curlp,
trit_inputs: self.trit_inputs.drain(..),
}
}
}
struct BatchedHashes<'a, B: RawEncodingBuf> {
hasher: &'a mut BatchHasher<B>,
range: std::ops::Range<usize>,
}
impl<'a, B> Iterator for BatchedHashes<'a, B>
where
B: RawEncodingBuf,
B::Slice: RawEncoding<Trit = Btrit>,
{
type Item = TritBuf;
fn next(&mut self) -> Option<Self::Item> {
let index = self.range.next()?;
Some(self.hasher.demux(index))
}
}
struct UnbatchedHashes<'a, B: RawEncodingBuf> {
curl: &'a mut CurlP,
trit_inputs: std::vec::Drain<'a, TritBuf<B>>,
}
impl<'a, B> Iterator for UnbatchedHashes<'a, B>
where
B: RawEncodingBuf,
B::Slice: RawEncoding<Trit = Btrit>,
{
type Item = TritBuf;
fn next(&mut self) -> Option<Self::Item> {
let buf = self.trit_inputs.next()?;
Some(self.curl.digest(&buf.encode::<T1B1Buf>()).unwrap())
}
}