mod bct;
mod bct_curlp;
use alloc::vec::Vec;
use bct::{BcTrit, BcTritArr, BcTritBuf};
use bct_curlp::BctCurlP;
use crate::{
encoding::ternary::{
raw::{RawEncoding, RawEncodingBuf},
Btrit, TritBuf,
},
hashes::ternary::HASH_LENGTH,
};
pub const BATCH_SIZE: usize = 8 * core::mem::size_of::<usize>();
const HIGH_BITS: usize = usize::MAX;
const NUM_ROUNDS: usize = 81;
pub struct CurlPBatchHasher<B: RawEncodingBuf> {
trit_inputs: Vec<TritBuf<B>>,
bct_inputs: BcTritBuf,
bct_hashes: BcTritArr<HASH_LENGTH>,
buf_demux: TritBuf,
bct_curlp: BctCurlP,
}
impl<B> CurlPBatchHasher<B>
where
B: RawEncodingBuf,
B::Slice: RawEncoding<Trit = Btrit>,
{
pub fn new(input_length: usize) -> Self {
Self {
trit_inputs: Vec::with_capacity(BATCH_SIZE),
bct_inputs: BcTritBuf::zeros(input_length),
bct_hashes: BcTritArr::<HASH_LENGTH>::zeros(),
buf_demux: TritBuf::zeros(HASH_LENGTH),
bct_curlp: BctCurlP::new(),
}
}
pub fn add(&mut self, input: TritBuf<B>) {
assert!(self.trit_inputs.len() < BATCH_SIZE, "Batch is full.");
assert_eq!(input.len(), self.bct_inputs.len(), "Input has an incorrect size.");
self.trit_inputs.push(input);
}
pub fn len(&self) -> usize {
self.trit_inputs.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn mux(&mut self) {
let count = self.trit_inputs.len();
for i in 0..self.bct_inputs.len() {
let BcTrit(lo, hi) = unsafe { self.bct_inputs.get_unchecked_mut(i) };
for j in 0..count {
match unsafe { self.trit_inputs.get_unchecked(j).get_unchecked(i) } {
Btrit::NegOne => *lo |= 1 << j,
Btrit::PlusOne => *hi |= 1 << j,
Btrit::Zero => {
*lo |= 1 << j;
*hi |= 1 << j;
}
}
}
}
}
fn demux(&mut self, index: usize) -> TritBuf {
for (bc_trit, btrit) in self.bct_hashes.iter().zip(self.buf_demux.iter_mut()) {
let lo = (bc_trit.lo() >> index) & 1;
let hi = (bc_trit.hi() >> index) & 1;
*btrit = match (lo, hi) {
(1, 0) => Btrit::NegOne,
(0, 1) => Btrit::PlusOne,
_ => Btrit::Zero,
};
}
self.buf_demux.clone()
}
pub fn hash(&mut self) -> impl Iterator<Item = TritBuf> + '_ {
let total = self.trit_inputs.len();
self.bct_curlp.reset();
self.mux();
self.bct_curlp.absorb(&self.bct_inputs);
self.bct_curlp.squeeze_into(&mut self.bct_hashes);
self.trit_inputs.clear();
self.bct_inputs.fill(0);
BatchedHashes {
hasher: self,
range: 0..total,
}
}
}
struct BatchedHashes<'a, B: RawEncodingBuf> {
hasher: &'a mut CurlPBatchHasher<B>,
range: core::ops::Range<usize>,
}
impl<'a, B> Iterator for BatchedHashes<'a, B>
where
B: RawEncodingBuf,
B::Slice: RawEncoding<Trit = Btrit>,
{
type Item = TritBuf;
fn next(&mut self) -> Option<Self::Item> {
let index = self.range.next()?;
Some(self.hasher.demux(index))
}
}