use roaring::RoaringBitmap;
use crate::distance::distance;
use crate::hnsw::{HnswIndex, SearchResult};
pub struct FilterThresholds {
pub high_selectivity: f64,
pub extreme_selectivity: f64,
}
impl Default for FilterThresholds {
fn default() -> Self {
Self {
high_selectivity: 0.50,
extreme_selectivity: 0.95,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FilterStrategy {
PreFilter,
PostFilter { over_fetch_factor: usize },
BruteForceMatching,
}
pub fn estimate_selectivity(bitmap: &RoaringBitmap, total_vectors: usize) -> f64 {
if total_vectors == 0 {
return 0.0;
}
let matching = bitmap.len() as usize;
1.0 - (matching as f64 / total_vectors as f64)
}
pub fn select_strategy(selectivity: f64, thresholds: &FilterThresholds) -> FilterStrategy {
if selectivity >= thresholds.extreme_selectivity {
FilterStrategy::BruteForceMatching
} else if selectivity >= thresholds.high_selectivity {
FilterStrategy::PostFilter {
over_fetch_factor: 10,
}
} else {
FilterStrategy::PreFilter
}
}
pub fn adaptive_search(
index: &HnswIndex,
query: &[f32],
top_k: usize,
ef: usize,
bitmap: &RoaringBitmap,
thresholds: &FilterThresholds,
) -> Vec<SearchResult> {
let total = index.len();
let selectivity = estimate_selectivity(bitmap, total);
let strategy = select_strategy(selectivity, thresholds);
match strategy {
FilterStrategy::PreFilter => index.search_filtered(query, top_k, ef, bitmap),
FilterStrategy::PostFilter { over_fetch_factor } => {
let fetch_k = top_k * over_fetch_factor;
let results = index.search(query, fetch_k, ef.max(fetch_k));
let mut filtered: Vec<SearchResult> = results
.into_iter()
.filter(|r| bitmap.contains(r.id))
.collect();
filtered.truncate(top_k);
filtered
}
FilterStrategy::BruteForceMatching => {
let metric = index.params().metric;
let mut results: Vec<SearchResult> = bitmap
.iter()
.filter_map(|id| {
let v = index.get_vector(id)?;
if index.is_deleted(id) {
return None;
}
Some(SearchResult {
id,
distance: distance(query, v, metric),
})
})
.collect();
if results.len() > top_k {
results.select_nth_unstable_by(top_k, |a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
}
results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distance::DistanceMetric;
use crate::hnsw::{HnswIndex, HnswParams};
fn build_test_index() -> HnswIndex {
let mut idx = HnswIndex::with_seed(
3,
HnswParams {
m: 8,
m0: 16,
ef_construction: 50,
metric: DistanceMetric::L2,
},
42,
);
for i in 0..1000 {
idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
}
idx
}
#[test]
fn low_selectivity_uses_prefilter() {
let thresholds = FilterThresholds::default();
let strategy = select_strategy(0.2, &thresholds);
assert_eq!(strategy, FilterStrategy::PreFilter);
}
#[test]
fn high_selectivity_uses_postfilter() {
let thresholds = FilterThresholds::default();
let strategy = select_strategy(0.8, &thresholds);
assert!(matches!(strategy, FilterStrategy::PostFilter { .. }));
}
#[test]
fn extreme_selectivity_uses_bruteforce() {
let thresholds = FilterThresholds::default();
let strategy = select_strategy(0.99, &thresholds);
assert_eq!(strategy, FilterStrategy::BruteForceMatching);
}
#[test]
fn adaptive_search_extreme_filter() {
let idx = build_test_index();
let thresholds = FilterThresholds::default();
let mut bitmap = RoaringBitmap::new();
for i in 500..510 {
bitmap.insert(i);
}
let results = adaptive_search(&idx, &[505.0, 0.0, 0.0], 3, 64, &bitmap, &thresholds);
assert_eq!(results.len(), 3);
for r in &results {
assert!(bitmap.contains(r.id), "got filtered-out id {}", r.id);
}
assert_eq!(results[0].id, 505);
}
#[test]
fn adaptive_search_low_filter() {
let idx = build_test_index();
let thresholds = FilterThresholds::default();
let mut bitmap = RoaringBitmap::new();
for i in 0..800 {
bitmap.insert(i);
}
let results = adaptive_search(&idx, &[100.0, 0.0, 0.0], 5, 64, &bitmap, &thresholds);
assert_eq!(results.len(), 5);
for r in &results {
assert!(bitmap.contains(r.id));
}
}
#[test]
fn selectivity_estimation() {
let mut bitmap = RoaringBitmap::new();
for i in 0..100 {
bitmap.insert(i);
}
let sel = estimate_selectivity(&bitmap, 1000);
assert!((sel - 0.9).abs() < 0.01);
}
}