use crate::{
blake3::{Blake3, CoreBlake3, Digest},
Hasher as _,
};
use bytes::{Buf, BufMut};
use commonware_codec::{Error as CodecError, FixedSize, Read, ReadExt, Write};
const LTHASH_SIZE: usize = 2048;
const LTHASH_ELEMENTS: usize = LTHASH_SIZE / 2;
#[derive(Clone)]
pub struct LtHash {
state: [u16; LTHASH_ELEMENTS],
}
impl LtHash {
pub fn new() -> Self {
Self {
state: [0u16; LTHASH_ELEMENTS],
}
}
pub fn add(&mut self, data: &[u8]) {
let expanded = Self::expand_to_state(data);
for (i, val) in expanded.iter().enumerate() {
self.state[i] = self.state[i].wrapping_add(*val);
}
}
pub fn subtract(&mut self, data: &[u8]) {
let expanded = Self::expand_to_state(data);
for (i, val) in expanded.iter().enumerate() {
self.state[i] = self.state[i].wrapping_sub(*val);
}
}
pub fn combine(&mut self, other: &Self) {
for (i, val) in other.state.iter().enumerate() {
self.state[i] = self.state[i].wrapping_add(*val);
}
}
pub fn checksum(&self) -> Digest {
let mut hasher = Blake3::new();
for &val in &self.state {
hasher.update(&val.to_le_bytes());
}
hasher.finalize()
}
pub fn reset(&mut self) {
self.state = [0u16; LTHASH_ELEMENTS];
}
pub fn is_zero(&self) -> bool {
self.state.iter().all(|&val| val == 0)
}
fn expand_to_state(data: &[u8]) -> [u16; LTHASH_ELEMENTS] {
let mut result = [0u16; LTHASH_ELEMENTS];
let mut bytes = [0u8; LTHASH_SIZE];
let mut hasher = CoreBlake3::new();
hasher.update(data);
let mut output_reader = hasher.finalize_xof();
output_reader.fill(&mut bytes);
for (i, chunk) in bytes.chunks(2).enumerate() {
result[i] = u16::from_le_bytes([chunk[0], chunk[1]]);
}
result
}
}
impl Default for LtHash {
fn default() -> Self {
Self::new()
}
}
impl Write for LtHash {
fn write(&self, buf: &mut impl BufMut) {
for &val in &self.state {
val.write(buf);
}
}
}
impl Read for LtHash {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
let mut state = [0u16; LTHASH_ELEMENTS];
for val in state.iter_mut() {
*val = u16::read(buf)?;
}
Ok(Self { state })
}
}
impl FixedSize for LtHash {
const SIZE: usize = LTHASH_SIZE;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Hasher;
#[test]
fn test_new() {
let lthash = LtHash::new();
assert!(lthash.is_zero());
}
#[test]
fn test_add() {
let mut lthash = LtHash::new();
lthash.add(b"hello");
assert!(!lthash.is_zero());
}
#[test]
fn test_commutativity() {
let mut lthash1 = LtHash::new();
lthash1.add(b"hello");
lthash1.add(b"world");
let hash1 = lthash1.checksum();
let mut lthash2 = LtHash::new();
lthash2.add(b"world");
lthash2.add(b"hello");
let hash2 = lthash2.checksum();
assert_eq!(hash1, hash2);
}
#[test]
fn test_associativity() {
let mut lthash1 = LtHash::new();
lthash1.add(b"a");
lthash1.add(b"b");
lthash1.add(b"c");
let hash1 = lthash1.checksum();
let mut lthash2 = LtHash::new();
let mut temp = LtHash::new();
temp.add(b"b");
temp.add(b"c");
lthash2.add(b"a");
lthash2.combine(&temp);
let hash2 = lthash2.checksum();
assert_eq!(hash1, hash2);
}
#[test]
fn test_subtraction() {
let mut lthash1 = LtHash::new();
lthash1.add(b"hello");
let hash1 = lthash1.checksum();
let mut lthash2 = LtHash::new();
lthash2.add(b"hello");
lthash2.add(b"world");
lthash2.subtract(b"world");
let hash2 = lthash2.checksum();
assert_eq!(hash1, hash2);
}
#[test]
fn test_empty() {
let lthash = LtHash::new();
let empty_hash = lthash.checksum();
let mut hasher = Blake3::new();
for _ in 0..LTHASH_ELEMENTS {
hasher.update(&0u16.to_le_bytes());
}
let expected = hasher.finalize();
assert_eq!(empty_hash, expected);
}
#[test]
fn test_reset() {
let mut lthash = LtHash::new();
lthash.add(b"hello");
assert!(!lthash.is_zero());
lthash.reset();
assert!(lthash.is_zero());
}
#[test]
fn test_deterministic() {
let mut lthash = LtHash::new();
lthash.add(b"test");
let mut lthash2 = LtHash::new();
lthash2.add(b"test");
assert_eq!(lthash.checksum(), lthash2.checksum());
}
#[test]
fn test_large_data() {
let mut lthash = LtHash::new();
let large_data = vec![0xAB; 10000];
lthash.add(&large_data);
lthash.checksum();
}
#[test]
fn test_snake() {
let mut lthash1 = LtHash::new();
for i in 0..100u32 {
lthash1.add(&i.to_le_bytes());
}
let hash1 = lthash1.checksum();
let mut lthash2 = LtHash::new();
for i in (0..100u32).rev() {
lthash2.add(&i.to_le_bytes());
}
let hash2 = lthash2.checksum();
assert_eq!(hash1, hash2);
}
#[test]
fn test_codec() {
let mut lthash = LtHash::new();
lthash.add(b"hello");
let hash = lthash.checksum();
let mut buf = Vec::new();
lthash.write(&mut buf);
let lthash2 = LtHash::read_cfg(&mut &buf[..], &()).unwrap();
let hash2 = lthash2.checksum();
assert_eq!(hash, hash2);
}
}