use core::{hash::BuildHasher, marker::PhantomData};
use alloc::{vec, vec::Vec};
use crate::{
hash::{DefaultHashBuilder, HashPair},
Error,
};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MinHash<T: ?Sized, S = DefaultHashBuilder> {
signature: Vec<u64>,
#[cfg_attr(feature = "serde", serde(skip))]
hasher: S,
#[cfg_attr(feature = "serde", serde(skip))]
_marker: PhantomData<fn(&T)>,
}
impl<T: ?Sized> MinHash<T, DefaultHashBuilder> {
pub fn new(num_hashes: usize) -> Result<Self, Error> {
Self::with_hasher(num_hashes, DefaultHashBuilder)
}
}
impl<T: ?Sized, S: BuildHasher> MinHash<T, S> {
pub fn with_hasher(num_hashes: usize, hasher: S) -> Result<Self, Error> {
if num_hashes == 0 {
return Err(Error::InvalidParameter {
param: "num_hashes",
reason: "must be greater than zero",
});
}
Ok(Self {
signature: vec![u64::MAX; num_hashes],
hasher,
_marker: PhantomData,
})
}
pub fn insert(&mut self, item: &T)
where
T: core::hash::Hash,
{
let pair = HashPair::new(item, &self.hasher);
for (i, slot) in self.signature.iter_mut().enumerate() {
let value = pair.nth(i as u64);
if value < *slot {
*slot = value;
}
}
}
pub fn similarity(&self, other: &Self) -> Result<f64, Error> {
if self.signature.len() != other.signature.len() {
return Err(Error::IncompatibleParameters);
}
let matches = self
.signature
.iter()
.zip(other.signature.iter())
.filter(|(a, b)| a == b)
.count();
Ok(matches as f64 / self.signature.len() as f64)
}
pub fn merge(&mut self, other: &Self) -> Result<(), Error> {
if self.signature.len() != other.signature.len() {
return Err(Error::IncompatibleParameters);
}
for (dst, src) in self.signature.iter_mut().zip(other.signature.iter()) {
*dst = (*dst).min(*src);
}
Ok(())
}
#[inline]
#[must_use]
pub fn num_hashes(&self) -> usize {
self.signature.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.signature.iter().all(|&slot| slot == u64::MAX)
}
pub fn clear(&mut self) {
self.signature.iter_mut().for_each(|slot| *slot = u64::MAX);
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_new_rejects_zero() {
assert!(matches!(
MinHash::<&str>::new(0),
Err(Error::InvalidParameter { .. })
));
}
#[test]
fn test_identical_sets_are_fully_similar() {
let mut a = MinHash::new(256).unwrap();
let mut b = MinHash::new(256).unwrap();
for i in 0..1_000u32 {
a.insert(&i);
b.insert(&i);
}
assert_eq!(a.similarity(&b).unwrap(), 1.0);
}
#[test]
fn test_disjoint_sets_are_dissimilar() {
let mut a = MinHash::new(256).unwrap();
let mut b = MinHash::new(256).unwrap();
for i in 0..1_000u32 {
a.insert(&i);
}
for i in 10_000..11_000u32 {
b.insert(&i);
}
let similarity = a.similarity(&b).unwrap();
assert!(similarity < 0.1, "disjoint sets too similar: {similarity}");
}
#[test]
fn test_partial_overlap_estimate() {
let mut a = MinHash::new(512).unwrap();
let mut b = MinHash::new(512).unwrap();
for i in 0..1_000u32 {
a.insert(&i);
}
for i in 500..1_500u32 {
b.insert(&i);
}
let similarity = a.similarity(&b).unwrap();
assert!(
(0.27..=0.40).contains(&similarity),
"estimate {similarity} far from 1/3"
);
}
#[test]
fn test_similarity_rejects_mismatched_lengths() {
let a = MinHash::<u32>::new(128).unwrap();
let b = MinHash::<u32>::new(256).unwrap();
assert_eq!(a.similarity(&b), Err(Error::IncompatibleParameters));
}
#[test]
fn test_merge_forms_union() {
let mut a = MinHash::new(256).unwrap();
let mut b = MinHash::new(256).unwrap();
let mut union = MinHash::new(256).unwrap();
for i in 0..500u32 {
a.insert(&i);
union.insert(&i);
}
for i in 500..1_000u32 {
b.insert(&i);
union.insert(&i);
}
a.merge(&b).unwrap();
assert_eq!(a.similarity(&union).unwrap(), 1.0);
}
#[test]
fn test_merge_rejects_mismatched_lengths() {
let mut a = MinHash::<u32>::new(128).unwrap();
let b = MinHash::<u32>::new(64).unwrap();
assert_eq!(a.merge(&b), Err(Error::IncompatibleParameters));
}
#[test]
fn test_clear() {
let mut sketch = MinHash::new(64).unwrap();
sketch.insert("x");
assert!(!sketch.is_empty());
sketch.clear();
assert!(sketch.is_empty());
}
}