use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FilterRepresentation {
RoaringBitmap,
PostingsList,
HashedSet,
InvertedPostings,
}
impl FilterRepresentation {
pub fn estimated_bytes(&self, n_vectors: usize, cardinality: usize, density: f32) -> usize {
let n_set = (n_vectors as f32 * density) as usize;
match self {
Self::RoaringBitmap => {
if density > 0.5 {
n_vectors / 8 } else {
n_set * 2 }
}
Self::PostingsList => {
n_set * 4 }
Self::HashedSet => {
n_set * 12 }
Self::InvertedPostings => {
cardinality * 16 + n_set * 4 }
}
}
pub fn estimated_query_cost(
&self,
n_vectors: usize,
cardinality: usize,
density: f32,
selectivity: f32,
) -> f32 {
let n_set = (n_vectors as f32 * density) as usize;
let expected_result = (n_set as f32 * selectivity) as usize;
match self {
Self::RoaringBitmap => {
(n_vectors / 64) as f32 }
Self::PostingsList => {
(expected_result as f32).log2() + expected_result as f32
}
Self::HashedSet => {
expected_result as f32
}
Self::InvertedPostings => {
cardinality as f32 * 0.01 + expected_result as f32
}
}
}
}
#[derive(Debug, Clone)]
pub struct AttributeStats {
pub name: String,
pub cardinality: usize,
pub density: f32,
pub avg_selectivity: f32,
pub query_frequency: f32,
pub is_multi_valued: bool,
pub value_distribution: Option<HashMap<String, usize>>,
}
impl AttributeStats {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
cardinality: 1,
density: 1.0,
avg_selectivity: 0.5,
query_frequency: 1.0,
is_multi_valued: false,
value_distribution: None,
}
}
pub fn cardinality(mut self, c: usize) -> Self {
self.cardinality = c;
self
}
pub fn density(mut self, d: f32) -> Self {
self.density = d;
self
}
pub fn selectivity(mut self, s: f32) -> Self {
self.avg_selectivity = s;
self
}
pub fn frequency(mut self, f: f32) -> Self {
self.query_frequency = f;
self
}
pub fn multi_valued(mut self, m: bool) -> Self {
self.is_multi_valued = m;
self
}
pub fn cardinality_ratio(&self, n_vectors: usize) -> f32 {
self.cardinality as f32 / n_vectors.max(1) as f32
}
pub fn is_low_cardinality(&self, n_vectors: usize) -> bool {
self.cardinality_ratio(n_vectors) < 0.01 || self.cardinality < 1000
}
pub fn is_high_cardinality(&self, n_vectors: usize) -> bool {
self.cardinality_ratio(n_vectors) > 0.1 && self.cardinality > 10000
}
}
#[derive(Debug, Clone)]
pub struct FilterPolicy {
pub representation: FilterRepresentation,
pub per_list: bool,
pub cache_results: bool,
pub postfilter_threshold: f32,
}
impl FilterPolicy {
pub fn auto_select(stats: &AttributeStats, n_vectors: usize) -> Self {
let repr = if stats.is_multi_valued {
FilterRepresentation::InvertedPostings
} else if stats.is_low_cardinality(n_vectors) && stats.density > 0.5 {
FilterRepresentation::RoaringBitmap
} else if stats.is_high_cardinality(n_vectors) && stats.density < 0.1 {
FilterRepresentation::PostingsList
} else if stats.cardinality > 100000 {
FilterRepresentation::HashedSet
} else {
FilterRepresentation::RoaringBitmap
};
Self {
representation: repr,
per_list: stats.avg_selectivity < 0.1, cache_results: stats.query_frequency > 10.0,
postfilter_threshold: 0.8, }
}
pub fn bitmap() -> Self {
Self {
representation: FilterRepresentation::RoaringBitmap,
per_list: false,
cache_results: true,
postfilter_threshold: 0.8,
}
}
pub fn postings() -> Self {
Self {
representation: FilterRepresentation::PostingsList,
per_list: true,
cache_results: false,
postfilter_threshold: 0.9,
}
}
}
#[derive(Debug, Clone)]
pub struct BitmapFilter {
words: Vec<u64>,
n_bits: usize,
}
impl BitmapFilter {
pub fn new(n_bits: usize) -> Self {
let n_words = (n_bits + 63) / 64;
Self {
words: vec![0; n_words],
n_bits,
}
}
pub fn set(&mut self, idx: u32) {
if (idx as usize) < self.n_bits {
let word = idx as usize / 64;
let bit = idx as usize % 64;
self.words[word] |= 1 << bit;
}
}
pub fn contains(&self, idx: u32) -> bool {
if (idx as usize) >= self.n_bits {
return false;
}
let word = idx as usize / 64;
let bit = idx as usize % 64;
(self.words[word] & (1 << bit)) != 0
}
pub fn and(&self, other: &BitmapFilter) -> BitmapFilter {
let n_words = self.words.len().min(other.words.len());
let mut result = BitmapFilter::new(self.n_bits.min(other.n_bits));
for i in 0..n_words {
result.words[i] = self.words[i] & other.words[i];
}
result
}
pub fn or(&self, other: &BitmapFilter) -> BitmapFilter {
let _n_words = self.words.len().max(other.words.len());
let mut result = BitmapFilter::new(self.n_bits.max(other.n_bits));
for i in 0..self.words.len() {
result.words[i] |= self.words[i];
}
for i in 0..other.words.len() {
result.words[i] |= other.words[i];
}
result
}
pub fn count(&self) -> usize {
self.words.iter().map(|w| w.count_ones() as usize).sum()
}
pub fn iter(&self) -> impl Iterator<Item = u32> + '_ {
self.words
.iter()
.enumerate()
.flat_map(|(word_idx, &word)| {
(0..64).filter_map(move |bit| {
if (word & (1 << bit)) != 0 {
Some((word_idx * 64 + bit) as u32)
} else {
None
}
})
})
.filter(move |&idx| (idx as usize) < self.n_bits)
}
pub fn memory_bytes(&self) -> usize {
self.words.len() * 8
}
}
#[derive(Debug, Clone)]
pub struct PostingsFilter {
ids: Vec<u32>,
}
impl PostingsFilter {
pub fn new() -> Self {
Self { ids: Vec::new() }
}
pub fn from_ids(ids: Vec<u32>) -> Self {
let mut ids = ids;
ids.sort_unstable();
ids.dedup();
Self { ids }
}
pub fn add(&mut self, id: u32) {
match self.ids.binary_search(&id) {
Ok(_) => {} Err(pos) => self.ids.insert(pos, id),
}
}
pub fn contains(&self, id: u32) -> bool {
self.ids.binary_search(&id).is_ok()
}
pub fn intersect(&self, other: &PostingsFilter) -> PostingsFilter {
let mut result = Vec::new();
let mut i = 0;
let mut j = 0;
while i < self.ids.len() && j < other.ids.len() {
match self.ids[i].cmp(&other.ids[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result.push(self.ids[i]);
i += 1;
j += 1;
}
}
}
PostingsFilter { ids: result }
}
pub fn union(&self, other: &PostingsFilter) -> PostingsFilter {
let mut result = Vec::with_capacity(self.ids.len() + other.ids.len());
let mut i = 0;
let mut j = 0;
while i < self.ids.len() && j < other.ids.len() {
match self.ids[i].cmp(&other.ids[j]) {
std::cmp::Ordering::Less => {
result.push(self.ids[i]);
i += 1;
}
std::cmp::Ordering::Greater => {
result.push(other.ids[j]);
j += 1;
}
std::cmp::Ordering::Equal => {
result.push(self.ids[i]);
i += 1;
j += 1;
}
}
}
result.extend_from_slice(&self.ids[i..]);
result.extend_from_slice(&other.ids[j..]);
PostingsFilter { ids: result }
}
pub fn count(&self) -> usize {
self.ids.len()
}
pub fn iter(&self) -> impl Iterator<Item = u32> + '_ {
self.ids.iter().copied()
}
pub fn memory_bytes(&self) -> usize {
self.ids.len() * 4
}
}
impl Default for PostingsFilter {
fn default() -> Self {
Self::new()
}
}
pub struct FilterIndex {
attribute: String,
policy: FilterPolicy,
inverted: HashMap<String, FilterData>,
n_vectors: usize,
}
enum FilterData {
Bitmap(BitmapFilter),
Postings(PostingsFilter),
}
impl FilterIndex {
pub fn new(attribute: &str, n_vectors: usize, policy: FilterPolicy) -> Self {
Self {
attribute: attribute.to_string(),
policy,
inverted: HashMap::new(),
n_vectors,
}
}
pub fn add(&mut self, vector_id: u32, value: &str) {
let filter = self.inverted.entry(value.to_string()).or_insert_with(|| {
match self.policy.representation {
FilterRepresentation::RoaringBitmap => {
FilterData::Bitmap(BitmapFilter::new(self.n_vectors))
}
_ => FilterData::Postings(PostingsFilter::new()),
}
});
match filter {
FilterData::Bitmap(b) => b.set(vector_id),
FilterData::Postings(p) => p.add(vector_id),
}
}
pub fn query(&self, value: &str) -> Option<Vec<u32>> {
self.inverted.get(value).map(|filter| match filter {
FilterData::Bitmap(b) => b.iter().collect(),
FilterData::Postings(p) => p.iter().collect(),
})
}
pub fn query_with_stats(&self, value: &str) -> (Option<Vec<u32>>, f32) {
match self.inverted.get(value) {
Some(filter) => {
let ids: Vec<u32> = match filter {
FilterData::Bitmap(b) => b.iter().collect(),
FilterData::Postings(p) => p.iter().collect(),
};
let selectivity = ids.len() as f32 / self.n_vectors.max(1) as f32;
(Some(ids), selectivity)
}
None => (None, 0.0),
}
}
pub fn contains(&self, vector_id: u32, value: &str) -> bool {
match self.inverted.get(value) {
Some(FilterData::Bitmap(b)) => b.contains(vector_id),
Some(FilterData::Postings(p)) => p.contains(vector_id),
None => false,
}
}
pub fn memory_bytes(&self) -> usize {
self.inverted
.values()
.map(|f| match f {
FilterData::Bitmap(b) => b.memory_bytes(),
FilterData::Postings(p) => p.memory_bytes(),
})
.sum()
}
pub fn stats(&self) -> FilterIndexStats {
let n_values = self.inverted.len();
let total_entries: usize = self
.inverted
.values()
.map(|f| match f {
FilterData::Bitmap(b) => b.count(),
FilterData::Postings(p) => p.count(),
})
.sum();
FilterIndexStats {
attribute: self.attribute.clone(),
n_values,
total_entries,
memory_bytes: self.memory_bytes(),
representation: self.policy.representation,
}
}
}
#[derive(Debug, Clone)]
pub struct FilterIndexStats {
pub attribute: String,
pub n_values: usize,
pub total_entries: usize,
pub memory_bytes: usize,
pub representation: FilterRepresentation,
}
pub struct FilterIndexManager {
indexes: HashMap<String, FilterIndex>,
n_vectors: usize,
}
impl FilterIndexManager {
pub fn new(n_vectors: usize) -> Self {
Self {
indexes: HashMap::new(),
n_vectors,
}
}
pub fn get_or_create(&mut self, attribute: &str, stats: &AttributeStats) -> &mut FilterIndex {
let n_vectors = self.n_vectors;
self.indexes
.entry(attribute.to_string())
.or_insert_with(|| {
let policy = FilterPolicy::auto_select(stats, n_vectors);
FilterIndex::new(attribute, n_vectors, policy)
})
}
pub fn add(&mut self, attribute: &str, vector_id: u32, value: &str) {
if let Some(index) = self.indexes.get_mut(attribute) {
index.add(vector_id, value);
}
}
pub fn query(&self, attribute: &str, value: &str) -> Option<Vec<u32>> {
self.indexes.get(attribute).and_then(|idx| idx.query(value))
}
pub fn memory_bytes(&self) -> usize {
self.indexes.values().map(|idx| idx.memory_bytes()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitmap_filter() {
let mut bitmap = BitmapFilter::new(1000);
bitmap.set(0);
bitmap.set(10);
bitmap.set(100);
bitmap.set(999);
assert!(bitmap.contains(0));
assert!(bitmap.contains(10));
assert!(!bitmap.contains(1));
assert!(!bitmap.contains(500));
assert_eq!(bitmap.count(), 4);
}
#[test]
fn test_bitmap_intersection() {
let mut a = BitmapFilter::new(100);
let mut b = BitmapFilter::new(100);
a.set(1);
a.set(2);
a.set(3);
b.set(2);
b.set(3);
b.set(4);
let c = a.and(&b);
assert!(!c.contains(1));
assert!(c.contains(2));
assert!(c.contains(3));
assert!(!c.contains(4));
}
#[test]
fn test_postings_filter() {
let mut postings = PostingsFilter::new();
postings.add(5);
postings.add(10);
postings.add(3);
postings.add(5);
assert!(postings.contains(3));
assert!(postings.contains(5));
assert!(postings.contains(10));
assert!(!postings.contains(7));
assert_eq!(postings.count(), 3);
}
#[test]
fn test_postings_intersection() {
let a = PostingsFilter::from_ids(vec![1, 2, 3, 5, 7]);
let b = PostingsFilter::from_ids(vec![2, 3, 6, 7, 8]);
let c = a.intersect(&b);
assert_eq!(c.count(), 3);
assert!(c.contains(2));
assert!(c.contains(3));
assert!(c.contains(7));
}
#[test]
fn test_policy_selection() {
let stats1 = AttributeStats::new("status").cardinality(5).density(0.9);
let policy1 = FilterPolicy::auto_select(&stats1, 1_000_000);
assert_eq!(policy1.representation, FilterRepresentation::RoaringBitmap);
let stats2 = AttributeStats::new("user_id")
.cardinality(500_000)
.density(0.001);
let policy2 = FilterPolicy::auto_select(&stats2, 1_000_000);
assert_eq!(policy2.representation, FilterRepresentation::PostingsList);
}
#[test]
fn test_filter_index() {
let policy = FilterPolicy::bitmap();
let mut index = FilterIndex::new("category", 1000, policy);
index.add(0, "electronics");
index.add(1, "electronics");
index.add(2, "clothing");
index.add(3, "electronics");
let electronics = index.query("electronics").unwrap();
assert_eq!(electronics.len(), 3);
assert!(electronics.contains(&0));
assert!(electronics.contains(&1));
assert!(electronics.contains(&3));
let clothing = index.query("clothing").unwrap();
assert_eq!(clothing.len(), 1);
}
}