use bit_set::BitSet;
use std::fmt::Debug;
use diskann::{graph::index::QueryLabelProvider, utils::VectorId};
use diskann_benchmark_runner::files::InputFile;
use diskann_label_filter::{
kv_index::GenericIndex,
stores::bftree_store::BfTreeStore,
traits::{
posting_list_trait::{PostingList, RoaringPostingList},
query_evaluator::QueryEvaluator,
},
ASTExpr, DefaultKeyCodec,
};
use diskann_providers::model::graph::provider::layers::BetaFilter;
use diskann_tools::utils::ground_truth::read_labels_and_compute_bitmap;
use std::sync::Arc;
pub struct QueryBitmapEvaluator {
pub ast_expr: ASTExpr,
evaluated_bitmap: RoaringPostingList,
}
impl QueryBitmapEvaluator {
pub fn new(
ast_expr: ASTExpr,
inverted_index: &GenericIndex<BfTreeStore, RoaringPostingList, DefaultKeyCodec>,
) -> Self {
let evaluated_bitmap = inverted_index.evaluate_query(&ast_expr).unwrap();
Self {
ast_expr,
evaluated_bitmap,
}
}
fn get_bitmap(&self) -> &RoaringPostingList {
&self.evaluated_bitmap
}
pub fn count(&self) -> usize {
self.get_bitmap().len()
}
}
impl Debug for QueryBitmapEvaluator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BitmapFilter")
.field("ast_expr", &self.ast_expr)
.field("evaluated_bitmap", &self.evaluated_bitmap)
.finish()
}
}
impl<T> QueryLabelProvider<T> for QueryBitmapEvaluator
where
T: VectorId,
{
fn is_match(&self, vec_id: T) -> bool {
self.get_bitmap().contains(vec_id.into_usize())
}
}
#[derive(Debug)]
pub struct BitmapFilter(pub BitSet);
impl<T> QueryLabelProvider<T> for BitmapFilter
where
T: VectorId,
{
fn is_match(&self, vec_id: T) -> bool {
self.0.contains(vec_id.into_usize())
}
}
pub(crate) fn generate_bitmaps(
query_predicates: &InputFile,
data_labels: &InputFile,
) -> anyhow::Result<Vec<BitSet>> {
let bit_maps = match read_labels_and_compute_bitmap(
data_labels.to_str().unwrap(),
query_predicates.to_str().unwrap(),
) {
Ok(bit_maps) => bit_maps,
Err(e) => {
return Err(e.into());
}
};
Ok(bit_maps)
}
pub(crate) fn setup_filter_strategies<I, S>(
beta: f32,
bit_maps: I,
search_strategy: S,
) -> Vec<BetaFilter<S, u32>>
where
I: IntoIterator<Item = Arc<dyn QueryLabelProvider<u32>>>,
S: Clone,
{
bit_maps
.into_iter()
.map(|bit_map| BetaFilter::<S, u32>::new(search_strategy.clone(), bit_map, beta))
.collect::<Vec<_>>()
}
pub(crate) fn as_query_label_provider(set: BitSet) -> Arc<dyn QueryLabelProvider<u32>> {
Arc::new(BitmapFilter(set))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitmap_filter_match() {
let mut bitset = BitSet::new();
bitset.insert(1);
bitset.insert(3);
let filter = BitmapFilter(bitset);
assert!(filter.is_match(1u32));
assert!(filter.is_match(3u32));
assert!(!filter.is_match(2u32));
assert!(!filter.is_match(0u32));
}
#[test]
fn test_bitmap_filter_empty() {
let bitset = BitSet::new();
let filter = BitmapFilter(bitset);
assert!(!filter.is_match(0u32));
assert!(!filter.is_match(10u32));
}
#[test]
fn test_bitmap_filter_large_id() {
let mut bitset = BitSet::new();
bitset.insert(1000);
let filter = BitmapFilter(bitset);
assert!(filter.is_match(1000u32));
assert!(!filter.is_match(999u32));
}
}