use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone)]
pub struct SimHash {
bits: usize,
}
impl SimHash {
pub fn new(bits: usize) -> Self {
assert!(bits > 0 && bits <= 128, "bits must be 1-128");
Self { bits }
}
pub fn new_64() -> Self {
Self { bits: 64 }
}
pub fn fingerprint<T: Hash>(&self, features: &[(T, f64)]) -> SimHashFingerprint {
let mut v = vec![0.0f64; self.bits];
for (feature, weight) in features {
let hash = self.hash_feature(feature);
for i in 0..self.bits {
if (hash >> i) & 1 == 1 {
v[i] += weight;
} else {
v[i] -= weight;
}
}
}
let mut fingerprint = 0u128;
for i in 0..self.bits {
if v[i] > 0.0 {
fingerprint |= 1u128 << i;
}
}
SimHashFingerprint {
value: fingerprint,
bits: self.bits,
}
}
pub fn fingerprint_unweighted<T: Hash, I: IntoIterator<Item = T>>(
&self,
features: I,
) -> SimHashFingerprint {
let weighted: Vec<(T, f64)> = features.into_iter().map(|f| (f, 1.0)).collect();
self.fingerprint(&weighted)
}
pub fn fingerprint_text(&self, text: &str, ngram_size: usize) -> SimHashFingerprint {
let chars: Vec<char> = text.chars().collect();
let mut features: Vec<(String, f64)> = Vec::new();
for window in chars.windows(ngram_size) {
let ngram: String = window.iter().collect();
features.push((ngram, 1.0));
}
for word in text.split_whitespace() {
features.push((word.to_lowercase(), 2.0)); }
self.fingerprint(&features)
}
fn hash_feature<T: Hash>(&self, feature: &T) -> u128 {
let mut hasher = DefaultHasher::new();
feature.hash(&mut hasher);
let h1 = hasher.finish();
let mut hasher2 = DefaultHasher::new();
h1.hash(&mut hasher2);
let h2 = hasher2.finish();
(h1 as u128) | ((h2 as u128) << 64)
}
pub fn bits(&self) -> usize {
self.bits
}
}
impl Default for SimHash {
fn default() -> Self {
Self::new_64()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SimHashFingerprint {
value: u128,
bits: usize,
}
impl SimHashFingerprint {
pub fn hamming_distance(&self, other: &SimHashFingerprint) -> usize {
let xor = self.value ^ other.value;
xor.count_ones() as usize
}
pub fn estimated_cosine(&self, other: &SimHashFingerprint) -> f64 {
let d = self.hamming_distance(other);
let theta = std::f64::consts::PI * (d as f64) / (self.bits as f64);
theta.cos()
}
pub fn is_similar(&self, other: &SimHashFingerprint, max_distance: usize) -> bool {
self.hamming_distance(other) <= max_distance
}
pub fn value(&self) -> u128 {
self.value
}
pub fn value_64(&self) -> u64 {
self.value as u64
}
pub fn bits(&self) -> usize {
self.bits
}
}
#[derive(Debug)]
pub struct SimHashLSH {
num_tables: usize,
bits_per_table: usize,
tables: Vec<std::collections::HashMap<u64, Vec<usize>>>,
masks: Vec<(Vec<usize>, u128)>,
fingerprints: Vec<SimHashFingerprint>,
}
impl SimHashLSH {
pub fn new(num_tables: usize, bits_per_table: usize, total_bits: usize) -> Self {
let mut masks = Vec::with_capacity(num_tables);
let mut rng_state = 12345u64;
for _ in 0..num_tables {
let mut bit_indices = Vec::with_capacity(bits_per_table);
let mut mask = 0u128;
while bit_indices.len() < bits_per_table {
rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
let bit = (rng_state as usize) % total_bits;
if !bit_indices.contains(&bit) {
bit_indices.push(bit);
mask |= 1u128 << bit;
}
}
bit_indices.sort_unstable();
masks.push((bit_indices, mask));
}
Self {
num_tables,
bits_per_table,
tables: (0..num_tables)
.map(|_| std::collections::HashMap::new())
.collect(),
masks,
fingerprints: Vec::new(),
}
}
pub fn insert(&mut self, fingerprint: SimHashFingerprint) -> usize {
let doc_id = self.fingerprints.len();
for (table_idx, (bit_indices, _)) in self.masks.iter().enumerate() {
let key = self.extract_bits(&fingerprint, bit_indices);
self.tables[table_idx].entry(key).or_default().push(doc_id);
}
self.fingerprints.push(fingerprint);
doc_id
}
pub fn query(&self, fingerprint: &SimHashFingerprint) -> Vec<usize> {
let mut candidates = std::collections::HashSet::new();
for (table_idx, (bit_indices, _)) in self.masks.iter().enumerate() {
let key = self.extract_bits(fingerprint, bit_indices);
if let Some(docs) = self.tables[table_idx].get(&key) {
candidates.extend(docs.iter().copied());
}
}
candidates.into_iter().collect()
}
pub fn query_with_distance(&self, fingerprint: &SimHashFingerprint) -> Vec<(usize, usize)> {
let candidates = self.query(fingerprint);
let mut results: Vec<(usize, usize)> = candidates
.into_iter()
.map(|id| {
let dist = fingerprint.hamming_distance(&self.fingerprints[id]);
(id, dist)
})
.collect();
results.sort_by_key(|(_, d)| *d);
results
}
fn extract_bits(&self, fingerprint: &SimHashFingerprint, bit_indices: &[usize]) -> u64 {
let mut key = 0u64;
for (i, &bit_idx) in bit_indices.iter().enumerate() {
if (fingerprint.value >> bit_idx) & 1 == 1 {
key |= 1u64 << i;
}
}
key
}
pub fn len(&self) -> usize {
self.fingerprints.len()
}
pub fn is_empty(&self) -> bool {
self.fingerprints.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simhash_identical() {
let sh = SimHash::new_64();
let fp1 = sh.fingerprint_text("hello world", 3);
let fp2 = sh.fingerprint_text("hello world", 3);
assert_eq!(fp1.hamming_distance(&fp2), 0);
}
#[test]
fn test_simhash_similar_text() {
let sh = SimHash::new_64();
let fp1 = sh.fingerprint_text("the quick brown fox jumps", 3);
let fp2 = sh.fingerprint_text("the quick brown dog jumps", 3);
let distance = fp1.hamming_distance(&fp2);
assert!(distance < 20);
}
#[test]
fn test_simhash_different_text() {
let sh = SimHash::new_64();
let fp1 = sh.fingerprint_text("the quick brown fox", 3);
let fp2 = sh.fingerprint_text("completely different text here", 3);
let distance = fp1.hamming_distance(&fp2);
assert!(distance > 15);
}
#[test]
fn test_simhash_lsh() {
let sh = SimHash::new_64();
let mut lsh = SimHashLSH::new(10, 8, 64);
let fp1 = sh.fingerprint_text("document about machine learning", 3);
let fp2 = sh.fingerprint_text("document about deep learning", 3);
let fp3 = sh.fingerprint_text("recipe for chocolate cake", 3);
lsh.insert(fp1);
lsh.insert(fp2);
lsh.insert(fp3);
let query = sh.fingerprint_text("document about neural networks", 3);
let results = lsh.query_with_distance(&query);
assert!(!results.is_empty());
}
#[test]
fn test_estimated_cosine() {
let sh = SimHash::new_64();
let fp1 = sh.fingerprint_text("hello world", 3);
let fp2 = sh.fingerprint_text("hello world", 3);
assert!((fp1.estimated_cosine(&fp2) - 1.0).abs() < 0.001);
}
}