use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::sync::Arc;
use crate::canonical::Canonicalizer;
use crate::classical::Fingerprinter;
use crate::classical::hash::{HashFamily, hash128};
use crate::error::{Error, Result};
use crate::tokenize::Tokenizer;
use super::sig::SimHash64;
pub const DEFAULT_SEED: u64 = 0x00C0_FFEE_5EED;
#[derive(Clone, Debug)]
pub enum Weighting {
Uniform,
Tf,
IdfWeighted(IdfTable),
}
impl Default for Weighting {
fn default() -> Self {
Self::Tf
}
}
#[derive(Clone, Debug, Default)]
pub struct IdfTable {
inner: Arc<BTreeMap<String, f32>>,
}
impl IdfTable {
pub fn from_pairs<I, S>(pairs: I) -> Self
where
I: IntoIterator<Item = (S, f32)>,
S: Into<String>,
{
let mut m = BTreeMap::new();
for (k, v) in pairs {
m.insert(k.into(), v);
}
Self { inner: Arc::new(m) }
}
#[inline]
#[must_use]
pub fn get(&self, token: &str) -> f32 {
self.inner.get(token).copied().unwrap_or(1.0)
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
#[derive(Clone, Debug)]
pub struct SimHashFingerprinterBuilder {
seed: u64,
weighting: Weighting,
hasher: HashFamily,
}
impl Default for SimHashFingerprinterBuilder {
fn default() -> Self {
Self {
seed: DEFAULT_SEED,
weighting: Weighting::Tf,
hasher: HashFamily::Xxh3_64,
}
}
}
impl SimHashFingerprinterBuilder {
#[must_use]
pub fn seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
#[must_use]
pub fn weighting(mut self, w: Weighting) -> Self {
self.weighting = w;
self
}
#[must_use]
pub fn hasher(mut self, hasher: HashFamily) -> Self {
self.hasher = hasher;
self
}
#[must_use]
pub fn build<T: Tokenizer>(
self,
canonicalizer: Canonicalizer,
tokenizer: T,
) -> SimHashFingerprinter<T> {
SimHashFingerprinter {
canonicalizer,
tokenizer,
seed: self.seed,
weighting: self.weighting,
hasher: self.hasher,
}
}
}
#[derive(Clone, Debug)]
pub struct SimHashFingerprinter<T: Tokenizer> {
canonicalizer: Canonicalizer,
tokenizer: T,
seed: u64,
weighting: Weighting,
hasher: HashFamily,
}
impl<T: Tokenizer> SimHashFingerprinter<T> {
pub fn new(canonicalizer: Canonicalizer, tokenizer: T) -> Self {
Self {
canonicalizer,
tokenizer,
seed: DEFAULT_SEED,
weighting: Weighting::Tf,
hasher: HashFamily::Xxh3_64,
}
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
#[must_use]
pub fn with_hasher(mut self, hasher: HashFamily) -> Self {
self.hasher = hasher;
self
}
#[must_use]
pub fn with_weighting(mut self, w: Weighting) -> Self {
self.weighting = w;
self
}
pub fn canonicalizer(&self) -> &Canonicalizer {
&self.canonicalizer
}
pub fn tokenizer(&self) -> &T {
&self.tokenizer
}
pub fn weighting(&self) -> &Weighting {
&self.weighting
}
pub fn hasher(&self) -> HashFamily {
self.hasher
}
pub(super) fn sketch_canonical(&self, canonical: &str) -> Result<SimHash64> {
let mut acc: [i64; 64] = [0; 64];
let mut any = false;
match &self.weighting {
Weighting::Tf => {
let hasher = self.hasher;
let seed = self.seed;
self.tokenizer.for_each_token(canonical, &mut |tok| {
any = true;
let (lo, _hi) = hash128(hasher, tok.as_bytes(), seed);
accumulate_bits(&mut acc, lo, 1);
});
}
Weighting::Uniform | Weighting::IdfWeighted(_) => {
#[cfg(feature = "std")]
let mut counts: std::collections::HashMap<String, u32> =
std::collections::HashMap::new();
#[cfg(not(feature = "std"))]
let mut counts: alloc::collections::BTreeMap<String, u32> =
alloc::collections::BTreeMap::new();
self.tokenizer.for_each_token(canonical, &mut |tok| {
any = true;
if let Some(c) = counts.get_mut(tok) {
*c += 1;
} else {
counts.insert(tok.into(), 1);
}
});
if !any {
return Err(Error::InvalidInput("empty document".into()));
}
for (tok, tf) in &counts {
let weight = match &self.weighting {
Weighting::Uniform => 1.0_f64,
Weighting::IdfWeighted(table) => (*tf as f64) * table.get(tok) as f64,
Weighting::Tf => unreachable!(),
};
let weight = if weight.is_finite() { weight } else { 1.0 };
let w_int = weight.clamp(-1e15, 1e15) as i64;
let (lo, _hi) = hash128(self.hasher, tok.as_bytes(), self.seed);
accumulate_bits(&mut acc, lo, w_int);
}
}
}
if !any {
return Err(Error::InvalidInput("empty document".into()));
}
let mut bits: u64 = 0;
for (b, &slot) in acc.iter().enumerate() {
if slot > 0 {
bits |= 1u64 << b;
}
}
Ok(SimHash64(bits))
}
}
#[inline]
fn accumulate_bits(acc: &mut [i64; 64], lo: u64, w: i64) {
for (b, slot) in acc.iter_mut().enumerate() {
if (lo >> b) & 1 == 1 {
*slot = slot.saturating_add(w);
} else {
*slot = slot.saturating_sub(w);
}
}
}
impl<T: Tokenizer> Fingerprinter for SimHashFingerprinter<T> {
type Output = SimHash64;
fn fingerprint(&self, input: &str) -> Result<Self::Output> {
if input.is_empty() {
return Err(Error::InvalidInput("empty document".into()));
}
let canonical = self.canonicalizer.canonicalize(input);
self.sketch_canonical(&canonical)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::canonical::Canonicalizer;
use crate::classical::simhash::distance::hamming;
use crate::tokenize::WordTokenizer;
fn fp() -> SimHashFingerprinter<WordTokenizer> {
SimHashFingerprinter::new(Canonicalizer::default(), WordTokenizer)
}
#[test]
fn empty_input_errors() {
assert!(matches!(fp().fingerprint(""), Err(Error::InvalidInput(_))));
}
#[test]
fn deterministic() {
let f = fp();
let a = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
let b = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
assert_eq!(a, b);
}
#[test]
fn similar_docs_have_small_hamming() {
let f = fp();
let a = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
let b = f
.fingerprint("the quick brown fox leaps over the lazy dog")
.unwrap();
let h = hamming(a, b);
assert!(h < 16, "expected hamming < 16, got {h}");
}
#[test]
fn different_docs_have_large_hamming() {
let f = fp();
let a = f
.fingerprint("the quick brown fox jumps over the lazy dog")
.unwrap();
let b = f
.fingerprint("astronomers map cosmic background radiation")
.unwrap();
let h = hamming(a, b);
assert!(h > 16, "expected hamming > 16, got {h}");
}
#[test]
fn uniform_vs_tf_can_differ() {
let canon = Canonicalizer::default();
let f1 = SimHashFingerprinter::new(canon.clone(), WordTokenizer)
.with_weighting(Weighting::Uniform);
let f2 = SimHashFingerprinter::new(canon, WordTokenizer).with_weighting(Weighting::Tf);
let a = f1.fingerprint("the the the the cat").unwrap();
let b = f2.fingerprint("the the the the cat").unwrap();
assert_ne!(a, b);
}
#[test]
fn idf_table_lookup() {
let table = IdfTable::from_pairs([("the", 0.1f32), ("cat", 4.0f32)]);
assert!((table.get("the") - 0.1).abs() < 1e-6);
assert!((table.get("cat") - 4.0).abs() < 1e-6);
assert!((table.get("absent") - 1.0).abs() < 1e-6);
assert_eq!(table.len(), 2);
assert!(!table.is_empty());
}
#[test]
fn idf_weighting_runs_end_to_end() {
let table = IdfTable::from_pairs([("the", 0.1f32), ("dog", 4.0f32)]);
let f = fp().with_weighting(Weighting::IdfWeighted(table));
let s = f.fingerprint("the dog the dog the dog").unwrap();
assert_ne!(s, SimHash64::new(0));
}
#[test]
fn schema_round_trip() {
let f = fp();
let s = f.fingerprint("hello world").unwrap();
let bytes = s.as_bytes();
let s2: SimHash64 = *bytemuck::from_bytes(bytes);
assert_eq!(s, s2);
}
#[test]
fn xxh3_hasher_works() {
let f = fp().with_hasher(HashFamily::Xxh3_64);
let s = f.fingerprint("the quick brown fox jumps").unwrap();
assert_ne!(s, SimHash64::new(0));
}
#[test]
fn builder_default_matches_constructor() {
let canon = Canonicalizer::default();
let a = SimHashFingerprinterBuilder::default().build(canon.clone(), WordTokenizer);
let b = SimHashFingerprinter::new(canon, WordTokenizer);
let s_a = a.fingerprint("hello world").unwrap();
let s_b = b.fingerprint("hello world").unwrap();
assert_eq!(s_a, s_b);
}
}