use roaring::RoaringBitmap;
use super::collection::{PredicateSignature, SieveCollection};
use crate::hnsw::graph::{HnswIndex, SearchResult};
use crate::navix::traversal::{NavixSearchOptions, navix_search};
use nodedb_types::vector_distance::DistanceMetric;
pub struct SieveRouter<'a> {
pub collection: &'a SieveCollection,
pub fallback: &'a HnswIndex,
}
impl<'a> SieveRouter<'a> {
pub fn route(
&self,
query: &[f32],
predicate_signature: Option<&PredicateSignature>,
allowed: RoaringBitmap,
k: usize,
ef_search: usize,
metric: DistanceMetric,
) -> Vec<SearchResult> {
if let Some(sig) = predicate_signature
&& let Some(subindex) = self.collection.get(sig)
{
return subindex.search(query, k, ef_search);
}
let opts = NavixSearchOptions {
k,
ef_search,
allowed,
brute_force_threshold: 0.001,
};
navix_search(self.fallback, query, &opts, metric)
.into_iter()
.map(|r| SearchResult {
id: r.id,
distance: r.distance,
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hnsw::{HnswIndex, HnswParams};
use crate::sieve::collection::SieveCollection;
use nodedb_types::vector_distance::DistanceMetric;
fn build_fallback(n: usize) -> HnswIndex {
let mut idx = HnswIndex::with_seed(
3,
HnswParams {
m: 8,
m0: 16,
ef_construction: 50,
metric: DistanceMetric::L2,
},
99,
);
for i in 0..n {
idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
}
idx
}
fn all_allowed(n: u32) -> RoaringBitmap {
let mut b = RoaringBitmap::new();
for i in 0..n {
b.insert(i);
}
b
}
#[test]
fn route_hits_subindex() {
let mut coll = SieveCollection::new(8);
let sub_vecs: Vec<(u32, Vec<f32>)> =
(0u32..5).map(|i| (i, vec![i as f32, 0.0, 0.0])).collect();
coll.build_subindex("T".to_string(), &sub_vecs, 3, DistanceMetric::L2)
.expect("build subindex");
let fallback = build_fallback(20);
let router = SieveRouter {
collection: &coll,
fallback: &fallback,
};
let results = router.route(
&[2.0, 0.0, 0.0],
Some(&"T".to_string()),
all_allowed(20), 3,
32,
DistanceMetric::L2,
);
assert!(!results.is_empty());
for r in &results {
assert!(r.id < 5, "expected subindex id < 5, got {}", r.id);
}
}
#[test]
fn route_falls_back_to_navix() {
let coll = SieveCollection::new(8); let fallback = build_fallback(20);
let router = SieveRouter {
collection: &coll,
fallback: &fallback,
};
let allowed = all_allowed(20);
let results = router.route(
&[10.0, 0.0, 0.0],
Some(&"unknown_sig".to_string()),
allowed,
3,
64,
DistanceMetric::L2,
);
assert!(!results.is_empty());
assert_eq!(results[0].id, 10);
}
#[test]
fn route_no_signature_uses_navix() {
let mut coll = SieveCollection::new(8);
let sub_vecs: Vec<(u32, Vec<f32>)> =
(0u32..5).map(|i| (i, vec![i as f32, 0.0, 0.0])).collect();
coll.build_subindex("T".to_string(), &sub_vecs, 3, DistanceMetric::L2)
.expect("build subindex");
let fallback = build_fallback(20);
let router = SieveRouter {
collection: &coll,
fallback: &fallback,
};
let allowed = all_allowed(20);
let results = router.route(&[5.0, 0.0, 0.0], None, allowed, 3, 64, DistanceMetric::L2);
assert!(!results.is_empty());
assert_eq!(results[0].id, 5);
}
}