use crate::filter::ast::FilterExpr;
use crate::filter::error::FilterError;
use crate::filter::evaluator::evaluate;
use crate::filter::strategy::{
calculate_oversample, estimate_selectivity, is_contradiction, is_tautology, select_strategy,
FilterStrategy, MetadataStore, EF_CAP,
};
use crate::hnsw::graph::{GraphError, HnswIndex};
use crate::hnsw::search::{SearchContext, SearchResult};
use crate::metadata::MetadataValue;
use crate::storage::VectorStorage;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct FilteredSearchResult {
pub results: Vec<SearchResult>,
pub complete: bool,
pub observed_selectivity: f32,
pub strategy_used: FilterStrategy,
pub vectors_evaluated: usize,
}
impl FilteredSearchResult {
#[must_use]
pub fn empty(strategy: FilterStrategy) -> Self {
Self {
results: vec![],
complete: true,
observed_selectivity: 0.0,
strategy_used: strategy,
vectors_evaluated: 0,
}
}
#[must_use]
pub fn full(results: Vec<SearchResult>, strategy: FilterStrategy) -> Self {
let len = results.len();
Self {
results,
complete: true,
observed_selectivity: 1.0,
strategy_used: strategy,
vectors_evaluated: len,
}
}
}
#[derive(Debug, Clone)]
pub enum FilteredSearchError {
Filter(FilterError),
Graph(GraphError),
}
impl From<FilterError> for FilteredSearchError {
fn from(e: FilterError) -> Self {
FilteredSearchError::Filter(e)
}
}
impl From<GraphError> for FilteredSearchError {
fn from(e: GraphError) -> Self {
FilteredSearchError::Graph(e)
}
}
impl std::fmt::Display for FilteredSearchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FilteredSearchError::Filter(e) => write!(f, "filter error: {e}"),
FilteredSearchError::Graph(e) => write!(f, "graph error: {e}"),
}
}
}
impl std::error::Error for FilteredSearchError {}
#[derive(Debug, Clone, Default)]
pub struct VectorMetadataStore {
metadata: Vec<HashMap<String, MetadataValue>>,
}
impl VectorMetadataStore {
#[must_use]
pub fn new() -> Self {
Self {
metadata: Vec::new(),
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
metadata: Vec::with_capacity(capacity),
}
}
pub fn push(&mut self, metadata: HashMap<String, MetadataValue>) -> usize {
let idx = self.metadata.len();
self.metadata.push(metadata);
idx
}
pub fn set(&mut self, idx: usize, metadata: HashMap<String, MetadataValue>) {
if idx >= self.metadata.len() {
self.metadata.resize_with(idx + 1, HashMap::new);
}
self.metadata[idx] = metadata;
}
#[must_use]
pub fn get(&self, idx: usize) -> Option<&HashMap<String, MetadataValue>> {
self.metadata.get(idx)
}
}
impl MetadataStore for VectorMetadataStore {
fn get_metadata(&self, id: usize) -> Option<&HashMap<String, MetadataValue>> {
self.metadata.get(id)
}
fn len(&self) -> usize {
self.metadata.len()
}
}
pub struct FilteredSearcher<'idx, 'sto, 'meta, M: MetadataStore> {
index: &'idx HnswIndex,
storage: &'sto VectorStorage,
metadata: &'meta M,
search_ctx: SearchContext,
}
impl<'idx, 'sto, 'meta, M: MetadataStore> FilteredSearcher<'idx, 'sto, 'meta, M> {
#[must_use]
pub fn new(index: &'idx HnswIndex, storage: &'sto VectorStorage, metadata: &'meta M) -> Self {
Self {
index,
storage,
metadata,
search_ctx: SearchContext::new(),
}
}
pub fn search_filtered(
&mut self,
query: &[f32],
k: usize,
filter: Option<&FilterExpr>,
strategy: FilterStrategy,
) -> Result<FilteredSearchResult, FilteredSearchError> {
strategy.validate()?;
let Some(filter) = filter else {
let results = self.index.search(query, k, self.storage)?;
return Ok(FilteredSearchResult {
complete: results.len() >= k || self.index.is_empty(),
vectors_evaluated: k.min(self.index.len()),
observed_selectivity: 1.0,
strategy_used: strategy,
results,
});
};
if let Some(result) = self.handle_filter_edge_cases(filter, k, query)? {
return Ok(result);
}
let actual_strategy = match strategy {
FilterStrategy::Auto => {
let estimate = estimate_selectivity(filter, self.metadata, Some(42));
select_strategy(estimate.selectivity)
}
other => other,
};
match actual_strategy {
FilterStrategy::PreFilter => self.search_prefilter(query, k, filter),
FilterStrategy::PostFilter { oversample } => {
self.search_postfilter(query, k, filter, oversample)
}
FilterStrategy::Hybrid {
oversample_min,
oversample_max,
} => self.search_hybrid(query, k, filter, oversample_min, oversample_max),
FilterStrategy::Auto => unreachable!("Auto already resolved above"),
}
}
fn handle_filter_edge_cases(
&mut self,
filter: &FilterExpr,
k: usize,
query: &[f32],
) -> Result<Option<FilteredSearchResult>, FilteredSearchError> {
if self.index.is_empty() {
return Ok(Some(FilteredSearchResult::empty(FilterStrategy::Auto)));
}
if is_tautology(filter) {
let results = self.index.search(query, k, self.storage)?;
return Ok(Some(FilteredSearchResult::full(
results,
FilterStrategy::Auto,
)));
}
if is_contradiction(filter) {
return Ok(Some(FilteredSearchResult::empty(FilterStrategy::Auto)));
}
Ok(None) }
fn search_prefilter(
&mut self,
query: &[f32],
k: usize,
filter: &FilterExpr,
) -> Result<FilteredSearchResult, FilteredSearchError> {
let mut passing_indices = HashSet::new();
let total = self.metadata.len();
for idx in 0..total {
if let Some(metadata) = self.metadata.get_metadata(idx) {
if evaluate(filter, metadata).unwrap_or(false) {
passing_indices.insert(idx);
}
}
}
let passed = passing_indices.len();
#[allow(clippy::cast_precision_loss)]
let selectivity = if total > 0 {
(passed as f32) / (total as f32)
} else {
0.0
};
if passing_indices.is_empty() {
return Ok(FilteredSearchResult {
results: vec![],
complete: true,
observed_selectivity: 0.0,
strategy_used: FilterStrategy::PreFilter,
vectors_evaluated: total,
});
}
let ef_effective = (k * 10).min(EF_CAP).max(k);
let all_results = self.index.search_with_context(
query,
ef_effective,
self.storage,
&mut self.search_ctx,
)?;
let mut results = Vec::with_capacity(k);
for result in all_results {
if results.len() >= k {
break;
}
#[allow(clippy::cast_possible_truncation)]
let idx = (result.vector_id.0 as usize).saturating_sub(1);
if passing_indices.contains(&idx) {
results.push(result);
}
}
Ok(FilteredSearchResult {
complete: results.len() >= k,
observed_selectivity: selectivity,
strategy_used: FilterStrategy::PreFilter,
vectors_evaluated: total,
results,
})
}
fn search_postfilter(
&mut self,
query: &[f32],
k: usize,
filter: &FilterExpr,
oversample: f32,
) -> Result<FilteredSearchResult, FilteredSearchError> {
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
#[allow(clippy::cast_precision_loss)]
let ef_effective = ((k as f32) * oversample).ceil() as usize;
let ef_effective = ef_effective.min(EF_CAP).max(k);
let candidates = self.index.search_with_context(
query,
ef_effective,
self.storage,
&mut self.search_ctx,
)?;
let mut results = Vec::with_capacity(k);
let mut passed = 0;
let evaluated = candidates.len();
for candidate in candidates {
if results.len() >= k {
break;
}
#[allow(clippy::cast_possible_truncation)]
let idx = (candidate.vector_id.0 as usize).saturating_sub(1);
if let Some(metadata) = self.metadata.get_metadata(idx) {
if evaluate(filter, metadata).unwrap_or(false) {
results.push(candidate);
passed += 1;
}
}
}
#[allow(clippy::cast_precision_loss)]
let selectivity = if evaluated > 0 {
(passed as f32) / (evaluated as f32)
} else {
0.0
};
Ok(FilteredSearchResult {
complete: results.len() >= k,
observed_selectivity: selectivity,
strategy_used: FilterStrategy::PostFilter { oversample },
vectors_evaluated: evaluated,
results,
})
}
fn search_hybrid(
&mut self,
query: &[f32],
k: usize,
filter: &FilterExpr,
oversample_min: f32,
oversample_max: f32,
) -> Result<FilteredSearchResult, FilteredSearchError> {
let estimate = estimate_selectivity(filter, self.metadata, Some(42));
let oversample = calculate_oversample(estimate.selectivity)
.max(oversample_min)
.min(oversample_max);
let mut result = self.search_postfilter(query, k, filter, oversample)?;
result.strategy_used = FilterStrategy::Hybrid {
oversample_min,
oversample_max,
};
Ok(result)
}
pub fn search_binary_filtered(
&mut self,
query: &[u8],
k: usize,
filter: Option<&FilterExpr>,
strategy: FilterStrategy,
) -> Result<FilteredSearchResult, FilteredSearchError> {
strategy.validate()?;
let Some(filter) = filter else {
let results = self.index.search_binary(query, k, self.storage)?;
return Ok(FilteredSearchResult {
complete: results.len() >= k || self.index.is_empty(),
vectors_evaluated: k.min(self.index.len()),
observed_selectivity: 1.0,
strategy_used: strategy,
results,
});
};
if let Some(result) = self.handle_binary_filter_edge_cases(filter, k, query)? {
return Ok(result);
}
let _ = strategy;
self.search_binary_prefilter(query, k, filter)
}
fn handle_binary_filter_edge_cases(
&mut self,
filter: &FilterExpr,
k: usize,
query: &[u8],
) -> Result<Option<FilteredSearchResult>, FilteredSearchError> {
if self.index.is_empty() {
return Ok(Some(FilteredSearchResult::empty(FilterStrategy::Auto)));
}
if is_tautology(filter) {
let results = self.index.search_binary(query, k, self.storage)?;
return Ok(Some(FilteredSearchResult::full(
results,
FilterStrategy::Auto,
)));
}
if is_contradiction(filter) {
return Ok(Some(FilteredSearchResult::empty(FilterStrategy::Auto)));
}
Ok(None) }
fn search_binary_prefilter(
&mut self,
query: &[u8],
k: usize,
filter: &FilterExpr,
) -> Result<FilteredSearchResult, FilteredSearchError> {
let mut passing_indices = HashSet::new();
let total = self.metadata.len();
for idx in 0..total {
if let Some(metadata) = self.metadata.get_metadata(idx) {
if evaluate(filter, metadata).unwrap_or(false) {
passing_indices.insert(idx);
}
}
}
let passed = passing_indices.len();
#[allow(clippy::cast_precision_loss)]
let selectivity = if total > 0 {
(passed as f32) / (total as f32)
} else {
0.0
};
if passing_indices.is_empty() {
return Ok(FilteredSearchResult {
results: vec![],
complete: true,
observed_selectivity: 0.0,
strategy_used: FilterStrategy::PreFilter,
vectors_evaluated: total,
});
}
let ef_effective = (k * 10).min(EF_CAP).max(k);
let all_results = self.index.search_binary_with_context(
query,
ef_effective,
self.storage,
&mut self.search_ctx,
)?;
let mut results = Vec::with_capacity(k);
for result in all_results {
if results.len() >= k {
break;
}
#[allow(clippy::cast_possible_truncation)]
let idx = (result.vector_id.0 as usize).saturating_sub(1);
if passing_indices.contains(&idx) {
results.push(result);
}
}
Ok(FilteredSearchResult {
complete: results.len() >= k,
observed_selectivity: selectivity,
strategy_used: FilterStrategy::PreFilter,
vectors_evaluated: total,
results,
})
}
#[allow(dead_code)]
fn search_binary_postfilter(
&mut self,
query: &[u8],
k: usize,
filter: &FilterExpr,
oversample: f32,
) -> Result<FilteredSearchResult, FilteredSearchError> {
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
#[allow(clippy::cast_precision_loss)]
let ef_effective = ((k as f32) * oversample).ceil() as usize;
let ef_effective = ef_effective.min(EF_CAP).max(k);
let candidates = self.index.search_binary_with_context(
query,
ef_effective,
self.storage,
&mut self.search_ctx,
)?;
let mut results = Vec::with_capacity(k);
let mut passed = 0;
let evaluated = candidates.len();
for candidate in candidates {
if results.len() >= k {
break;
}
#[allow(clippy::cast_possible_truncation)]
let idx = (candidate.vector_id.0 as usize).saturating_sub(1);
if let Some(metadata) = self.metadata.get_metadata(idx) {
if evaluate(filter, metadata).unwrap_or(false) {
results.push(candidate);
passed += 1;
}
}
}
#[allow(clippy::cast_precision_loss)]
let selectivity = if evaluated > 0 {
(passed as f32) / (evaluated as f32)
} else {
0.0
};
Ok(FilteredSearchResult {
complete: results.len() >= k,
observed_selectivity: selectivity,
strategy_used: FilterStrategy::PostFilter { oversample },
vectors_evaluated: evaluated,
results,
})
}
#[allow(dead_code)]
fn search_binary_hybrid(
&mut self,
query: &[u8],
k: usize,
filter: &FilterExpr,
oversample_min: f32,
oversample_max: f32,
) -> Result<FilteredSearchResult, FilteredSearchError> {
let estimate = estimate_selectivity(filter, self.metadata, Some(42));
let oversample = calculate_oversample(estimate.selectivity)
.max(oversample_min)
.min(oversample_max);
let mut result = self.search_binary_postfilter(query, k, filter, oversample)?;
result.strategy_used = FilterStrategy::Hybrid {
oversample_min,
oversample_max,
};
Ok(result)
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)] #[allow(clippy::cast_possible_wrap)] mod tests {
use super::*;
use crate::filter::parse;
use crate::hnsw::config::HnswConfig;
use crate::hnsw::graph::VectorId;
fn create_test_index(
count: usize,
dim: u32,
) -> (HnswIndex, VectorStorage, VectorMetadataStore) {
let config = HnswConfig::new(dim);
let mut storage = VectorStorage::new(&config, None);
let mut index = HnswIndex::new(config, &storage).expect("Failed to create index");
let mut metadata_store = VectorMetadataStore::with_capacity(count);
for i in 0..count {
#[allow(clippy::cast_precision_loss)]
let vector: Vec<f32> = (0..dim).map(|d| (i + d as usize) as f32).collect();
let _vid = index
.insert(&vector, &mut storage)
.expect("Failed to insert into index");
let mut meta = HashMap::new();
#[allow(clippy::cast_precision_loss)]
{
let category = if i % 3 == 0 {
"gpu"
} else if i % 3 == 1 {
"cpu"
} else {
"memory"
};
meta.insert(
"category".to_string(),
MetadataValue::String(category.to_string()),
);
meta.insert(
"price".to_string(),
MetadataValue::Integer((i * 100) as i64),
);
meta.insert("active".to_string(), MetadataValue::Boolean(i % 2 == 0));
}
metadata_store.push(meta);
}
(index, storage, metadata_store)
}
#[test]
fn test_search_filtered_no_filter() {
let (index, storage, metadata) = create_test_index(100, 8);
let mut searcher = FilteredSearcher::new(&index, &storage, &metadata);
let query: Vec<f32> = vec![0.0; 8];
let result = searcher
.search_filtered(&query, 10, None, FilterStrategy::Auto)
.unwrap();
assert_eq!(result.results.len(), 10);
assert!(result.complete);
assert!((result.observed_selectivity - 1.0).abs() < 0.001);
}
#[test]
fn test_search_filtered_prefilter() {
let (index, storage, metadata) = create_test_index(100, 8);
let mut searcher = FilteredSearcher::new(&index, &storage, &metadata);
let filter = parse("category = \"gpu\"").unwrap();
let query: Vec<f32> = vec![0.0; 8];
let result = searcher
.search_filtered(&query, 10, Some(&filter), FilterStrategy::PreFilter)
.unwrap();
assert_eq!(result.strategy_used, FilterStrategy::PreFilter);
assert!(!result.results.is_empty());
}
#[test]
fn test_search_filtered_postfilter() {
let (index, storage, metadata) = create_test_index(100, 8);
let mut searcher = FilteredSearcher::new(&index, &storage, &metadata);
let filter = parse("category = \"gpu\"").unwrap();
let query: Vec<f32> = vec![0.0; 8];
let result = searcher
.search_filtered(
&query,
5,
Some(&filter),
FilterStrategy::PostFilter { oversample: 5.0 },
)
.unwrap();
assert!(matches!(
result.strategy_used,
FilterStrategy::PostFilter { .. }
));
assert!(result.vectors_evaluated >= 5);
}
#[test]
fn test_search_filtered_hybrid() {
let (index, storage, metadata) = create_test_index(100, 8);
let mut searcher = FilteredSearcher::new(&index, &storage, &metadata);
let filter = parse("active = true").unwrap();
let query: Vec<f32> = vec![0.0; 8];
let result = searcher
.search_filtered(&query, 5, Some(&filter), FilterStrategy::HYBRID_DEFAULT)
.unwrap();
assert!(matches!(
result.strategy_used,
FilterStrategy::Hybrid { .. }
));
}
#[test]
fn test_search_filtered_auto() {
let (index, storage, metadata) = create_test_index(100, 8);
let mut searcher = FilteredSearcher::new(&index, &storage, &metadata);
let filter = parse("active = true").unwrap();
let query: Vec<f32> = vec![0.0; 8];
let result = searcher
.search_filtered(&query, 5, Some(&filter), FilterStrategy::Auto)
.unwrap();
assert!(!matches!(result.strategy_used, FilterStrategy::Auto));
}
#[test]
fn test_search_filtered_empty_index() {
let config = HnswConfig::new(8);
let storage = VectorStorage::new(&config, None);
let index = HnswIndex::new(config, &storage).expect("Failed to create index");
let metadata = VectorMetadataStore::new();
let mut searcher = FilteredSearcher::new(&index, &storage, &metadata);
let filter = parse("active = true").unwrap();
let query: Vec<f32> = vec![0.0; 8];
let result = searcher
.search_filtered(&query, 5, Some(&filter), FilterStrategy::Auto)
.unwrap();
assert!(result.results.is_empty());
assert!(result.complete);
}
#[test]
fn test_search_filtered_tautology() {
let (index, storage, metadata) = create_test_index(100, 8);
let mut searcher = FilteredSearcher::new(&index, &storage, &metadata);
let filter = FilterExpr::LiteralBool(true);
let query: Vec<f32> = vec![0.0; 8];
let result = searcher
.search_filtered(&query, 10, Some(&filter), FilterStrategy::Auto)
.unwrap();
assert_eq!(result.results.len(), 10);
assert!((result.observed_selectivity - 1.0).abs() < 0.001);
}
#[test]
fn test_search_filtered_contradiction() {
let (index, storage, metadata) = create_test_index(100, 8);
let mut searcher = FilteredSearcher::new(&index, &storage, &metadata);
let filter = FilterExpr::LiteralBool(false);
let query: Vec<f32> = vec![0.0; 8];
let result = searcher
.search_filtered(&query, 10, Some(&filter), FilterStrategy::Auto)
.unwrap();
assert!(result.results.is_empty());
assert!(result.complete);
assert_eq!(result.vectors_evaluated, 0);
}
#[test]
fn test_vector_metadata_store() {
let mut store = VectorMetadataStore::new();
let mut meta1 = HashMap::new();
meta1.insert(
"key".to_string(),
MetadataValue::String("value1".to_string()),
);
store.push(meta1);
let mut meta2 = HashMap::new();
meta2.insert(
"key".to_string(),
MetadataValue::String("value2".to_string()),
);
store.push(meta2);
assert_eq!(store.len(), 2);
assert!(store.get(0).is_some());
assert!(store.get(1).is_some());
assert!(store.get(2).is_none());
}
#[test]
fn test_filtered_search_result_constructors() {
let empty = FilteredSearchResult::empty(FilterStrategy::PreFilter);
assert!(empty.results.is_empty());
assert!(empty.complete);
assert_eq!(empty.observed_selectivity, 0.0);
let results = vec![SearchResult {
vector_id: VectorId(1),
distance: 0.5,
}];
let full = FilteredSearchResult::full(results.clone(), FilterStrategy::Auto);
assert_eq!(full.results.len(), 1);
assert!(full.complete);
assert_eq!(full.observed_selectivity, 1.0);
}
}