use std::marker::PhantomData;
use digest::ExtendableOutput;
use crate::{
utils::{into_bytes, read_u32, HexDisplayRef32},
LtHash,
};
#[derive(Clone, Copy)]
pub struct LtHash32<H> {
pub(crate) checksum: [u32; 1024],
hasher: PhantomData<H>,
}
static_assertions::assert_impl_all!(LtHash32<()>: Send, Sync, Unpin);
impl<H> LtHash32<H> {
pub(crate) const fn name(&self) -> &'static str {
"LtHash32"
}
}
impl<H> LtHash32<H>
where
H: ExtendableOutput + Default,
{
#[inline(always)]
pub fn new() -> Self {
Self::default()
}
fn hash_object(&mut self, object: impl AsRef<[u8]>) -> [u8; 4096] {
let mut output = [0u8; 4096];
H::digest_xof(object, output.as_mut());
output
}
#[inline(always)]
fn display_hex_ref(&self) -> HexDisplayRef32<'_> {
HexDisplayRef32(&self.checksum[..])
}
}
impl<H> Default for LtHash32<H>
where
H: ExtendableOutput + Default,
{
#[inline(always)]
fn default() -> Self {
Self {
checksum: [0; 1024],
hasher: Default::default(),
}
}
}
impl<H> LtHash for LtHash32<H>
where
H: ExtendableOutput + Default,
{
fn insert(&mut self, element: impl AsRef<[u8]>) {
let hashed = self.hash_object(element);
let mut i = 0;
while i < 4096 {
let xi = &self.checksum[i / 4];
let yi = &hashed[i..i + 4];
let yi = read_u32(yi);
let sum = xi.wrapping_add(yi);
self.checksum[i / 4] = sum;
i += 4;
}
}
fn remove(&mut self, element: impl AsRef<[u8]>) {
let hashed = self.hash_object(element);
let mut i = 0;
while i < 4096 {
let xi = &self.checksum[i / 4];
let yi = &hashed[i..i + 4];
let yi = read_u32(yi);
let diff = xi.wrapping_sub(yi);
self.checksum[i / 4] = diff;
i += 4;
}
}
fn to_hex_string(&self) -> String {
self.display_hex_ref().to_string()
}
fn union(&self, rhs: &Self) -> Self {
let mut checksum = [0; 1024];
for (checksum, (&lhs, &rhs)) in checksum
.iter_mut()
.zip(self.checksum.iter().zip(rhs.checksum.iter()))
{
*checksum = lhs.wrapping_add(rhs);
}
Self {
checksum,
hasher: PhantomData,
}
}
fn difference(&self, rhs: &Self) -> Self {
let mut checksum = [0; 1024];
for (checksum, (&lhs, &rhs)) in checksum
.iter_mut()
.zip(self.checksum.iter().zip(rhs.checksum.iter()))
{
*checksum = lhs.wrapping_sub(rhs);
}
Self {
checksum,
hasher: PhantomData,
}
}
fn reset(&mut self) {
self.checksum.fill(0);
}
fn into_bytes(self) -> Vec<u8> {
into_bytes(self.checksum)
}
}
impl<H> TryFrom<&[u8]> for LtHash32<H> {
type Error = String;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() != 4096 {
return Err(String::from("Wrong number of bytes."));
}
let mut checksum = [0; 1024];
for (checksum, bytes) in checksum.iter_mut().zip(bytes.chunks_exact(4))
{
*checksum =
u32::from_le_bytes(bytes.try_into().map_err(|_| {
String::from("Error converting bytes to u32.")
})?);
}
Ok(Self {
checksum,
hasher: PhantomData,
})
}
}