use std::collections::{BTreeMap, HashMap};
use std::sync::RwLock;
use crate::candidate_gate::AllowedSet;
use crate::filter_ir::{FilterAtom, FilterIR, FilterValue};
#[derive(Debug, Clone)]
pub struct PostingSet {
doc_ids: Vec<u64>,
}
impl PostingSet {
pub fn new() -> Self {
Self {
doc_ids: Vec::new(),
}
}
pub fn from_vec(mut ids: Vec<u64>) -> Self {
ids.sort_unstable();
ids.dedup();
Self { doc_ids: ids }
}
pub fn add(&mut self, doc_id: u64) {
match self.doc_ids.binary_search(&doc_id) {
Ok(_) => {} Err(pos) => self.doc_ids.insert(pos, doc_id),
}
}
pub fn remove(&mut self, doc_id: u64) {
if let Ok(pos) = self.doc_ids.binary_search(&doc_id) {
self.doc_ids.remove(pos);
}
}
pub fn contains(&self, doc_id: u64) -> bool {
self.doc_ids.binary_search(&doc_id).is_ok()
}
pub fn len(&self) -> usize {
self.doc_ids.len()
}
pub fn is_empty(&self) -> bool {
self.doc_ids.is_empty()
}
pub fn to_allowed_set(&self) -> AllowedSet {
if self.doc_ids.is_empty() {
AllowedSet::None
} else {
AllowedSet::from_sorted_vec(self.doc_ids.clone())
}
}
pub fn intersect(&self, other: &PostingSet) -> PostingSet {
let mut result = Vec::with_capacity(self.doc_ids.len().min(other.doc_ids.len()));
let mut i = 0;
let mut j = 0;
while i < self.doc_ids.len() && j < other.doc_ids.len() {
match self.doc_ids[i].cmp(&other.doc_ids[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result.push(self.doc_ids[i]);
i += 1;
j += 1;
}
}
}
PostingSet { doc_ids: result }
}
pub fn union(&self, other: &PostingSet) -> PostingSet {
let mut result = Vec::with_capacity(self.doc_ids.len() + other.doc_ids.len());
let mut i = 0;
let mut j = 0;
while i < self.doc_ids.len() && j < other.doc_ids.len() {
match self.doc_ids[i].cmp(&other.doc_ids[j]) {
std::cmp::Ordering::Less => {
result.push(self.doc_ids[i]);
i += 1;
}
std::cmp::Ordering::Greater => {
result.push(other.doc_ids[j]);
j += 1;
}
std::cmp::Ordering::Equal => {
result.push(self.doc_ids[i]);
i += 1;
j += 1;
}
}
}
result.extend_from_slice(&self.doc_ids[i..]);
result.extend_from_slice(&other.doc_ids[j..]);
PostingSet { doc_ids: result }
}
pub fn iter(&self) -> impl Iterator<Item = u64> + '_ {
self.doc_ids.iter().copied()
}
}
impl Default for PostingSet {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct EqualityIndex {
string_postings: HashMap<String, PostingSet>,
int_postings: HashMap<i64, PostingSet>,
uint_postings: HashMap<u64, PostingSet>,
}
impl EqualityIndex {
pub fn new() -> Self {
Self::default()
}
pub fn add_string(&mut self, value: &str, doc_id: u64) {
self.string_postings
.entry(value.to_string())
.or_default()
.add(doc_id);
}
pub fn add_int(&mut self, value: i64, doc_id: u64) {
self.int_postings.entry(value).or_default().add(doc_id);
}
pub fn add_uint(&mut self, value: u64, doc_id: u64) {
self.uint_postings.entry(value).or_default().add(doc_id);
}
pub fn remove_string(&mut self, value: &str, doc_id: u64) {
if let Some(posting) = self.string_postings.get_mut(value) {
posting.remove(doc_id);
if posting.is_empty() {
self.string_postings.remove(value);
}
}
}
pub fn lookup_string(&self, value: &str) -> AllowedSet {
self.string_postings
.get(value)
.map(|p| p.to_allowed_set())
.unwrap_or(AllowedSet::None)
}
pub fn lookup_int(&self, value: i64) -> AllowedSet {
self.int_postings
.get(&value)
.map(|p| p.to_allowed_set())
.unwrap_or(AllowedSet::None)
}
pub fn lookup_uint(&self, value: u64) -> AllowedSet {
self.uint_postings
.get(&value)
.map(|p| p.to_allowed_set())
.unwrap_or(AllowedSet::None)
}
pub fn lookup_string_in(&self, values: &[String]) -> AllowedSet {
let sets: Vec<_> = values
.iter()
.filter_map(|v| self.string_postings.get(v))
.collect();
if sets.is_empty() {
return AllowedSet::None;
}
let mut result = sets[0].clone();
for set in &sets[1..] {
result = result.union(set);
}
result.to_allowed_set()
}
pub fn lookup_uint_in(&self, values: &[u64]) -> AllowedSet {
let sets: Vec<_> = values
.iter()
.filter_map(|v| self.uint_postings.get(v))
.collect();
if sets.is_empty() {
return AllowedSet::None;
}
let mut result = sets[0].clone();
for set in &sets[1..] {
result = result.union(set);
}
result.to_allowed_set()
}
pub fn string_values(&self) -> impl Iterator<Item = &str> {
self.string_postings.keys().map(|s| s.as_str())
}
pub fn stats(&self) -> EqualityIndexStats {
EqualityIndexStats {
unique_string_values: self.string_postings.len(),
unique_int_values: self.int_postings.len(),
unique_uint_values: self.uint_postings.len(),
total_postings: self
.string_postings
.values()
.map(|p| p.len())
.sum::<usize>()
+ self.int_postings.values().map(|p| p.len()).sum::<usize>()
+ self.uint_postings.values().map(|p| p.len()).sum::<usize>(),
}
}
}
#[derive(Debug, Clone)]
pub struct EqualityIndexStats {
pub unique_string_values: usize,
pub unique_int_values: usize,
pub unique_uint_values: usize,
pub total_postings: usize,
}
#[derive(Debug, Default)]
pub struct RangeIndex {
entries: BTreeMap<i64, PostingSet>,
doc_count: usize,
}
impl RangeIndex {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, value: i64, doc_id: u64) {
self.entries.entry(value).or_default().add(doc_id);
self.doc_count += 1;
}
pub fn add_uint(&mut self, value: u64, doc_id: u64) {
self.add(value as i64, doc_id);
}
pub fn remove(&mut self, value: i64, doc_id: u64) {
if let Some(posting) = self.entries.get_mut(&value) {
posting.remove(doc_id);
if posting.is_empty() {
self.entries.remove(&value);
}
self.doc_count -= 1;
}
}
pub fn range_query(
&self,
min: Option<i64>,
max: Option<i64>,
min_inclusive: bool,
max_inclusive: bool,
) -> AllowedSet {
use std::ops::Bound;
let start = match min {
Some(v) if min_inclusive => Bound::Included(v),
Some(v) => Bound::Excluded(v),
None => Bound::Unbounded,
};
let end = match max {
Some(v) if max_inclusive => Bound::Included(v),
Some(v) => Bound::Excluded(v),
None => Bound::Unbounded,
};
let mut result = PostingSet::new();
for (_, posting) in self.entries.range((start, end)) {
result = result.union(posting);
}
result.to_allowed_set()
}
pub fn greater_than(&self, value: i64, inclusive: bool) -> AllowedSet {
self.range_query(Some(value), None, inclusive, true)
}
pub fn less_than(&self, value: i64, inclusive: bool) -> AllowedSet {
self.range_query(None, Some(value), true, inclusive)
}
pub fn stats(&self) -> RangeIndexStats {
let values: Vec<_> = self.entries.keys().collect();
RangeIndexStats {
unique_values: self.entries.len(),
total_docs: self.doc_count,
min_value: values.first().copied().copied(),
max_value: values.last().copied().copied(),
}
}
}
#[derive(Debug, Clone)]
pub struct RangeIndexStats {
pub unique_values: usize,
pub total_docs: usize,
pub min_value: Option<i64>,
pub max_value: Option<i64>,
}
#[derive(Debug, Default)]
pub struct MetadataIndex {
equality_indexes: HashMap<String, EqualityIndex>,
range_indexes: HashMap<String, RangeIndex>,
doc_count: usize,
}
impl MetadataIndex {
pub fn new() -> Self {
Self::default()
}
pub fn add_equality(&mut self, field: &str, value: &FilterValue, doc_id: u64) {
let index = self.equality_indexes.entry(field.to_string()).or_default();
match value {
FilterValue::String(s) => index.add_string(s, doc_id),
FilterValue::Int64(i) => index.add_int(*i, doc_id),
FilterValue::Uint64(u) => index.add_uint(*u, doc_id),
_ => {} }
}
pub fn add_string(&mut self, field: &str, value: &str, doc_id: u64) {
self.equality_indexes
.entry(field.to_string())
.or_default()
.add_string(value, doc_id);
}
pub fn add_range(&mut self, field: &str, value: i64, doc_id: u64) {
self.range_indexes
.entry(field.to_string())
.or_default()
.add(value, doc_id);
}
pub fn add_timestamp(&mut self, field: &str, timestamp: u64, doc_id: u64) {
self.add_range(field, timestamp as i64, doc_id);
}
pub fn set_doc_count(&mut self, count: usize) {
self.doc_count = count;
}
pub fn inc_doc_count(&mut self) {
self.doc_count += 1;
}
pub fn doc_count(&self) -> usize {
self.doc_count
}
pub fn evaluate_atom(&self, atom: &FilterAtom) -> AllowedSet {
match atom {
FilterAtom::Eq { field, value } => {
if let Some(index) = self.equality_indexes.get(field) {
match value {
FilterValue::String(s) => index.lookup_string(s),
FilterValue::Int64(i) => index.lookup_int(*i),
FilterValue::Uint64(u) => index.lookup_uint(*u),
_ => AllowedSet::All, }
} else {
AllowedSet::All }
}
FilterAtom::In { field, values } => {
if let Some(index) = self.equality_indexes.get(field) {
let strings: Vec<String> = values
.iter()
.filter_map(|v| match v {
FilterValue::String(s) => Some(s.clone()),
_ => None,
})
.collect();
if strings.len() == values.len() {
return index.lookup_string_in(&strings);
}
let uints: Vec<u64> = values
.iter()
.filter_map(|v| match v {
FilterValue::Uint64(u) => Some(*u),
_ => None,
})
.collect();
if uints.len() == values.len() {
return index.lookup_uint_in(&uints);
}
}
AllowedSet::All }
FilterAtom::Range {
field,
min,
max,
min_inclusive,
max_inclusive,
} => {
if let Some(index) = self.range_indexes.get(field) {
let min_val = min.as_ref().and_then(|v| match v {
FilterValue::Int64(i) => Some(*i),
FilterValue::Uint64(u) => Some(*u as i64),
_ => None,
});
let max_val = max.as_ref().and_then(|v| match v {
FilterValue::Int64(i) => Some(*i),
FilterValue::Uint64(u) => Some(*u as i64),
_ => None,
});
index.range_query(min_val, max_val, *min_inclusive, *max_inclusive)
} else {
AllowedSet::All
}
}
FilterAtom::True => AllowedSet::All,
FilterAtom::False => AllowedSet::None,
_ => AllowedSet::All,
}
}
pub fn evaluate(&self, filter: &FilterIR) -> AllowedSet {
if filter.is_all() {
return AllowedSet::All;
}
if filter.is_none() {
return AllowedSet::None;
}
let mut result = AllowedSet::All;
for clause in &filter.clauses {
let clause_result = self.evaluate_disjunction(clause);
result = result.intersect(&clause_result);
if result.is_empty() {
return AllowedSet::None;
}
}
result
}
fn evaluate_disjunction(&self, clause: &crate::filter_ir::Disjunction) -> AllowedSet {
if clause.atoms.len() == 1 {
return self.evaluate_atom(&clause.atoms[0]);
}
let mut result = AllowedSet::None;
for atom in &clause.atoms {
let atom_result = self.evaluate_atom(atom);
result = result.union(&atom_result);
if result.is_all() {
return AllowedSet::All;
}
}
result
}
pub fn estimate_selectivity(&self, filter: &FilterIR) -> f64 {
if self.doc_count == 0 {
return 1.0;
}
let allowed = self.evaluate(filter);
allowed.selectivity(self.doc_count)
}
}
pub struct ConcurrentMetadataIndex {
inner: RwLock<MetadataIndex>,
}
impl ConcurrentMetadataIndex {
pub fn new() -> Self {
Self {
inner: RwLock::new(MetadataIndex::new()),
}
}
pub fn add_string(&self, field: &str, value: &str, doc_id: u64) {
self.inner.write().unwrap().add_string(field, value, doc_id);
}
pub fn add_timestamp(&self, field: &str, timestamp: u64, doc_id: u64) {
self.inner
.write()
.unwrap()
.add_timestamp(field, timestamp, doc_id);
}
pub fn evaluate(&self, filter: &FilterIR) -> AllowedSet {
self.inner.read().unwrap().evaluate(filter)
}
pub fn set_doc_count(&self, count: usize) {
self.inner.write().unwrap().set_doc_count(count);
}
}
impl Default for ConcurrentMetadataIndex {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::filter_ir::FilterBuilder;
#[test]
fn test_posting_set_basic() {
let mut ps = PostingSet::new();
ps.add(1);
ps.add(5);
ps.add(3);
assert!(ps.contains(1));
assert!(ps.contains(3));
assert!(ps.contains(5));
assert!(!ps.contains(2));
assert_eq!(ps.len(), 3);
}
#[test]
fn test_posting_set_intersection() {
let a = PostingSet::from_vec(vec![1, 2, 3, 4, 5]);
let b = PostingSet::from_vec(vec![3, 4, 5, 6, 7]);
let c = a.intersect(&b);
assert_eq!(c.len(), 3);
assert!(c.contains(3));
assert!(c.contains(4));
assert!(c.contains(5));
}
#[test]
fn test_equality_index() {
let mut idx = EqualityIndex::new();
idx.add_string("production", 1);
idx.add_string("production", 2);
idx.add_string("staging", 3);
let result = idx.lookup_string("production");
assert_eq!(result.cardinality(), Some(2));
let result2 = idx.lookup_string("staging");
assert_eq!(result2.cardinality(), Some(1));
let result3 = idx.lookup_string("dev");
assert!(result3.is_empty());
}
#[test]
fn test_range_index() {
let mut idx = RangeIndex::new();
idx.add(100, 1);
idx.add(200, 2);
idx.add(300, 3);
idx.add(400, 4);
idx.add(500, 5);
let result = idx.range_query(Some(200), Some(400), true, true);
assert_eq!(result.cardinality(), Some(3));
let result2 = idx.greater_than(300, false);
assert_eq!(result2.cardinality(), Some(2));
let result3 = idx.less_than(300, true);
assert_eq!(result3.cardinality(), Some(3));
}
#[test]
fn test_metadata_index_evaluation() {
let mut idx = MetadataIndex::new();
for i in 0..10 {
idx.add_string("namespace", "production", i);
idx.add_timestamp("created_at", 1000 + i * 100, i);
}
for i in 10..20 {
idx.add_string("namespace", "staging", i);
idx.add_timestamp("created_at", 1000 + i * 100, i);
}
idx.set_doc_count(20);
let filter = FilterBuilder::new().namespace("production").build();
let result = idx.evaluate(&filter);
assert_eq!(result.cardinality(), Some(10));
let filter2 = FilterBuilder::new()
.namespace("production")
.gte("created_at", 1500i64)
.build();
let result2 = idx.evaluate(&filter2);
assert_eq!(result2.cardinality(), Some(5));
}
#[test]
fn test_selectivity_estimate() {
let mut idx = MetadataIndex::new();
for i in 0..100 {
let ns = if i % 10 == 0 { "rare" } else { "common" };
idx.add_string("namespace", ns, i);
}
idx.set_doc_count(100);
let common_filter = FilterBuilder::new().namespace("common").build();
let rare_filter = FilterBuilder::new().namespace("rare").build();
let common_selectivity = idx.estimate_selectivity(&common_filter);
let rare_selectivity = idx.estimate_selectivity(&rare_filter);
assert!(common_selectivity > rare_selectivity);
assert!((common_selectivity - 0.9).abs() < 0.01);
assert!((rare_selectivity - 0.1).abs() < 0.01);
}
}