use crate::distance::FloatOrd;
use crate::RetrieveError;
use smallvec::SmallVec;
use std::collections::{BinaryHeap, HashSet};
#[derive(Clone, Debug, Default)]
pub struct SparseVector {
pub indices: Vec<u32>,
pub values: Vec<f32>,
}
impl SparseVector {
pub fn from_pairs(mut pairs: Vec<(u32, f32)>) -> Self {
pairs.sort_unstable_by_key(|&(i, _)| i);
pairs.dedup_by_key(|p| p.0);
let indices = pairs.iter().map(|&(i, _)| i).collect();
let values = pairs.iter().map(|&(_, v)| v).collect();
Self { indices, values }
}
#[inline]
pub fn norm(&self) -> f32 {
self.values.iter().map(|v| v * v).sum::<f32>().sqrt()
}
#[inline]
pub fn nnz(&self) -> usize {
self.indices.len()
}
}
#[derive(Clone, Debug)]
pub struct SparseMipsParams {
pub max_degree: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub alpha: f32,
}
impl Default for SparseMipsParams {
fn default() -> Self {
Self {
max_degree: 32,
ef_construction: 200,
ef_search: 100,
alpha: 1.2,
}
}
}
pub struct SparseMipsIndex {
params: SparseMipsParams,
built: bool,
vectors: Vec<SparseVector>,
num_vectors: usize,
doc_ids: Vec<u32>,
neighbors: Vec<SmallVec<[u32; 16]>>,
entry_point: u32,
}
impl SparseMipsIndex {
pub fn new(params: SparseMipsParams) -> Self {
Self {
params,
built: false,
vectors: Vec::new(),
num_vectors: 0,
doc_ids: Vec::new(),
neighbors: Vec::new(),
entry_point: 0,
}
}
pub fn add(&mut self, doc_id: u32, vector: SparseVector) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"cannot add after build".into(),
));
}
self.vectors.push(vector);
self.doc_ids.push(doc_id);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let n = self.num_vectors;
self.entry_point = self.find_entry_point();
if n <= 1000 {
self.build_knn_bruteforce();
} else {
self.build_knn_nndescent();
}
for i in 0..n {
let candidates = self.beam_search_internal(i, self.params.ef_construction);
let selected = self.rng_prune(i, &candidates);
let old = std::mem::replace(
&mut self.neighbors[i],
selected.iter().map(|&(id, _)| id).collect(),
);
drop(old);
let max_deg = self.params.max_degree;
for &(nb_id, _) in &selected {
let nid = nb_id as usize;
if !self.neighbors[nid].contains(&(i as u32)) {
if self.neighbors[nid].len() < max_deg {
self.neighbors[nid].push(i as u32);
} else {
let rev_cands: Vec<(u32, f32)> = self.neighbors[nid]
.iter()
.chain(std::iter::once(&(i as u32)))
.map(|&cid| {
(
cid,
sparse_distance(
&self.vectors[nid],
&self.vectors[cid as usize],
),
)
})
.collect();
let pruned = self.rng_prune(nid, &rev_cands);
self.neighbors[nid] = pruned.iter().map(|&(id, _)| id).collect();
}
}
}
}
self.ensure_connectivity();
self.built = true;
Ok(())
}
pub fn search(&self, query: &SparseVector, k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let ef = self.params.ef_search.max(k);
let results = self.beam_search_query(query, ef);
Ok(results
.into_iter()
.take(k)
.map(|(id, dist)| (self.doc_ids[id as usize], dist))
.collect())
}
pub fn len(&self) -> usize {
self.num_vectors
}
pub fn is_empty(&self) -> bool {
self.num_vectors == 0
}
fn find_entry_point(&self) -> u32 {
let mut best = 0u32;
let mut best_norm = self.vectors[0].norm();
for i in 1..self.num_vectors {
let n = self.vectors[i].norm();
if n > best_norm {
best_norm = n;
best = i as u32;
}
}
best
}
fn build_knn_bruteforce(&mut self) {
let n = self.num_vectors;
let k = (self.params.max_degree / 2).max(1).min(n - 1);
self.neighbors = vec![SmallVec::new(); n];
for i in 0..n {
let mut dists: Vec<(u32, f32)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
(
j as u32,
sparse_distance(&self.vectors[i], &self.vectors[j]),
)
})
.collect();
dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
dists.truncate(k);
self.neighbors[i] = dists.iter().map(|&(id, _)| id).collect();
}
}
fn build_knn_nndescent(&mut self) {
let (n, k) = (self.num_vectors, (self.params.max_degree / 2).max(1));
let vecs = &self.vectors;
self.neighbors = crate::graph_utils::build_knn_graph_nndescent(n, k, |i, j| {
sparse_distance(&vecs[i], &vecs[j])
});
}
fn rng_prune(&self, query_idx: usize, candidates: &[(u32, f32)]) -> Vec<(u32, f32)> {
let mut sorted: Vec<(u32, f32)> = candidates.to_vec();
sorted.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
sorted.dedup_by_key(|c| c.0);
sorted.retain(|&(id, _)| id as usize != query_idx);
let max_deg = self.params.max_degree;
let alpha = self.params.alpha;
let mut selected: Vec<(u32, f32)> = Vec::with_capacity(max_deg);
for &(cand_id, cand_dist) in &sorted {
if selected.len() >= max_deg {
break;
}
let mut keep = true;
for &(sel_id, _sel_dist) in &selected {
let sc_dist = sparse_distance(
&self.vectors[sel_id as usize],
&self.vectors[cand_id as usize],
);
if alpha * sc_dist <= cand_dist {
keep = false;
break;
}
}
if keep {
selected.push((cand_id, cand_dist));
}
}
selected
}
fn beam_search_internal(&self, query_idx: usize, ef: usize) -> Vec<(u32, f32)> {
self.beam_search_query(&self.vectors[query_idx].clone(), ef)
}
fn beam_search_query(&self, query: &SparseVector, ef: usize) -> Vec<(u32, f32)> {
let n = self.num_vectors;
if n == 0 {
return Vec::new();
}
let mut visited: HashSet<u32> = HashSet::new();
let mut frontier: BinaryHeap<std::cmp::Reverse<(FloatOrd, u32)>> = BinaryHeap::new();
let mut candidates: Vec<(u32, f32)> = Vec::new();
let entry = self.entry_point;
let entry_dist = sparse_distance(query, &self.vectors[entry as usize]);
visited.insert(entry);
frontier.push(std::cmp::Reverse((FloatOrd(entry_dist), entry)));
candidates.push((entry, entry_dist));
while let Some(std::cmp::Reverse((FloatOrd(current_dist), current_id))) = frontier.pop() {
if candidates.len() >= ef {
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
if current_dist > candidates[ef - 1].1 * 1.5 {
break;
}
}
for &neighbor in &self.neighbors[current_id as usize] {
if visited.insert(neighbor) {
let dist = sparse_distance(query, &self.vectors[neighbor as usize]);
candidates.push((neighbor, dist));
frontier.push(std::cmp::Reverse((FloatOrd(dist), neighbor)));
}
}
if visited.len() > ef * 10 {
break;
}
}
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
candidates.dedup_by_key(|c| c.0);
candidates
}
fn ensure_connectivity(&mut self) {
let vecs = &self.vectors;
crate::graph_utils::ensure_connectivity(&mut self.neighbors, self.entry_point, |i, j| {
sparse_distance(&vecs[i], &vecs[j])
});
}
}
pub fn sparse_distance(a: &SparseVector, b: &SparseVector) -> f32 {
let mut dot = 0.0f32;
let mut ai = 0usize;
let mut bi = 0usize;
let a_idx = &a.indices;
let b_idx = &b.indices;
let a_val = &a.values;
let b_val = &b.values;
while ai < a_idx.len() && bi < b_idx.len() {
match a_idx[ai].cmp(&b_idx[bi]) {
std::cmp::Ordering::Equal => {
dot += a_val[ai] * b_val[bi];
ai += 1;
bi += 1;
}
std::cmp::Ordering::Less => {
ai += 1;
}
std::cmp::Ordering::Greater => {
bi += 1;
}
}
}
-dot
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_sparse(n: usize, nnz: usize, max_dim: u32, seed: u64) -> Vec<SparseVector> {
let mut rng = seed;
let lcg = |state: &mut u64| -> u64 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*state
};
(0..n)
.map(|_| {
let mut indices: Vec<u32> = Vec::with_capacity(nnz);
let mut attempts = 0usize;
while indices.len() < nnz && attempts < nnz * 16 {
attempts += 1;
let idx = (lcg(&mut rng) >> 33) as u32 % max_dim;
if !indices.contains(&idx) {
indices.push(idx);
}
}
indices.sort_unstable();
let values: Vec<f32> = indices
.iter()
.map(|_| {
let r = lcg(&mut rng);
(r >> 33) as f32 / (1u64 << 31) as f32
})
.collect();
SparseVector { indices, values }
})
.collect()
}
#[test]
fn build_and_search() {
let n = 50;
let vecs = make_sparse(n, 10, 100, 42);
let mut index = SparseMipsIndex::new(SparseMipsParams {
max_degree: 16,
ef_construction: 64,
ef_search: 32,
alpha: 1.2,
});
for (i, v) in vecs.iter().enumerate() {
index.add(i as u32, v.clone()).unwrap();
}
index.build().unwrap();
assert_eq!(index.len(), n);
let results = index.search(&vecs[0], 5).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 5);
assert!(results.iter().any(|&(id, _)| id == 0));
}
#[test]
fn self_search_recall() {
let n = 50;
let vecs = make_sparse(n, 15, 200, 7);
let mut index = SparseMipsIndex::new(SparseMipsParams {
max_degree: 16,
ef_construction: 100,
ef_search: 50,
alpha: 1.2,
});
for (i, v) in vecs.iter().enumerate() {
index.add(i as u32, v.clone()).unwrap();
}
index.build().unwrap();
let mut hits = 0usize;
for (i, v) in vecs.iter().enumerate() {
let results = index.search(v, 1).unwrap();
if results.first().map(|&(id, _)| id) == Some(i as u32) {
hits += 1;
}
}
let recall = hits as f64 / n as f64;
assert!(
recall > 0.5,
"self-search recall too low: {recall:.2} ({hits}/{n})"
);
}
#[test]
fn empty_sparse_vectors() {
let mut index = SparseMipsIndex::new(SparseMipsParams::default());
for i in 0..5u32 {
index
.add(
i,
SparseVector {
indices: vec![],
values: vec![],
},
)
.unwrap();
}
index.build().unwrap();
let query = SparseVector {
indices: vec![],
values: vec![],
};
let results = index.search(&query, 3).unwrap();
assert!(!results.is_empty());
}
#[test]
fn disjoint_vectors() {
let mut index = SparseMipsIndex::new(SparseMipsParams {
max_degree: 8,
ef_construction: 32,
ef_search: 16,
alpha: 1.2,
});
for i in 0..5u32 {
let sv = SparseVector {
indices: (0u32..10).collect(),
values: vec![1.0; 10],
};
index.add(i, sv).unwrap();
}
for i in 5..10u32 {
let sv = SparseVector {
indices: (100u32..110).collect(),
values: vec![1.0; 10],
};
index.add(i, sv).unwrap();
}
index.build().unwrap();
let query_a = SparseVector {
indices: (0u32..10).collect(),
values: vec![1.0; 10],
};
let results = index.search(&query_a, 3).unwrap();
assert!(!results.is_empty());
let all_group_a = results.iter().all(|&(id, _)| id < 5);
assert!(all_group_a, "expected group A results, got: {results:?}");
}
#[test]
fn empty_index_errors() {
let mut index = SparseMipsIndex::new(SparseMipsParams::default());
assert!(index.build().is_err());
assert!(index.search(&SparseVector::default(), 1).is_err());
}
#[test]
fn add_after_build_errors() {
let mut index = SparseMipsIndex::new(SparseMipsParams::default());
index
.add(
0,
SparseVector {
indices: vec![0],
values: vec![1.0],
},
)
.unwrap();
index.build().unwrap();
let err = index.add(1, SparseVector::default());
assert!(err.is_err());
}
#[test]
fn sparse_distance_correctness() {
let a = SparseVector {
indices: vec![0, 2, 4],
values: vec![1.0, 2.0, 3.0],
};
let b = SparseVector {
indices: vec![1, 2, 3],
values: vec![5.0, 6.0, 7.0],
};
let d = sparse_distance(&a, &b);
assert!((d - (-12.0f32)).abs() < 1e-5, "expected -12.0, got {d}");
let c = SparseVector {
indices: vec![10, 11],
values: vec![3.0, 4.0],
};
let d2 = sparse_distance(&a, &c);
assert!((d2).abs() < 1e-5, "expected 0.0, got {d2}");
}
#[test]
fn from_pairs_sorts_and_deduplicates() {
let sv = SparseVector::from_pairs(vec![(3, 1.0), (1, 2.0), (3, 5.0), (0, 0.5)]);
assert_eq!(sv.indices, vec![0, 1, 3]);
assert_eq!(sv.values, vec![0.5, 2.0, 1.0]);
}
}