use std::collections::HashSet;
use rand::Rng;
use crate::{Node, Vector};
pub struct LshExternal<'a, const N: usize> {
trees: Vec<Node<N, u32>>,
values: &'a [Vector<N>],
}
impl<const N: usize> LshExternal<'_, N> {
pub fn inner_node_count(&self) -> usize {
self.trees
.iter()
.map(|tree| tree.inner_node_count())
.sum::<usize>()
}
}
fn check_unique<const N: usize>(vectors: &[Vector<N>]) -> bool {
let mut hashes_seen = HashSet::new();
for vector in vectors.iter() {
if !hashes_seen.insert(vector) {
return false;
}
}
true
}
impl<'a, const N: usize> ANNIndexExternal<'a, N> for LshExternal<'a, N> {
type Index = Self;
fn build<R: Rng>(
num_trees: usize,
max_leaf_size: usize,
vectors: &'a [Vector<N>],
rng: &mut R,
) -> Result<Self::Index, &'static str> {
if !check_unique(vectors) {
return Err("Vectors are not unique");
}
if vectors.len() > u32::MAX as usize {
return Err("Number of vectors exceeds u32::MAX");
}
let all_indexes: Vec<u32> = (0..vectors.len() as u32).collect();
let vector_fn = |idx: &u32| &vectors[*idx as usize];
let trees: Vec<_> = (0..num_trees)
.map(|_| Node::build_tree(max_leaf_size, &all_indexes, &vector_fn, rng))
.collect();
Ok(Self {
trees,
values: vectors,
})
}
fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)> {
let mut candidates = HashSet::new();
for tree in self.trees.iter() {
tree.search(query, top_k, &mut candidates);
}
let mut results = candidates
.into_iter()
.map(|idx| (*idx as usize, self.values[*idx as usize].sq_euc_dist(query)))
.collect::<Vec<_>>();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
results.into_iter().take(top_k).collect()
}
fn memory_usage(&self) -> usize {
let tree_size = self
.trees
.iter()
.map(|tree| tree.memory_usage())
.sum::<usize>();
std::mem::size_of::<Self>() + tree_size
}
}
pub trait ANNIndexExternal<'a, const N: usize> {
type Index;
fn build<R: Rng>(
num_trees: usize,
max_leaf_size: usize,
vectors: &'a [Vector<N>],
rng: &mut R,
) -> Result<Self::Index, &'static str>;
fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)>;
fn memory_usage(&self) -> usize;
}
pub struct LinearSearchExternal<'a, const N: usize> {
values: &'a [Vector<N>],
}
impl<'a, const N: usize> ANNIndexExternal<'a, N> for LinearSearchExternal<'a, N> {
type Index = Self;
fn build<R: Rng>(
_num_trees: usize,
_max_leaf_size: usize,
vectors: &'a [Vector<N>],
_rng: &mut R,
) -> Result<Self::Index, &'static str> {
if vectors.is_empty() {
return Err("Cannot build index with empty vector set");
}
Ok(Self { values: vectors })
}
fn search(&self, query: &Vector<N>, top_k: usize) -> Vec<(usize, f32)> {
if top_k == 0 || self.values.is_empty() {
return Vec::new();
}
#[derive(PartialEq)]
struct Entry(f32, usize);
#[allow(clippy::non_canonical_partial_ord_impl)]
impl PartialOrd for Entry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
impl Eq for Entry {}
impl Ord for Entry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
}
}
let mut heap = std::collections::BinaryHeap::with_capacity(top_k);
for (idx, vec) in self.values.iter().enumerate() {
let dist = vec.sq_euc_dist(query);
if heap.len() < top_k {
heap.push(Entry(dist, idx));
} else if let Some(max_entry) = heap.peek() {
if dist < max_entry.0 {
heap.pop();
heap.push(Entry(dist, idx));
}
}
}
let mut result = Vec::with_capacity(heap.len());
while let Some(Entry(dist, idx)) = heap.pop() {
result.push((idx, dist));
}
result.reverse();
result
}
fn memory_usage(&self) -> usize {
std::mem::size_of::<Self>()
}
}
#[cfg(test)]
mod tests {
use rand::{SeedableRng, rngs::StdRng};
use super::*;
fn approx_eq(a: f32, b: f32, tol: f32) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_basic_nearest_neighbor() {
let mut seed_rng = StdRng::seed_from_u64(42);
let vectors = vec![
Vector::from([10.0, 20.0, 30.0]),
Vector::from([10.0, 30.0, 20.0]),
Vector::from([20.0, 10.0, 30.0]),
Vector::from([20.0, 30.0, 10.0]),
Vector::from([30.0, 20.0, 10.0]),
Vector::from([30.0, 10.0, 20.0]),
];
let index = LshExternal::build(1, 1, &vectors, &mut seed_rng).unwrap();
for (i, vector) in vectors.iter().enumerate() {
let results = index.search(vector, 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, i);
}
for (i, vector) in vectors.iter().enumerate() {
let query = vector.add(&Vector::from([0.1, 0.1, 0.1]));
let results = index.search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, i);
}
let index = LinearSearchExternal::build(1, 1, &vectors, &mut seed_rng).unwrap();
for (i, vector) in vectors.iter().enumerate() {
let results = index.search(vector, 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, i);
}
for (i, vector) in vectors.iter().enumerate() {
let query = vector.add(&Vector::from([0.1, 0.1, 0.1]));
let results = index.search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, i);
}
}
#[test]
fn test_top_2_nearest_neighbor() {
let mut seed_rng = StdRng::seed_from_u64(42);
let vectors = vec![
Vector::from([10.0, 20.0, 30.0]),
Vector::from([10.0, 20.0, 30.1]),
Vector::from([20.0, 30.0, 10.0]),
Vector::from([20.0, 30.1, 10.0]),
Vector::from([30.0, 20.0, 10.0]),
Vector::from([30.1, 20.0, 10.0]),
];
let index = LshExternal::build(1, 2, &vectors, &mut seed_rng).unwrap();
for (i, vector) in vectors.iter().enumerate() {
let results = index.search(vector, 2);
assert_eq!(results.len(), 2);
let id_bucket = i / 2 * 2;
assert_eq!(results[0].0 + results[1].0, id_bucket + id_bucket + 1);
}
let index = LinearSearchExternal::build(1, 2, &vectors, &mut seed_rng).unwrap();
for (i, vector) in vectors.iter().enumerate() {
let results = index.search(vector, 2);
let id_bucket = i / 2 * 2;
assert_eq!(results[0].0 + results[1].0, id_bucket + id_bucket + 1);
}
}
#[test]
fn test_top_k_exceeds_total_and_high_dim() {
let vectors = vec![
Vector::from([0.0, 0.0, 0.0, 0.0]),
Vector::from([1.0, 0.0, 0.0, 0.0]),
Vector::from([0.0, 1.0, 0.0, 0.0]),
Vector::from([0.0, 0.0, 1.0, 0.0]),
Vector::from([0.0, 0.0, 0.0, 1.0]),
];
let mut seed_rng = StdRng::seed_from_u64(42);
let index = LshExternal::build(3, 2, &vectors, &mut seed_rng).unwrap();
let query = Vector::from([0.5, 0.5, 0.5, 0.5]);
let top_k = 10; let results = index.search(&query, top_k);
assert_eq!(results.len(), 5);
for &(_, dist) in results.iter() {
assert!(approx_eq(dist, 1.0, 0.0001));
}
}
}