use crate::RetrieveError;
use smallvec::SmallVec;
pub struct SNGIndex {
pub(crate) vectors: Vec<f32>,
pub(crate) dimension: usize,
pub(crate) num_vectors: usize,
doc_ids: Vec<u32>,
#[allow(dead_code)]
params: SNGParams,
built: bool,
pub(crate) neighbors: Vec<SmallVec<[u32; 16]>>,
truncation_r: f32,
}
#[derive(Clone, Debug)]
pub struct SNGParams {
pub max_degree: Option<usize>,
pub num_hash_functions: usize,
}
impl Default for SNGParams {
fn default() -> Self {
Self {
max_degree: None, num_hash_functions: 10,
}
}
}
impl SNGIndex {
pub fn new(dimension: usize, params: SNGParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be greater than 0".to_string(),
));
}
Ok(Self {
vectors: Vec::new(),
dimension,
num_vectors: 0,
doc_ids: Vec::new(),
params,
built: false,
neighbors: Vec::new(),
truncation_r: 0.0, })
}
pub fn add(&mut self, doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"cannot add vectors after index is built".into(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
self.vectors.extend_from_slice(&vector);
self.doc_ids.push(doc_id);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
self.truncation_r =
crate::sng::optimization::optimize_truncation_r(self.num_vectors, self.dimension)?;
self.construct_graph()?;
self.built = true;
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let results = crate::sng::search::search_sng(self, query, k)?;
Ok(results
.into_iter()
.filter_map(|(internal_id, dist)| {
let doc_id = self.doc_ids.get(internal_id as usize).copied()?;
Some((doc_id, dist))
})
.collect())
}
pub(crate) fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
let end = start + self.dimension;
&self.vectors[start..end]
}
fn construct_graph(&mut self) -> Result<(), RetrieveError> {
use crate::simd;
use crate::sng::martingale;
use smallvec::SmallVec;
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
self.neighbors = vec![SmallVec::new(); self.num_vectors];
let mut evolution = martingale::CandidateEvolution::new();
for current_id in 0..self.num_vectors {
let current_vector = self.get_vector(current_id);
let mut candidates = Vec::new();
for other_id in 0..self.num_vectors {
if other_id == current_id {
continue;
}
let other_vector = self.get_vector(other_id);
let dist = 1.0 - simd::dot(current_vector, other_vector);
candidates.push((other_id as u32, dist));
}
let pruned = martingale::prune_candidates_martingale(
&candidates,
self.truncation_r,
&self.vectors,
self.dimension,
)?;
evolution.update(pruned.len());
for &neighbor_id in &pruned {
self.neighbors[current_id].push(neighbor_id);
let reverse_neighbors = &mut self.neighbors[neighbor_id as usize];
if !reverse_neighbors.contains(&(current_id as u32)) {
reverse_neighbors.push(current_id as u32);
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::RetrieveError;
fn normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter_mut().for_each(|x| *x /= norm);
}
}
#[test]
fn test_create_index() {
let index = SNGIndex::new(4, SNGParams::default());
assert!(index.is_ok());
let index = index.unwrap();
assert_eq!(index.dimension, 4);
assert_eq!(index.num_vectors, 0);
}
#[test]
fn test_add_and_search() {
let mut index = SNGIndex::new(4, SNGParams::default()).unwrap();
for i in 0..10u32 {
let mut v = vec![i as f32 + 1.0, (i as f32) * 0.5, 1.0, 0.5];
normalize(&mut v);
index.add(i, v).unwrap();
}
index.build().unwrap();
let mut query = vec![1.0, 0.0, 1.0, 0.5];
normalize(&mut query);
let results = index.search(&query, 3).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 3);
}
#[test]
fn test_zero_dimension_error() {
let result = SNGIndex::new(0, SNGParams::default());
match result {
Err(RetrieveError::InvalidParameter(_)) => {}
Err(other) => panic!("Expected InvalidParameter, got {:?}", other),
Ok(_) => panic!("Expected error for dimension 0"),
}
}
}