1use std::collections::HashMap;
8
9struct XorShift64 {
15 state: u64,
16}
17
18impl XorShift64 {
19 fn new(seed: u64) -> Self {
20 Self {
22 state: if seed == 0 { 1 } else { seed },
23 }
24 }
25
26 fn next(&mut self) -> u64 {
28 let mut x = self.state;
29 x ^= x << 13;
30 x ^= x >> 7;
31 x ^= x << 17;
32 self.state = x;
33 x
34 }
35
36 fn next_f64_signed(&mut self) -> f64 {
38 let bits = self.next();
39 let pos = (bits as f64) / (u64::MAX as f64);
41 pos * 2.0 - 1.0
42 }
43}
44
45#[derive(Debug, Clone)]
54pub struct LshHasher {
55 pub random_vectors: Vec<Vec<f64>>,
57 pub dim: usize,
59}
60
61impl LshHasher {
62 fn new_with_rng(dim: usize, num_hashes: usize, rng: &mut XorShift64) -> Self {
66 let mut random_vectors = Vec::with_capacity(num_hashes);
67 for _ in 0..num_hashes {
68 let mut v: Vec<f64> = (0..dim).map(|_| rng.next_f64_signed()).collect();
69 normalize_vec(&mut v);
70 random_vectors.push(v);
71 }
72 Self {
73 random_vectors,
74 dim,
75 }
76 }
77
78 pub fn hash(&self, v: &[f64]) -> u64 {
80 let mut h: u64 = 0;
81 for (bit, rv) in self.random_vectors.iter().enumerate() {
82 if bit >= 64 {
83 break;
84 }
85 let dot: f64 = v.iter().zip(rv.iter()).map(|(a, b)| a * b).sum();
86 if dot >= 0.0 {
87 h |= 1u64 << bit;
88 }
89 }
90 h
91 }
92}
93
94pub type LshBucket = HashMap<u64, Vec<usize>>;
96
97pub struct LshIndex {
103 pub vectors: Vec<Vec<f64>>,
105 pub buckets: Vec<LshBucket>,
107 pub hashers: Vec<LshHasher>,
109 pub dim: usize,
111 pub num_tables: usize,
113 pub num_hashes: usize,
115}
116
117impl LshIndex {
118 pub fn new(dim: usize, num_tables: usize, num_hashes: usize, seed: u64) -> Self {
125 let mut rng = XorShift64::new(seed);
126 let mut hashers = Vec::with_capacity(num_tables);
127 let mut buckets = Vec::with_capacity(num_tables);
128 for _ in 0..num_tables {
129 hashers.push(LshHasher::new_with_rng(dim, num_hashes, &mut rng));
130 buckets.push(LshBucket::new());
131 }
132 Self {
133 vectors: Vec::new(),
134 buckets,
135 hashers,
136 dim,
137 num_tables,
138 num_hashes,
139 }
140 }
141
142 pub fn insert(&mut self, id: usize, vector: &[f64]) {
144 while self.vectors.len() <= id {
146 self.vectors.push(vec![]);
147 }
148 self.vectors[id] = vector.to_vec();
149
150 for (table_idx, hasher) in self.hashers.iter().enumerate() {
151 let h = hasher.hash(vector);
152 self.buckets[table_idx].entry(h).or_default().push(id);
153 }
154 }
155
156 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
160 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
161 let mag_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
162 let mag_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
163 if mag_a < f64::EPSILON || mag_b < f64::EPSILON {
164 return 0.0;
165 }
166 dot / (mag_a * mag_b)
167 }
168
169 pub fn search(&self, query: &[f64], k: usize) -> Vec<(usize, f64)> {
175 let mut candidate_set = std::collections::HashSet::new();
176
177 for (table_idx, hasher) in self.hashers.iter().enumerate() {
178 let h = hasher.hash(query);
179 if let Some(ids) = self.buckets[table_idx].get(&h) {
180 for &id in ids {
181 candidate_set.insert(id);
182 }
183 }
184 }
185
186 let mut scored: Vec<(usize, f64)> = candidate_set
187 .into_iter()
188 .filter_map(|id| {
189 let v = self.vectors.get(id)?;
190 if v.is_empty() {
191 return None;
192 }
193 Some((id, Self::cosine_similarity(query, v)))
194 })
195 .collect();
196
197 scored.sort_by(|a, b| {
199 b.1.partial_cmp(&a.1)
200 .unwrap_or(std::cmp::Ordering::Equal)
201 .then_with(|| a.0.cmp(&b.0))
202 });
203
204 scored.truncate(k);
205 scored
206 }
207
208 pub fn len(&self) -> usize {
210 self.vectors.iter().filter(|v| !v.is_empty()).count()
211 }
212
213 pub fn is_empty(&self) -> bool {
215 self.len() == 0
216 }
217
218 pub fn clear(&mut self) {
220 self.vectors.clear();
221 for bucket in &mut self.buckets {
222 bucket.clear();
223 }
224 }
225}
226
227fn normalize_vec(v: &mut [f64]) {
232 let mag: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
233 if mag > f64::EPSILON {
234 for x in v.iter_mut() {
235 *x /= mag;
236 }
237 }
238}
239
240#[cfg(test)]
245mod tests {
246 use super::*;
247
248 fn unit_vec(dim: usize, axis: usize) -> Vec<f64> {
249 let mut v = vec![0.0_f64; dim];
250 v[axis] = 1.0;
251 v
252 }
253
254 fn new_index() -> LshIndex {
255 LshIndex::new(4, 4, 8, 42)
256 }
257
258 #[test]
261 fn test_xorshift64_deterministic() {
262 let mut rng1 = XorShift64::new(123);
263 let mut rng2 = XorShift64::new(123);
264 for _ in 0..100 {
265 assert_eq!(rng1.next(), rng2.next());
266 }
267 }
268
269 #[test]
270 fn test_xorshift64_nonzero_seed() {
271 let mut rng = XorShift64::new(0); let v = rng.next();
273 assert_ne!(v, 0);
274 }
275
276 #[test]
277 fn test_xorshift64_different_seeds() {
278 let mut rng1 = XorShift64::new(1);
279 let mut rng2 = XorShift64::new(2);
280 let v1 = rng1.next();
281 let v2 = rng2.next();
282 assert_ne!(v1, v2);
283 }
284
285 #[test]
288 fn test_normalize_vec_unit_length() {
289 let mut v = vec![3.0_f64, 4.0_f64];
290 normalize_vec(&mut v);
291 let mag: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
292 assert!((mag - 1.0).abs() < 1e-9);
293 }
294
295 #[test]
296 fn test_normalize_zero_vec_safe() {
297 let mut v = vec![0.0_f64; 4];
298 normalize_vec(&mut v); }
300
301 #[test]
304 fn test_hasher_deterministic() {
305 let mut rng = XorShift64::new(42);
306 let h1 = LshHasher::new_with_rng(4, 8, &mut rng);
307 let v = vec![1.0_f64, 0.0, 0.0, 0.0];
308 let hash1 = h1.hash(&v);
309
310 let mut rng2 = XorShift64::new(42);
312 let h2 = LshHasher::new_with_rng(4, 8, &mut rng2);
313 let hash2 = h2.hash(&v);
314
315 assert_eq!(hash1, hash2);
316 }
317
318 #[test]
319 fn test_hasher_similar_vectors_same_bucket() {
320 let mut rng = XorShift64::new(42);
321 let h = LshHasher::new_with_rng(4, 4, &mut rng);
322 let v1 = vec![1.0_f64, 0.001, 0.001, 0.001];
323 let v2 = vec![1.0_f64, 0.001, 0.001, 0.002];
324 let hash1 = h.hash(&v1);
326 let hash2 = h.hash(&v2);
327 let _ = (hash1, hash2);
329 }
330
331 #[test]
332 fn test_hasher_opposite_vectors_different_bits() {
333 let mut rng = XorShift64::new(99);
334 let h = LshHasher::new_with_rng(4, 8, &mut rng);
335 let v = vec![1.0_f64, 0.0, 0.0, 0.0];
336 let neg_v = vec![-1.0_f64, 0.0, 0.0, 0.0];
337 let h1 = h.hash(&v);
338 let h2 = h.hash(&neg_v);
339 assert_ne!(h1, h2);
341 }
342
343 #[test]
346 fn test_cosine_identical_vectors() {
347 let v = vec![1.0_f64, 2.0, 3.0];
348 let sim = LshIndex::cosine_similarity(&v, &v);
349 assert!((sim - 1.0).abs() < 1e-9);
350 }
351
352 #[test]
353 fn test_cosine_orthogonal_vectors() {
354 let v1 = vec![1.0_f64, 0.0, 0.0];
355 let v2 = vec![0.0_f64, 1.0, 0.0];
356 let sim = LshIndex::cosine_similarity(&v1, &v2);
357 assert!(sim.abs() < 1e-9);
358 }
359
360 #[test]
361 fn test_cosine_opposite_vectors() {
362 let v1 = vec![1.0_f64, 0.0];
363 let v2 = vec![-1.0_f64, 0.0];
364 let sim = LshIndex::cosine_similarity(&v1, &v2);
365 assert!((sim + 1.0).abs() < 1e-9);
366 }
367
368 #[test]
369 fn test_cosine_zero_vector() {
370 let v1 = vec![0.0_f64, 0.0];
371 let v2 = vec![1.0_f64, 0.0];
372 let sim = LshIndex::cosine_similarity(&v1, &v2);
373 assert!((sim).abs() < 1e-9);
374 }
375
376 #[test]
379 fn test_index_new_dimensions() {
380 let idx = LshIndex::new(8, 4, 16, 1);
381 assert_eq!(idx.dim, 8);
382 assert_eq!(idx.num_tables, 4);
383 assert_eq!(idx.num_hashes, 16);
384 assert_eq!(idx.hashers.len(), 4);
385 assert_eq!(idx.buckets.len(), 4);
386 }
387
388 #[test]
389 fn test_index_empty() {
390 let idx = new_index();
391 assert!(idx.is_empty());
392 assert_eq!(idx.len(), 0);
393 }
394
395 #[test]
398 fn test_insert_single_vector() {
399 let mut idx = new_index();
400 idx.insert(0, &[1.0, 0.0, 0.0, 0.0]);
401 assert_eq!(idx.len(), 1);
402 }
403
404 #[test]
405 fn test_insert_multiple_vectors() {
406 let mut idx = new_index();
407 for i in 0..10 {
408 idx.insert(i, &unit_vec(4, i % 4));
409 }
410 assert_eq!(idx.len(), 10);
411 }
412
413 #[test]
416 fn test_search_empty_index() {
417 let idx = new_index();
418 let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 5);
419 assert!(results.is_empty());
420 }
421
422 #[test]
423 fn test_search_exact_match() {
424 let mut idx = LshIndex::new(4, 8, 16, 42);
425 let v = vec![1.0_f64, 0.0, 0.0, 0.0];
426 idx.insert(0, &v);
427 let results = idx.search(&v, 1);
428 assert!(!results.is_empty());
429 assert_eq!(results[0].0, 0);
430 assert!((results[0].1 - 1.0).abs() < 1e-6);
431 }
432
433 #[test]
434 fn test_search_k_limits_results() {
435 let mut idx = LshIndex::new(4, 8, 4, 77);
436 let v = vec![1.0_f64, 0.0, 0.0, 0.0];
437 for i in 0..5 {
438 let mut vv = v.clone();
440 vv[0] = 1.0 - i as f64 * 0.01;
441 idx.insert(i, &vv);
442 }
443 let results = idx.search(&v, 2);
444 assert!(results.len() <= 2);
445 }
446
447 #[test]
448 fn test_search_returns_closer_vector() {
449 let mut idx = LshIndex::new(2, 8, 16, 1);
450 idx.insert(0, &[1.0_f64, 0.01]);
452 idx.insert(1, &[0.0_f64, 1.0]);
454
455 let results = idx.search(&[1.0_f64, 0.0], 2);
456 if results.len() >= 2 {
458 assert!(results[0].1 >= results[1].1);
459 }
460 }
461
462 #[test]
463 fn test_search_sorted_descending() {
464 let mut idx = LshIndex::new(4, 8, 16, 7);
465 idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
466 idx.insert(1, &[0.9_f64, 0.1, 0.0, 0.0]);
467 idx.insert(2, &[0.5_f64, 0.5, 0.0, 0.0]);
468
469 let query = [1.0_f64, 0.0, 0.0, 0.0];
470 let results = idx.search(&query, 3);
471 for w in results.windows(2) {
472 assert!(w[0].1 >= w[1].1, "Results not sorted descending");
473 }
474 }
475
476 #[test]
477 fn test_search_k_greater_than_num_vectors() {
478 let mut idx = new_index();
479 idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
480 idx.insert(1, &[0.0_f64, 1.0, 0.0, 0.0]);
481 let results = idx.search(&[1.0_f64, 0.0, 0.0, 0.0], 100);
482 assert!(results.len() <= 2);
483 }
484
485 #[test]
488 fn test_clear() {
489 let mut idx = new_index();
490 idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
491 idx.clear();
492 assert!(idx.is_empty());
493 }
494
495 #[test]
496 fn test_clear_then_insert() {
497 let mut idx = new_index();
498 idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
499 idx.clear();
500 idx.insert(0, &[0.0_f64, 1.0, 0.0, 0.0]);
501 assert_eq!(idx.len(), 1);
502 }
503
504 #[test]
507 fn test_multi_table_improves_recall() {
508 let mut idx = LshIndex::new(4, 16, 8, 2024);
510 let target = vec![1.0_f64, 0.0, 0.0, 0.0];
511 idx.insert(42, &target);
512
513 for i in 0..20 {
515 let mut v = vec![0.0_f64; 4];
516 v[i % 4] = 1.0;
517 v[(i + 1) % 4] = 0.1;
518 idx.insert(i, &v);
519 }
520
521 let results = idx.search(&target, 5);
522 let found = results.iter().any(|(id, _)| *id == 42);
523 assert!(found, "Target vector should be found with 16 tables");
524 }
525
526 #[test]
527 fn test_high_dimensional_search() {
528 let dim = 64;
529 let mut idx = LshIndex::new(dim, 8, 16, 99);
530 let target: Vec<f64> = (0..dim).map(|i| if i == 0 { 1.0 } else { 0.0 }).collect();
531 idx.insert(0, &target);
532 let results = idx.search(&target, 1);
533 if !results.is_empty() {
534 assert!((results[0].1 - 1.0).abs() < 1e-6);
535 }
536 }
537
538 #[test]
539 fn test_is_empty_after_inserts() {
540 let mut idx = new_index();
541 assert!(idx.is_empty());
542 idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
543 assert!(!idx.is_empty());
544 }
545
546 #[test]
547 fn test_results_contain_similarity() {
548 let mut idx = LshIndex::new(4, 8, 16, 55);
549 idx.insert(0, &[1.0_f64, 0.0, 0.0, 0.0]);
550 let results = idx.search(&[1.0_f64, 0.0, 0.0, 0.0], 1);
551 if !results.is_empty() {
552 assert!(results[0].1 >= 0.0 && results[0].1 <= 1.0 + 1e-9);
553 }
554 }
555}