use rand::{rngs::StdRng, Rng, SeedableRng};
const PRIME_61: u64 = 0x1FFFFFFFFFFFFFFF;
pub struct UniversalHash64State {
p: u64,
a: u64,
b: u64,
partial: u64,
}
impl Default for UniversalHash64State {
fn default() -> Self {
Self::new()
}
}
pub struct UniversalHashBuilder {
seed: Option<u64>,
a: Option<u64>,
b: Option<u64>,
p: u64, }
impl Default for UniversalHashBuilder {
fn default() -> Self {
Self {
seed: None,
a: None,
b: None,
p: PRIME_61,
}
}
}
impl UniversalHashBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn prime(mut self, p: u64) -> Self {
assert!(p > 1, "prime must be > 1");
self.p = p;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn params(mut self, a: u64, b: u64) -> Self {
self.a = Some(a);
self.b = Some(b);
self
}
pub fn build_64(self) -> UniversalHash64 {
let p = self.p;
let (a, b) = if let (Some(a), Some(b)) = (self.a, self.b) {
if a == 0 || a >= p {
panic!("param a must be in [1..p-1].");
}
if b >= p {
panic!("param b must be in [0..p-1].");
}
(a % p, b % p) } else {
let mut rng = if let Some(s) = self.seed {
StdRng::seed_from_u64(s)
} else {
StdRng::from_entropy()
};
let a = rng.gen_range(1..p);
let b = rng.gen_range(0..p);
(a, b)
};
UniversalHash64 { p, a, b }
}
}
#[derive(Debug, Clone)]
pub struct UniversalHash64 {
p: u64,
a: u64,
b: u64,
}
impl UniversalHash64 {
pub fn hash(&self, data: &[u8]) -> u64 {
let mut h = 0u64;
for &byte in data {
h = mul_mod(h, self.a, self.p);
h = add_mod(h, byte as u64, self.p);
}
add_mod(h, self.b, self.p)
}
pub fn hasher(&self) -> UniversalHash64State {
UniversalHash64State {
p: self.p,
a: self.a,
b: self.b,
partial: 0,
}
}
}
impl UniversalHash64State {
pub fn new() -> Self {
Self {
p: PRIME_61,
a: 1, b: 0, partial: 0,
}
}
pub fn write(&mut self, data: &[u8]) {
for &byte in data {
self.partial = mul_mod(self.partial, self.a, self.p);
self.partial = add_mod(self.partial, byte as u64, self.p);
}
}
pub fn finish(self) -> u64 {
add_mod(self.partial, self.b, self.p)
}
}
#[inline]
fn add_mod(x: u64, y: u64, p: u64) -> u64 {
let x = x % p;
let y = y % p;
let s = x.wrapping_add(y);
if s >= p {
s.wrapping_sub(p)
} else {
s
}
}
#[inline]
fn mul_mod(x: u64, y: u64, p: u64) -> u64 {
let x = x % p; let y = y % p;
let prod = (x as u128) * (y as u128);
(prod % (p as u128)) as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_universal_hash() {
let uh = UniversalHashBuilder::new().prime(31).seed(12345).build_64();
let h1 = uh.hash(b"abcd");
let h2 = uh.hash(b"abce");
assert!(h1 < 31);
assert!(h2 < 31);
assert_ne!(
h1, h2,
"Likely different for small input, though not guaranteed"
);
}
#[test]
fn test_streaming_equiv() {
let data = b"Hello, universal hashing test data!";
let uh = UniversalHashBuilder::new().build_64();
let direct = uh.hash(data);
let mut st = uh.hasher();
st.write(&data[..10]);
st.write(&data[10..]);
let streaming = st.finish();
assert_eq!(direct, streaming, "Streaming must match one-shot result");
}
#[test]
fn test_zero_data() {
let uh = UniversalHashBuilder::new().build_64();
let val = uh.hash(b"");
println!("Hash of empty = {}", val);
}
}