use std::collections::HashMap;
struct XorShift64 {
state: u64,
}
impl XorShift64 {
fn new(seed: u64) -> Self {
Self {
state: if seed == 0 { 1 } else { seed },
}
}
fn next(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
fn next_f64_signed(&mut self) -> f64 {
let bits = self.next();
let pos = (bits as f64) / (u64::MAX as f64);
pos * 2.0 - 1.0
}
}
#[derive(Debug, Clone)]
pub struct LshHasher {
pub random_vectors: Vec<Vec<f64>>,
pub dim: usize,
}
impl LshHasher {
fn new_with_rng(dim: usize, num_hashes: usize, rng: &mut XorShift64) -> Self {
let mut random_vectors = Vec::with_capacity(num_hashes);
for _ in 0..num_hashes {
let mut v: Vec<f64> = (0..dim).map(|_| rng.next_f64_signed()).collect();
normalize_vec(&mut v);
random_vectors.push(v);
}
Self {
random_vectors,
dim,
}
}
pub fn hash(&self, v: &[f64]) -> u64 {
let mut h: u64 = 0;
for (bit, rv) in self.random_vectors.iter().enumerate() {
if bit >= 64 {
break;
}
let dot: f64 = v.iter().zip(rv.iter()).map(|(a, b)| a * b).sum();
if dot >= 0.0 {
h |= 1u64 << bit;
}
}
h
}
}
pub type LshBucket = HashMap<u64, Vec<usize>>;
pub struct LshIndex {
pub vectors: Vec<Vec<f64>>,
pub buckets: Vec<LshBucket>,
pub hashers: Vec<LshHasher>,
pub dim: usize,
pub num_tables: usize,
pub num_hashes: usize,
}
impl LshIndex {
pub fn new(dim: usize, num_tables: usize, num_hashes: usize, seed: u64) -> Self {
let mut rng = XorShift64::new(seed);
let mut hashers = Vec::with_capacity(num_tables);
let mut buckets = Vec::with_capacity(num_tables);
for _ in 0..num_tables {
hashers.push(LshHasher::new_with_rng(dim, num_hashes, &mut rng));
buckets.push(LshBucket::new());
}
Self {
vectors: Vec::new(),
buckets,
hashers,
dim,
num_tables,
num_hashes,
}
}
pub fn insert(&mut self, id: usize, vector: &[f64]) {
while self.vectors.len() <= id {
self.vectors.push(vec![]);
}
self.vectors[id] = vector.to_vec();
for (table_idx, hasher) in self.hashers.iter().enumerate() {
let h = hasher.hash(vector);
self.buckets[table_idx].entry(h).or_default().push(id);
}
}
pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let mag_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
if mag_a < f64::EPSILON || mag_b < f64::EPSILON {
return 0.0;
}
dot / (mag_a * mag_b)
}
pub fn search(&self, query: &[f64], k: usize) -> Vec<(usize, f64)> {
let mut candidate_set = std::collections::HashSet::new();
for (table_idx, hasher) in self.hashers.iter().enumerate() {
let h = hasher.hash(query);
if let Some(ids) = self.buckets[table_idx].get(&h) {
for &id in ids {
candidate_set.insert(id);
}
}
}
let mut scored: Vec<(usize, f64)> = candidate_set
.into_iter()
.filter_map(|id| {
let v = self.vectors.get(id)?;
if v.is_empty() {
return None;
}
Some((id, Self::cosine_similarity(query, v)))
})
.collect();
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
scored.truncate(k);
scored
}
pub fn len(&self) -> usize {
self.vectors.iter().filter(|v| !v.is_empty()).count()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
self.vectors.clear();
for bucket in &mut self.buckets {
bucket.clear();
}
}
}
fn normalize_vec(v: &mut [f64]) {
let mag: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if mag > f64::EPSILON {
for x in v.iter_mut() {
*x /= mag;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn unit_vec(dim: usize, axis: usize) -> Vec<f64> {
let mut v = vec![0.0_f64; dim];
v[axis] = 1.0;
v
}
fn new_index() -> LshIndex {
LshIndex::new(4, 4, 8, 42)
}
#[test]
fn test_xorshift64_deterministic() {
let mut rng1 = XorShift64::new(123);
let mut rng2 = XorShift64::new(123);
for _ in 0..100 {
assert_eq!(rng1.next(), rng2.next());
}
}
#[test]
fn test_xorshift64_nonzero_seed() {
let mut rng = XorShift64::new(0); let v = rng.next();
assert_ne!(v, 0);
}
#[test]
fn test_xorshift64_different_seeds() {
let mut rng1 = XorShift64::new(1);
let mut rng2 = XorShift64::new(2);
let v1 = rng1.next();
let v2 = rng2.next();
assert_ne!(v1, v2);
}
#[test]
fn test_normalize_vec_unit_length() {
let mut v = vec![3.0_f64, 4.0_f64];
normalize_vec(&mut v);
let mag: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!((mag - 1.0).abs() < 1e-9);
}
#[test]
fn test_normalize_zero_vec_safe() {
let mut v = vec![0.0_f64; 4];
normalize_vec(&mut v); }
#[test]
fn test_hasher_deterministic() {
let mut rng = XorShift64::new(42);
let h1 = LshHasher::new_with_rng(4, 8, &mut rng);
let v = vec![1.0_f64, 0.0, 0.0, 0.0];
let hash1 = h1.hash(&v);
let mut rng2 = XorShift64::new(42);
let h2 = LshHasher::new_with_rng(4, 8, &mut rng2);
let hash2 = h2.hash(&v);
assert_eq!(hash1, hash2);
}
#[test]
fn test_hasher_similar_vectors_same_bucket() {
let mut rng = XorShift64::new(42);
let h = LshHasher::new_with_rng(4, 4, &mut rng);
let v1 = vec![1.0_f64, 0.001, 0.001, 0.001];
let v2 = vec![1.0_f64, 0.001, 0.001, 0.002];
let hash1 = h.hash(&v1);
let hash2 = h.hash(&v2);
let _ = (hash1, hash2);
}
#[test]
fn test_hasher_opposite_vectors_different_bits() {
let mut rng = XorShift64::new(99);
let h = LshHasher::new_with_rng(4, 8, &mut rng);
let v = vec![1.0_f64, 0.0, 0.0, 0.0];
let neg_v = vec![-1.0_f64, 0.0, 0.0, 0.0];
let h1 = h.hash(&v);
let h2 = h.hash(&neg_v);
assert_ne!(h1, h2);
}
#[test]
fn test_cosine_identical_vectors() {
let v = vec![1.0_f64, 2.0, 3.0];
let sim = LshIndex::cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-9);
}
#[test]
fn test_cosine_orthogonal_vectors() {
let v1 = vec![1.0_f64, 0.0, 0.0];
let v2 = vec![0.0_f64, 1.0, 0.0];
let sim = LshIndex::cosine_similarity(&v1, &v2);
assert!(sim.abs() < 1e-9);
}
#[test]
fn test_cosine_opposite_vectors() {
let v1 = vec![1.0_f64, 0.0];
let v2 = vec![-1.0_f64, 0.0];
let sim = LshIndex::cosine_similarity(&v1, &v2);
assert!((sim + 1.0).abs() < 1e-9);
}
#[test]
fn test_cosine_zero_vector() {
let v1 = vec![0.0_f64, 0.0];
let v2 = vec![1.0_f64, 0.0];
let sim = LshIndex::cosine_similarity(&v1, &v2);
assert!((sim).abs() < 1e-9);
}
#[test]
fn test_index_new_dimensions() {
let idx = LshIndex::new(8, 4, 16, 1);
assert_eq!(idx.dim, 8);
assert_eq!(idx.num_tables, 4);
assert_eq!(idx.num_hashes, 16);
assert_eq!(idx.hashers.len(), 4);
assert_eq!(idx.buckets.len(), 4);
}
#[test]
fn test_index_empty() {
let idx = new_index();
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
}
#[test]
fn test_insert_single_vector() {
let mut idx = new_index();
idx.insert(0, &[1.0, 0.0, 0.0, 0.0]);
assert_eq!(idx.len(), 1);
}
#[test]
fn test_insert_multiple_vectors() {
let mut idx = new_index();
for i in 0..10 {
idx.insert(i, &unit_vec(4, i % 4));
}
assert_eq!(idx.len(), 10);
}
#[test]
fn test_search_empty_index() {
let idx = new_index();
let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 5);
assert!(results.is_empty());
}
#[test]
fn test_search_exact_match() {
let mut idx = LshIndex::new(4, 8, 16, 42);
let v = vec![1.0_f64, 0.0, 0.0, 0.0];
idx.insert(0, &v);
let results = idx.search(&v, 1);
assert!(!results.is_empty());
assert_eq!(results[0].0, 0);
assert!((results[0].1 - 1.0).abs() < 1e-6);
}
#[test]
fn test_search_k_limits_results() {
let mut idx = LshIndex::new(4, 8, 4, 77);
let v = vec![1.0_f64, 0.0, 0.0, 0.0];
for i in 0..5 {
let mut vv = v.clone();
vv[0] = 1.0 - i as f64 * 0.01;
idx.insert(i, &vv);
}
let results = idx.search(&v, 2);
assert!(results.len() <= 2);
}
#[test]
fn test_search_returns_closer_vector() {
let mut idx = LshIndex::new(2, 8, 16, 1);
idx.insert(0, &[1.0_f64, 0.01]);
idx.insert(1, &[0.0_f64, 1.0]);
let results = idx.search(&[1.0_f64, 0.0], 2);
if results.len() >= 2 {
assert!(results[0].1 >= results[1].1);
}
}
#[test]
fn test_search_sorted_descending() {
let mut idx = LshIndex::new(4, 8, 16, 7);
idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
idx.insert(1, &[0.9_f64, 0.1, 0.0, 0.0]);
idx.insert(2, &[0.5_f64, 0.5, 0.0, 0.0]);
let query = [1.0_f64, 0.0, 0.0, 0.0];
let results = idx.search(&query, 3);
for w in results.windows(2) {
assert!(w[0].1 >= w[1].1, "Results not sorted descending");
}
}
#[test]
fn test_search_k_greater_than_num_vectors() {
let mut idx = new_index();
idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
idx.insert(1, &[0.0_f64, 1.0, 0.0, 0.0]);
let results = idx.search(&[1.0_f64, 0.0, 0.0, 0.0], 100);
assert!(results.len() <= 2);
}
#[test]
fn test_clear() {
let mut idx = new_index();
idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
idx.clear();
assert!(idx.is_empty());
}
#[test]
fn test_clear_then_insert() {
let mut idx = new_index();
idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
idx.clear();
idx.insert(0, &[0.0_f64, 1.0, 0.0, 0.0]);
assert_eq!(idx.len(), 1);
}
#[test]
fn test_multi_table_improves_recall() {
let mut idx = LshIndex::new(4, 16, 8, 2024);
let target = vec![1.0_f64, 0.0, 0.0, 0.0];
idx.insert(42, &target);
for i in 0..20 {
let mut v = vec![0.0_f64; 4];
v[i % 4] = 1.0;
v[(i + 1) % 4] = 0.1;
idx.insert(i, &v);
}
let results = idx.search(&target, 5);
let found = results.iter().any(|(id, _)| *id == 42);
assert!(found, "Target vector should be found with 16 tables");
}
#[test]
fn test_high_dimensional_search() {
let dim = 64;
let mut idx = LshIndex::new(dim, 8, 16, 99);
let target: Vec<f64> = (0..dim).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
idx.insert(0, &target);
let results = idx.search(&target, 1);
if !results.is_empty() {
assert!((results[0].1 - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_is_empty_after_inserts() {
let mut idx = new_index();
assert!(idx.is_empty());
idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
assert!(!idx.is_empty());
}
#[test]
fn test_results_contain_similarity() {
let mut idx = LshIndex::new(4, 8, 16, 55);
idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
let results = idx.search(&[1.0_f64, 0.0, 0.0, 0.0], 1);
if !results.is_empty() {
assert!(results[0].1 >= 0.0 && results[0].1 <= 1.0 + 1e-9);
}
}
}