use commonware_utils::GOLDEN_RATIO;
use core::hash::{BuildHasher, Hash, Hasher};
pub trait Translator: Clone + BuildHasher + Send + Sync + 'static {
type Key: Ord + Hash + Copy + Send + Sync;
fn transform(&self, key: &[u8]) -> Self::Key;
}
#[derive(Default, Clone)]
pub struct UintIdentity {
value: u64,
}
impl Hasher for UintIdentity {
#[inline]
fn write(&mut self, bytes: &[u8]) {
assert!(bytes.len() <= 8, "UintIdenty hasher cannot handle >8 bytes");
self.value = u64::from_le_bytes(cap::<8>(bytes));
}
#[inline]
fn write_u8(&mut self, i: u8) {
self.value = i as u64;
}
#[inline]
fn write_u16(&mut self, i: u16) {
self.value = i as u64;
}
#[inline]
fn write_u32(&mut self, i: u32) {
self.value = i as u64;
}
#[inline]
fn write_u64(&mut self, i: u64) {
self.value = i;
}
#[inline]
fn finish(&self) -> u64 {
self.value.wrapping_mul(GOLDEN_RATIO)
}
}
fn cap<const N: usize>(key: &[u8]) -> [u8; N] {
let mut capped = [0; N];
let len = key.len().min(N);
capped[..len].copy_from_slice(&key[..len]);
capped
}
macro_rules! define_cap_translator {
($name:ident, $size:expr, $int:ty) => {
#[doc = concat!("Translator that caps the key to ", stringify!($size), " byte(s) and returns it packed in a ", stringify!($int), ".")]
#[derive(Clone, Default)]
pub struct $name;
impl Translator for $name {
type Key = $int;
#[inline]
fn transform(&self, key: &[u8]) -> Self::Key {
let capped = cap::<$size>(key);
<$int>::from_be_bytes(capped)
}
}
impl BuildHasher for $name {
type Hasher = UintIdentity;
#[inline]
fn build_hasher(&self) -> Self::Hasher {
UintIdentity::default()
}
}
};
}
define_cap_translator!(OneCap, 1, u8);
define_cap_translator!(TwoCap, 2, u16);
define_cap_translator!(FourCap, 4, u32);
define_cap_translator!(EightCap, 8, u64);
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct UnhashedArray<const N: usize> {
pub inner: [u8; N],
}
impl<const N: usize> Hash for UnhashedArray<N> {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
state.write(&self.inner);
}
}
impl<const N: usize> PartialEq<[u8; N]> for UnhashedArray<N> {
fn eq(&self, other: &[u8; N]) -> bool {
&self.inner == other
}
}
#[derive(Clone, Copy)]
pub struct Cap<const N: usize>;
impl<const N: usize> Cap<N> {
pub const fn new() -> Self {
const {
assert!(N <= 8 && N > 0, "Cap must be between 1 and 8");
};
Self
}
}
impl<const N: usize> Default for Cap<N> {
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> Translator for Cap<N> {
type Key = UnhashedArray<N>;
#[inline]
fn transform(&self, key: &[u8]) -> Self::Key {
const {
assert!(N <= 8 && N > 0, "Cap must be between 1 and 8");
};
UnhashedArray {
inner: cap::<N>(key),
}
}
}
impl<const N: usize> BuildHasher for Cap<N> {
type Hasher = UintIdentity;
#[inline]
fn build_hasher(&self) -> Self::Hasher {
UintIdentity::default()
}
}
#[derive(Clone)]
pub struct Hashed<T: Translator> {
random_state: ahash::RandomState,
inner: T,
}
#[cfg(feature = "std")]
impl<T: Translator + Default> Default for Hashed<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T: Translator> Hashed<T> {
#[cfg(feature = "std")]
pub fn new(inner: T) -> Self {
Self {
random_state: ahash::RandomState::new(),
inner,
}
}
pub const fn from_seed(seed: u64, inner: T) -> Self {
Self {
random_state: ahash::RandomState::with_seeds(seed, 0, 0, 0),
inner,
}
}
}
impl<T: Translator> Translator for Hashed<T> {
type Key = T::Key;
#[inline]
fn transform(&self, key: &[u8]) -> T::Key {
let hash_val = self.random_state.hash_one(key);
self.inner.transform(&hash_val.to_le_bytes())
}
}
impl<T: Translator> BuildHasher for Hashed<T> {
type Hasher = <T as BuildHasher>::Hasher;
#[inline]
fn build_hasher(&self) -> Self::Hasher {
self.inner.build_hasher()
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::hash::Hasher;
#[test]
fn test_one_cap() {
let t = OneCap;
assert_eq!(t.transform(b""), 0);
assert_eq!(t.transform(b"a"), b'a');
assert_eq!(t.transform(b"ab"), b'a');
assert_eq!(t.transform(b"abc"), b'a');
}
#[test]
fn test_two_cap() {
let t = TwoCap;
assert_eq!(t.transform(b""), 0);
assert_eq!(t.transform(b"abc"), t.transform(b"ab"));
assert!(t.transform(b"") < t.transform(b"a"));
assert!(t.transform(b"a") < t.transform(b"b"));
assert!(t.transform(b"ab") < t.transform(b"ac"));
assert!(t.transform(b"z") < t.transform(b"zz"));
assert_eq!(t.transform(b"zz"), t.transform(b"zzabc"));
}
#[test]
fn test_four_cap() {
let t = FourCap;
let t1 = t.transform(b"");
let t2 = t.transform(b"a");
let t3 = t.transform(b"abcd");
let t4 = t.transform(b"abcdef");
let t5 = t.transform(b"b");
assert_eq!(t1, 0);
assert!(t1 < t2);
assert!(t2 < t3);
assert_eq!(t3, t4);
assert!(t3 < t5);
assert!(t4 < t5);
}
#[test]
fn test_cap_3() {
let t = Cap::<3>::new();
assert_eq!(t.transform(b""), [0; 3]);
assert_eq!(t.transform(b"abc"), *b"abc");
assert_eq!(t.transform(b"abcdef"), *b"abc");
assert_eq!(t.transform(b"ab"), [b'a', b'b', 0]);
}
#[test]
fn test_cap_6() {
let t = Cap::<6>::new();
assert_eq!(t.transform(b""), [0; 6]);
assert_eq!(t.transform(b"abcdef"), *b"abcdef");
assert_eq!(t.transform(b"abcdefghi"), *b"abcdef");
assert_eq!(t.transform(b"abc"), [b'a', b'b', b'c', 0, 0, 0]);
}
#[test]
fn test_eight_cap() {
let t = EightCap;
let t1 = t.transform(b"");
let t2 = t.transform(b"a");
let t3 = t.transform(b"abcdefghaaaaaaa");
let t4 = t.transform(b"abcdefghijkzzzzzzzzzzzzzzzzzz");
let t5 = t.transform(b"b");
assert_eq!(t1, 0);
assert!(t1 < t2);
assert!(t2 < t3);
assert_eq!(t3, t4);
assert!(t3 < t5);
assert!(t4 < t5);
}
#[test]
fn identity_hasher_small_slices_differ() {
let hash = |bytes: &[u8]| {
let mut h = UintIdentity::default();
h.write(bytes);
h.finish()
};
assert_ne!(hash(b"abc"), hash(b"abd"));
assert_ne!(hash(b"a"), hash(b"b"));
assert_ne!(hash(b""), hash(b"a"));
}
#[test]
fn identity_hasher_sets_high_bits() {
for i in [1u64, 7, 17, 255] {
let mut h = UintIdentity::default();
h.write_u64(i);
assert_ne!(h.finish() >> 57, 0, "high bits all zero for input {i}");
}
}
#[test]
fn identity_hasher_integer_writes_differ() {
let hash_u8 = |v: u8| {
let mut h = UintIdentity::default();
h.write_u8(v);
h.finish()
};
let hash_u16 = |v: u16| {
let mut h = UintIdentity::default();
h.write_u16(v);
h.finish()
};
let hash_u32 = |v: u32| {
let mut h = UintIdentity::default();
h.write_u32(v);
h.finish()
};
let hash_u64 = |v: u64| {
let mut h = UintIdentity::default();
h.write_u64(v);
h.finish()
};
assert_ne!(hash_u8(0), hash_u8(1));
assert_ne!(hash_u16(0), hash_u16(1));
assert_ne!(hash_u32(0), hash_u32(1));
assert_ne!(hash_u64(0), hash_u64(1));
assert_eq!(hash_u8(7), hash_u16(7));
assert_eq!(hash_u16(7), hash_u32(7));
assert_eq!(hash_u32(7), hash_u64(7));
}
#[test]
#[should_panic]
fn identity_hasher_panics_on_large_write_slice() {
let mut h = UintIdentity::default();
h.write(b"too big for an int");
}
#[test]
fn test_hashed_consistency() {
let t = Hashed::from_seed(42, TwoCap);
assert_eq!(t.transform(b"hello"), t.transform(b"hello"));
assert_eq!(t.transform(b""), t.transform(b""));
assert_eq!(t.transform(b"abcdef"), t.transform(b"abcdef"));
}
#[test]
fn test_hashed_seed_determinism() {
let t1 = Hashed::from_seed(42, TwoCap);
let t2 = Hashed::from_seed(42, TwoCap);
assert_eq!(t1.transform(b"hello"), t2.transform(b"hello"));
assert_eq!(t1.transform(b"world"), t2.transform(b"world"));
}
#[test]
fn test_hashed_seed_independence() {
let t1 = Hashed::from_seed(1, EightCap);
let t2 = Hashed::from_seed(2, EightCap);
assert_ne!(t1.transform(b"hello"), t2.transform(b"hello"));
}
#[test]
fn test_hashed_prefix_collisions_avoided() {
let t = Hashed::from_seed(99, TwoCap);
let k1 = t.transform(b"abXXX");
let k2 = t.transform(b"abYYY");
assert_ne!(k1, k2);
}
#[test]
fn test_hashed_all_cap_sizes() {
let t1 = Hashed::from_seed(7, OneCap);
assert_eq!(t1.transform(b"test"), t1.transform(b"test"));
let t4 = Hashed::from_seed(7, FourCap);
assert_eq!(t4.transform(b"test"), t4.transform(b"test"));
let t8 = Hashed::from_seed(7, EightCap);
assert_eq!(t8.transform(b"test"), t8.transform(b"test"));
let tc = Hashed::from_seed(7, Cap::<3>::new());
assert_eq!(tc.transform(b"test"), tc.transform(b"test"));
}
#[test]
fn test_hashed_random_seed() {
let t1 = Hashed::new(EightCap);
let t2 = Hashed::new(EightCap);
assert_ne!(t1.transform(b"hello"), t2.transform(b"hello"));
}
}