oxirs_vec/
lsh.rs

1//! Locality Sensitive Hashing (LSH) for approximate nearest neighbor search
2//!
3//! This module implements various LSH families including:
4//! - Random projection LSH for cosine similarity
5//! - MinHash for Jaccard similarity
6//! - SimHash for binary vectors
7//! - Multi-probe LSH for improved recall
8
9use crate::random_utils::NormalSampler as Normal;
10use crate::{Vector, VectorIndex};
11use anyhow::{anyhow, Result};
12#[allow(unused_imports)]
13use scirs2_core::random::{Random, Rng};
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16
17/// Configuration for LSH index
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct LshConfig {
20    /// Number of hash tables (L parameter)
21    pub num_tables: usize,
22    /// Number of hash functions per table (K parameter)
23    pub num_hash_functions: usize,
24    /// LSH family to use
25    pub lsh_family: LshFamily,
26    /// Random seed for reproducibility
27    pub seed: u64,
28    /// Enable multi-probe LSH
29    pub multi_probe: bool,
30    /// Number of probes for multi-probe LSH
31    pub num_probes: usize,
32}
33
34impl Default for LshConfig {
35    fn default() -> Self {
36        Self {
37            num_tables: 10,
38            num_hash_functions: 8,
39            lsh_family: LshFamily::RandomProjection,
40            seed: 42,
41            multi_probe: true,
42            num_probes: 3,
43        }
44    }
45}
46
47/// LSH family types
48#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
49pub enum LshFamily {
50    /// Random projection for cosine similarity
51    RandomProjection,
52    /// MinHash for Jaccard similarity
53    MinHash,
54    /// SimHash for binary similarity
55    SimHash,
56    /// P-stable distributions for Lp distance
57    PStable(f32), // p value
58}
59
60/// Hash function trait
61trait HashFunction: Send + Sync {
62    /// Compute hash value for a vector
63    fn hash(&self, vector: &[f32]) -> u64;
64
65    /// Compute multiple hash values
66    fn hash_multi(&self, vector: &[f32], num_hashes: usize) -> Vec<u64> {
67        (0..num_hashes).map(|_| self.hash(vector)).collect()
68    }
69}
70
71/// Random projection hash function for cosine similarity
72struct RandomProjectionHash {
73    projections: Vec<Vec<f32>>,
74    dimensions: usize,
75}
76
77impl RandomProjectionHash {
78    fn new(dimensions: usize, num_projections: usize, seed: u64) -> Self {
79        let mut rng = Random::seed(seed);
80        let normal =
81            Normal::new(0.0, 1.0).expect("standard normal distribution parameters are valid");
82
83        let mut projections = Vec::with_capacity(num_projections);
84        for _ in 0..num_projections {
85            let projection: Vec<f32> = (0..dimensions).map(|_| normal.sample(&mut rng)).collect();
86            projections.push(projection);
87        }
88
89        Self {
90            projections,
91            dimensions,
92        }
93    }
94}
95
96impl HashFunction for RandomProjectionHash {
97    fn hash(&self, vector: &[f32]) -> u64 {
98        let mut hash_value = 0u64;
99
100        for (i, projection) in self.projections.iter().enumerate() {
101            // Compute dot product
102            use oxirs_core::simd::SimdOps;
103            let dot_product = f32::dot(vector, projection);
104
105            // Set bit if positive
106            if dot_product > 0.0 {
107                hash_value |= 1 << (i % 64);
108            }
109        }
110
111        hash_value
112    }
113}
114
115/// MinHash for Jaccard similarity
116struct MinHashFunction {
117    a: Vec<u64>,
118    b: Vec<u64>,
119    prime: u64,
120}
121
122impl MinHashFunction {
123    fn new(num_hashes: usize, seed: u64) -> Self {
124        let mut rng = Random::seed(seed);
125        let prime = 4294967311u64; // Large prime
126
127        let a: Vec<u64> = (0..num_hashes)
128            .map(|_| rng.random_range(1..prime))
129            .collect();
130        let b: Vec<u64> = (0..num_hashes)
131            .map(|_| rng.random_range(0..prime))
132            .collect();
133
134        Self { a, b, prime }
135    }
136
137    fn minhash_signature(&self, set_elements: &[u32]) -> Vec<u64> {
138        let mut signature = vec![u64::MAX; self.a.len()];
139
140        for &element in set_elements {
141            for (i, sig_val) in signature.iter_mut().enumerate().take(self.a.len()) {
142                let hash = (self.a[i] * element as u64 + self.b[i]) % self.prime;
143                *sig_val = (*sig_val).min(hash);
144            }
145        }
146
147        signature
148    }
149}
150
151impl HashFunction for MinHashFunction {
152    fn hash(&self, vector: &[f32]) -> u64 {
153        // Convert vector to set of indices where value > threshold
154        let threshold = 0.0;
155        let set_elements: Vec<u32> = vector
156            .iter()
157            .enumerate()
158            .filter(|&(_, &v)| v > threshold)
159            .map(|(i, _)| i as u32)
160            .collect();
161
162        let signature = self.minhash_signature(&set_elements);
163
164        // Combine signature into single hash
165        let mut hash = 0u64;
166        for (i, &sig) in signature.iter().enumerate() {
167            hash ^= sig.rotate_left((i * 7) as u32);
168        }
169
170        hash
171    }
172}
173
174/// SimHash for binary similarity
175struct SimHashFunction {
176    random_vectors: Vec<Vec<f32>>,
177}
178
179impl SimHashFunction {
180    fn new(dimensions: usize, seed: u64) -> Self {
181        let mut rng = Random::seed(seed);
182        let normal =
183            Normal::new(0.0, 1.0).expect("standard normal distribution parameters are valid");
184
185        let random_vectors: Vec<Vec<f32>> = (0..64)
186            .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
187            .collect();
188
189        Self { random_vectors }
190    }
191}
192
193impl HashFunction for SimHashFunction {
194    fn hash(&self, vector: &[f32]) -> u64 {
195        let mut hash = 0u64;
196
197        for (i, random_vec) in self.random_vectors.iter().enumerate() {
198            // Weighted sum
199            let mut sum = 0.0;
200            for (j, &v) in vector.iter().enumerate() {
201                if j < random_vec.len() {
202                    sum += v * random_vec[j];
203                }
204            }
205
206            if sum > 0.0 {
207                hash |= 1 << i;
208            }
209        }
210
211        hash
212    }
213}
214
215/// P-stable LSH for Lp distance
216struct PStableHash {
217    projections: Vec<Vec<f32>>,
218    offsets: Vec<f32>,
219    width: f32,
220    p: f32,
221}
222
223impl PStableHash {
224    fn new(dimensions: usize, num_projections: usize, width: f32, p: f32, seed: u64) -> Self {
225        let mut rng = Random::seed(seed);
226
227        // Use Cauchy distribution for L1, Normal for L2
228        let projections: Vec<Vec<f32>> = if (p - 1.0).abs() < 0.1 {
229            // L1 distance - use Cauchy distribution
230            (0..num_projections)
231                .map(|_| {
232                    (0..dimensions)
233                        .map(|_| {
234                            let u: f32 = rng
235                                .gen_range(-std::f32::consts::PI / 2.0..std::f32::consts::PI / 2.0);
236                            u.tan()
237                        })
238                        .collect()
239                })
240                .collect()
241        } else if (p - 2.0).abs() < 0.1 {
242            // L2 distance - use Normal distribution
243            let normal =
244                Normal::new(0.0, 1.0).expect("standard normal distribution parameters are valid");
245            (0..num_projections)
246                .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
247                .collect()
248        } else {
249            // General case - approximate with Normal
250            let normal =
251                Normal::new(0.0, 1.0).expect("standard normal distribution parameters are valid");
252            (0..num_projections)
253                .map(|_| (0..dimensions).map(|_| normal.sample(&mut rng)).collect())
254                .collect()
255        };
256
257        let offsets: Vec<f32> = (0..num_projections)
258            .map(|_| rng.gen_range(0.0..width))
259            .collect();
260
261        Self {
262            projections,
263            offsets,
264            width,
265            p,
266        }
267    }
268}
269
270impl HashFunction for PStableHash {
271    fn hash(&self, vector: &[f32]) -> u64 {
272        let mut hash = 0u64;
273
274        for (i, (projection, &offset)) in self.projections.iter().zip(&self.offsets).enumerate() {
275            use oxirs_core::simd::SimdOps;
276            let dot_product = f32::dot(vector, projection);
277            let bucket = ((dot_product + offset) / self.width).floor() as i32;
278
279            // Map bucket to bit position
280            if bucket > 0 {
281                hash |= 1 << (i % 64);
282            }
283        }
284
285        hash
286    }
287}
288
289/// LSH table storing hash buckets
290struct LshTable {
291    buckets: HashMap<u64, Vec<usize>>,
292    hash_function: Box<dyn HashFunction>,
293}
294
295impl LshTable {
296    fn new(hash_function: Box<dyn HashFunction>) -> Self {
297        Self {
298            buckets: HashMap::new(),
299            hash_function,
300        }
301    }
302
303    fn insert(&mut self, id: usize, vector: &[f32]) {
304        let hash = self.hash_function.hash(vector);
305        self.buckets.entry(hash).or_default().push(id);
306    }
307
308    fn query(&self, vector: &[f32]) -> Vec<usize> {
309        let hash = self.hash_function.hash(vector);
310        self.buckets.get(&hash).cloned().unwrap_or_default()
311    }
312
313    fn query_multi_probe(&self, vector: &[f32], num_probes: usize) -> Vec<usize> {
314        let main_hash = self.hash_function.hash(vector);
315        let mut candidates = HashSet::new();
316
317        // Add exact match
318        if let Some(ids) = self.buckets.get(&main_hash) {
319            candidates.extend(ids);
320        }
321
322        // Probe nearby buckets by flipping bits
323        for probe in 1..=num_probes {
324            for bit in 0..64 {
325                let probed_hash = main_hash ^ (1 << bit);
326                if let Some(ids) = self.buckets.get(&probed_hash) {
327                    candidates.extend(ids);
328                }
329
330                // Stop if we have enough probes
331                if candidates.len() >= probe * 10 {
332                    break;
333                }
334            }
335        }
336
337        candidates.into_iter().collect()
338    }
339}
340
341/// LSH index implementation
342pub struct LshIndex {
343    config: LshConfig,
344    tables: Vec<LshTable>,
345    vectors: Vec<(String, Vector)>,
346    uri_to_id: HashMap<String, usize>,
347    dimensions: Option<usize>,
348}
349
350impl LshIndex {
351    /// Create a new LSH index
352    pub fn new(config: LshConfig) -> Self {
353        let tables = Self::create_tables(&config, 0);
354
355        Self {
356            config,
357            tables,
358            vectors: Vec::new(),
359            uri_to_id: HashMap::new(),
360            dimensions: None,
361        }
362    }
363
364    /// Create hash tables based on configuration
365    fn create_tables(config: &LshConfig, dimensions: usize) -> Vec<LshTable> {
366        let mut tables = Vec::with_capacity(config.num_tables);
367
368        for table_idx in 0..config.num_tables {
369            let seed = config.seed.wrapping_add(table_idx as u64);
370
371            let hash_function: Box<dyn HashFunction> = match config.lsh_family {
372                LshFamily::RandomProjection => Box::new(RandomProjectionHash::new(
373                    dimensions,
374                    config.num_hash_functions,
375                    seed,
376                )),
377                LshFamily::MinHash => {
378                    Box::new(MinHashFunction::new(config.num_hash_functions, seed))
379                }
380                LshFamily::SimHash => Box::new(SimHashFunction::new(dimensions, seed)),
381                LshFamily::PStable(p) => {
382                    Box::new(PStableHash::new(
383                        dimensions,
384                        config.num_hash_functions,
385                        1.0, // Default width
386                        p,
387                        seed,
388                    ))
389                }
390            };
391
392            tables.push(LshTable::new(hash_function));
393        }
394
395        tables
396    }
397
398    /// Rebuild tables with known dimensions
399    fn rebuild_tables(&mut self) {
400        if let Some(dims) = self.dimensions {
401            self.tables = Self::create_tables(&self.config, dims);
402
403            // Re-insert all vectors
404            for (id, (_, vector)) in self.vectors.iter().enumerate() {
405                let vector_f32 = vector.as_f32();
406                for table in &mut self.tables {
407                    table.insert(id, &vector_f32);
408                }
409            }
410        }
411    }
412
413    /// Query for approximate nearest neighbors
414    fn query_candidates(&self, vector: &[f32]) -> Vec<usize> {
415        let mut candidates = HashSet::new();
416
417        if self.config.multi_probe {
418            // Multi-probe LSH
419            for table in &self.tables {
420                let table_candidates = table.query_multi_probe(vector, self.config.num_probes);
421                candidates.extend(table_candidates);
422            }
423        } else {
424            // Standard LSH
425            for table in &self.tables {
426                let table_candidates = table.query(vector);
427                candidates.extend(table_candidates);
428            }
429        }
430
431        candidates.into_iter().collect()
432    }
433
434    /// Get statistics about the index
435    pub fn stats(&self) -> LshStats {
436        let avg_bucket_size = if self.tables.is_empty() {
437            0.0
438        } else {
439            let total_buckets: usize = self.tables.iter().map(|t| t.buckets.len()).sum();
440            let total_items: usize = self
441                .tables
442                .iter()
443                .map(|t| t.buckets.values().map(|v| v.len()).sum::<usize>())
444                .sum();
445
446            if total_buckets > 0 {
447                total_items as f64 / total_buckets as f64
448            } else {
449                0.0
450            }
451        };
452
453        LshStats {
454            num_vectors: self.vectors.len(),
455            num_tables: self.tables.len(),
456            avg_bucket_size,
457            memory_usage: self.estimate_memory_usage(),
458        }
459    }
460
461    fn estimate_memory_usage(&self) -> usize {
462        let vector_memory =
463            self.vectors.len() * (std::mem::size_of::<String>() + std::mem::size_of::<Vector>());
464
465        let table_memory: usize = self
466            .tables
467            .iter()
468            .map(|t| {
469                t.buckets.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<Vec<usize>>())
470                    + t.buckets
471                        .values()
472                        .map(|v| v.len() * std::mem::size_of::<usize>())
473                        .sum::<usize>()
474            })
475            .sum();
476
477        vector_memory + table_memory
478    }
479}
480
481impl VectorIndex for LshIndex {
482    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
483        // Initialize dimensions if needed
484        if self.dimensions.is_none() {
485            self.dimensions = Some(vector.dimensions);
486            self.rebuild_tables();
487        } else if Some(vector.dimensions) != self.dimensions {
488            return Err(anyhow!(
489                "Vector dimensions ({}) don't match index dimensions ({:?})",
490                vector.dimensions,
491                self.dimensions
492            ));
493        }
494
495        let id = self.vectors.len();
496        let vector_f32 = vector.as_f32();
497
498        // Insert into all tables
499        for table in &mut self.tables {
500            table.insert(id, &vector_f32);
501        }
502
503        self.uri_to_id.insert(uri.clone(), id);
504        self.vectors.push((uri, vector));
505
506        Ok(())
507    }
508
509    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
510        if self.vectors.is_empty() {
511            return Ok(Vec::new());
512        }
513
514        let query_f32 = query.as_f32();
515        let candidates = self.query_candidates(&query_f32);
516
517        // Compute exact distances for candidates
518        let mut results: Vec<(usize, f32)> = candidates
519            .into_iter()
520            .filter_map(|id| {
521                self.vectors.get(id).map(|(_, vec)| {
522                    let vec_f32 = vec.as_f32();
523                    let distance = match self.config.lsh_family {
524                        LshFamily::RandomProjection | LshFamily::SimHash => {
525                            // Cosine distance
526                            use oxirs_core::simd::SimdOps;
527                            f32::cosine_distance(&query_f32, &vec_f32)
528                        }
529                        LshFamily::MinHash => {
530                            // Jaccard distance
531                            let threshold = 0.0;
532                            let set1: HashSet<usize> = query_f32
533                                .iter()
534                                .enumerate()
535                                .filter(|&(_, &v)| v > threshold)
536                                .map(|(i, _)| i)
537                                .collect();
538                            let set2: HashSet<usize> = vec_f32
539                                .iter()
540                                .enumerate()
541                                .filter(|&(_, &v)| v > threshold)
542                                .map(|(i, _)| i)
543                                .collect();
544
545                            let intersection = set1.intersection(&set2).count();
546                            let union = set1.union(&set2).count();
547
548                            if union > 0 {
549                                1.0 - (intersection as f32 / union as f32)
550                            } else {
551                                1.0
552                            }
553                        }
554                        LshFamily::PStable(p) => {
555                            // Lp distance
556                            use oxirs_core::simd::SimdOps;
557                            if (p - 1.0).abs() < 0.1 {
558                                f32::manhattan_distance(&query_f32, &vec_f32)
559                            } else if (p - 2.0).abs() < 0.1 {
560                                f32::euclidean_distance(&query_f32, &vec_f32)
561                            } else {
562                                // General Minkowski distance
563                                query_f32
564                                    .iter()
565                                    .zip(&vec_f32)
566                                    .map(|(a, b)| (a - b).abs().powf(p))
567                                    .sum::<f32>()
568                                    .powf(1.0 / p)
569                            }
570                        }
571                    };
572                    (id, distance)
573                })
574            })
575            .collect();
576
577        // Sort by distance and take top k
578        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
579        results.truncate(k);
580
581        // Convert to final result format
582        Ok(results
583            .into_iter()
584            .map(|(id, dist)| (self.vectors[id].0.clone(), dist))
585            .collect())
586    }
587
588    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
589        if self.vectors.is_empty() {
590            return Ok(Vec::new());
591        }
592
593        let query_f32 = query.as_f32();
594        let candidates = self.query_candidates(&query_f32);
595
596        // Filter candidates by threshold
597        let mut results: Vec<(String, f32)> = candidates
598            .into_iter()
599            .filter_map(|id| {
600                self.vectors.get(id).and_then(|(uri, vec)| {
601                    let vec_f32 = vec.as_f32();
602                    let distance = match self.config.lsh_family {
603                        LshFamily::RandomProjection | LshFamily::SimHash => {
604                            use oxirs_core::simd::SimdOps;
605                            f32::cosine_distance(&query_f32, &vec_f32)
606                        }
607                        LshFamily::MinHash => {
608                            // Jaccard distance
609                            let threshold_val = 0.0;
610                            let set1: HashSet<usize> = query_f32
611                                .iter()
612                                .enumerate()
613                                .filter(|&(_, &v)| v > threshold_val)
614                                .map(|(i, _)| i)
615                                .collect();
616                            let set2: HashSet<usize> = vec_f32
617                                .iter()
618                                .enumerate()
619                                .filter(|&(_, &v)| v > threshold_val)
620                                .map(|(i, _)| i)
621                                .collect();
622
623                            let intersection = set1.intersection(&set2).count();
624                            let union = set1.union(&set2).count();
625
626                            if union > 0 {
627                                1.0 - (intersection as f32 / union as f32)
628                            } else {
629                                1.0
630                            }
631                        }
632                        LshFamily::PStable(p) => {
633                            use oxirs_core::simd::SimdOps;
634                            if (p - 1.0).abs() < 0.1 {
635                                f32::manhattan_distance(&query_f32, &vec_f32)
636                            } else if (p - 2.0).abs() < 0.1 {
637                                f32::euclidean_distance(&query_f32, &vec_f32)
638                            } else {
639                                query_f32
640                                    .iter()
641                                    .zip(&vec_f32)
642                                    .map(|(a, b)| (a - b).abs().powf(p))
643                                    .sum::<f32>()
644                                    .powf(1.0 / p)
645                            }
646                        }
647                    };
648
649                    if distance <= threshold {
650                        Some((uri.clone(), distance))
651                    } else {
652                        None
653                    }
654                })
655            })
656            .collect();
657
658        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
659        Ok(results)
660    }
661
662    fn get_vector(&self, uri: &str) -> Option<&Vector> {
663        self.uri_to_id
664            .get(uri)
665            .and_then(|&id| self.vectors.get(id))
666            .map(|(_, v)| v)
667    }
668}
669
670/// LSH index statistics
671#[derive(Debug, Clone)]
672pub struct LshStats {
673    pub num_vectors: usize,
674    pub num_tables: usize,
675    pub avg_bucket_size: f64,
676    pub memory_usage: usize,
677}
678
679#[cfg(test)]
680mod tests {
681    use super::*;
682
683    #[test]
684    fn test_random_projection_lsh() {
685        let config = LshConfig {
686            num_tables: 5,
687            num_hash_functions: 4,
688            lsh_family: LshFamily::RandomProjection,
689            seed: 42,
690            multi_probe: false,
691            num_probes: 0,
692        };
693
694        let mut index = LshIndex::new(config);
695
696        // Insert vectors
697        let v1 = Vector::new(vec![1.0, 0.0, 0.0]);
698        let v2 = Vector::new(vec![0.0, 1.0, 0.0]);
699        let v3 = Vector::new(vec![0.0, 0.0, 1.0]);
700        let v_similar = Vector::new(vec![0.9, 0.1, 0.0]); // Similar to v1
701
702        index.insert("v1".to_string(), v1.clone()).unwrap();
703        index.insert("v2".to_string(), v2.clone()).unwrap();
704        index.insert("v3".to_string(), v3.clone()).unwrap();
705        index
706            .insert("v_similar".to_string(), v_similar.clone())
707            .unwrap();
708
709        // Search for similar vectors
710        let results = index.search_knn(&v1, 2).unwrap();
711
712        assert!(results.len() <= 2);
713        // v1 and v_similar should be the closest
714        assert!(results
715            .iter()
716            .any(|(uri, _)| uri == "v1" || uri == "v_similar"));
717    }
718
719    #[test]
720    fn test_minhash_lsh() {
721        let config = LshConfig {
722            num_tables: 3,
723            num_hash_functions: 64,
724            lsh_family: LshFamily::MinHash,
725            seed: 42,
726            multi_probe: false,
727            num_probes: 0,
728        };
729
730        let mut index = LshIndex::new(config);
731
732        // Create sparse binary vectors
733        let mut v1 = vec![0.0; 100];
734        v1[0] = 1.0;
735        v1[10] = 1.0;
736        v1[20] = 1.0;
737
738        let mut v2 = vec![0.0; 100];
739        v2[0] = 1.0;
740        v2[10] = 1.0;
741        v2[30] = 1.0; // 2/4 overlap with v1
742
743        let mut v3 = vec![0.0; 100];
744        v3[50] = 1.0;
745        v3[60] = 1.0;
746        v3[70] = 1.0; // No overlap with v1
747
748        index
749            .insert("v1".to_string(), Vector::new(v1.clone()))
750            .unwrap();
751        index.insert("v2".to_string(), Vector::new(v2)).unwrap();
752        index.insert("v3".to_string(), Vector::new(v3)).unwrap();
753
754        // Search for similar vectors
755        let results = index.search_knn(&Vector::new(v1), 2).unwrap();
756
757        // v1 should be first, v2 should be second (due to overlap)
758        assert!(!results.is_empty());
759        assert_eq!(results[0].0, "v1");
760        if results.len() > 1 {
761            assert_eq!(results[1].0, "v2");
762        }
763    }
764
765    #[test]
766    fn test_multi_probe_lsh() {
767        let config = LshConfig {
768            num_tables: 3,
769            num_hash_functions: 4,
770            lsh_family: LshFamily::RandomProjection,
771            seed: 42,
772            multi_probe: true,
773            num_probes: 2,
774        };
775
776        let mut index = LshIndex::new(config);
777
778        // Insert many vectors
779        for i in 0..50 {
780            let angle = i as f32 * std::f32::consts::PI / 25.0;
781            let vec = Vector::new(vec![angle.cos(), angle.sin(), 0.0]);
782            index.insert(format!("v{i}"), vec).unwrap();
783        }
784
785        // Search with multi-probe should find more candidates
786        let query = Vector::new(vec![1.0, 0.0, 0.0]);
787        let results = index.search_knn(&query, 5).unwrap();
788
789        assert_eq!(results.len(), 5);
790        // Results should be ordered by distance
791        for i in 1..results.len() {
792            assert!(results[i - 1].1 <= results[i].1);
793        }
794    }
795}