use std::collections::HashMap;
use crate::error::VectorError;
use crate::hnsw::graph::HnswIndex;
use nodedb_types::hnsw::HnswParams;
use nodedb_types::vector_distance::DistanceMetric;
pub type PredicateSignature = String;
pub struct SieveCollection {
subindices: HashMap<PredicateSignature, HnswIndex>,
sub_m: usize,
}
impl SieveCollection {
pub fn new(sub_m: usize) -> Self {
Self {
subindices: HashMap::new(),
sub_m,
}
}
pub fn build_subindex(
&mut self,
signature: PredicateSignature,
vectors: &[(u32, Vec<f32>)],
dim: usize,
metric: DistanceMetric,
) -> Result<(), VectorError> {
let params = HnswParams {
m: self.sub_m,
m0: self.sub_m * 2,
ef_construction: 200,
metric,
};
let mut index = HnswIndex::new(dim, params);
for (_, vec) in vectors {
index.insert(vec.clone())?;
}
self.subindices.insert(signature, index);
Ok(())
}
pub fn has(&self, signature: &PredicateSignature) -> bool {
self.subindices.contains_key(signature)
}
pub fn get(&self, signature: &PredicateSignature) -> Option<&HnswIndex> {
self.subindices.get(signature)
}
pub fn drop(&mut self, signature: &PredicateSignature) {
self.subindices.remove(signature);
}
pub fn signatures(&self) -> Vec<&PredicateSignature> {
self.subindices.keys().collect()
}
}
pub use crate::hnsw::graph::SearchResult as SubindexSearchResult;
#[cfg(test)]
mod tests {
use super::*;
fn sample_vectors(n: usize, dim: usize) -> Vec<(u32, Vec<f32>)> {
(0..n).map(|i| (i as u32, vec![i as f32; dim])).collect()
}
#[test]
fn build_subindex_has_and_get() {
let mut coll = SieveCollection::new(8);
let vecs = sample_vectors(5, 3);
coll.build_subindex("tenant_id=42".to_string(), &vecs, 3, DistanceMetric::L2)
.expect("build should succeed");
assert!(coll.has(&"tenant_id=42".to_string()));
let idx = coll.get(&"tenant_id=42".to_string());
assert!(idx.is_some());
assert_eq!(idx.unwrap().len(), 5);
}
#[test]
fn drop_removes_subindex() {
let mut coll = SieveCollection::new(8);
let vecs = sample_vectors(5, 3);
coll.build_subindex("lang=en".to_string(), &vecs, 3, DistanceMetric::Cosine)
.expect("build should succeed");
assert!(coll.has(&"lang=en".to_string()));
coll.drop(&"lang=en".to_string());
assert!(!coll.has(&"lang=en".to_string()));
assert!(coll.get(&"lang=en".to_string()).is_none());
}
#[test]
fn signatures_lists_all_keys() {
let mut coll = SieveCollection::new(8);
let vecs = sample_vectors(3, 2);
coll.build_subindex("a".to_string(), &vecs, 2, DistanceMetric::L2)
.expect("build a");
coll.build_subindex("b".to_string(), &vecs, 2, DistanceMetric::L2)
.expect("build b");
let mut sigs: Vec<String> = coll.signatures().into_iter().cloned().collect();
sigs.sort();
assert_eq!(sigs, vec!["a".to_string(), "b".to_string()]);
}
#[test]
fn search_on_subindex() {
let mut coll = SieveCollection::new(8);
let vecs = sample_vectors(5, 3);
coll.build_subindex("tenant_id=1".to_string(), &vecs, 3, DistanceMetric::L2)
.expect("build");
let idx = coll.get(&"tenant_id=1".to_string()).unwrap();
let results = idx.search(&[2.0, 2.0, 2.0], 2, 32);
assert!(!results.is_empty());
}
}