use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::hash::Hash;
use crate::collation::Collatable;
use ankql::ast;
#[derive(Debug)]
pub(crate) struct ComparisonIndex<T> {
pub(crate) eq: HashMap<Vec<u8>, Vec<T>>,
pub(crate) ne: HashMap<Vec<u8>, Vec<T>>,
pub(crate) gt: BTreeMap<Vec<u8>, Vec<T>>,
pub(crate) lt: BTreeMap<Vec<u8>, Vec<T>>,
}
impl<T> Default for ComparisonIndex<T> {
fn default() -> Self { Self { eq: HashMap::new(), ne: HashMap::new(), gt: BTreeMap::new(), lt: BTreeMap::new() } }
}
impl<T: Clone + Eq + Hash + Ord> ComparisonIndex<T> {
#[allow(unused)]
pub fn new() -> Self { Self::default() }
fn for_entry<F, V>(&mut self, value: V, op: ast::ComparisonOperator, f: F)
where
F: FnOnce(&mut Vec<T>),
V: Collatable,
{
match op {
ast::ComparisonOperator::Equal => {
let entry = self.eq.entry(value.to_bytes()).or_default();
f(entry);
}
ast::ComparisonOperator::NotEqual => {
let entry = self.ne.entry(value.to_bytes()).or_default();
f(entry);
}
ast::ComparisonOperator::GreaterThan => {
let entry = self.gt.entry(value.to_bytes()).or_default();
f(entry);
}
ast::ComparisonOperator::LessThan => {
let entry = self.lt.entry(value.to_bytes()).or_default();
f(entry);
}
ast::ComparisonOperator::GreaterThanOrEqual => {
if let Some(pred) = value.predecessor_bytes() {
let entry = self.gt.entry(pred).or_default();
f(entry);
} else {
let entry = self.gt.entry(Vec::new()).or_default();
f(entry);
}
}
ast::ComparisonOperator::LessThanOrEqual => {
if let Some(succ) = value.successor_bytes() {
let entry = self.lt.entry(succ).or_default();
f(entry);
}
}
_ => panic!("Unsupported operator: {:?}", op),
}
}
pub fn add<V: Collatable>(&mut self, value: V, op: ast::ComparisonOperator, watcher_id: T) {
self.for_entry(value, op, |entries| entries.push(watcher_id));
}
pub fn remove<V: Collatable>(&mut self, value: V, op: ast::ComparisonOperator, watcher_id: T) {
self.for_entry(value, op, |entries| {
if let Some(pos) = entries.iter().position(|id| *id == watcher_id) {
entries.remove(pos);
}
});
}
pub fn find_matching<V: Collatable>(&self, value: V) -> std::collections::btree_set::IntoIter<T> {
let mut result = BTreeSet::new();
let bytes = value.to_bytes();
if let Some(subs) = self.eq.get(&bytes) {
result.extend(subs.iter().cloned());
}
for (stored_bytes, subs) in &self.ne {
if bytes != *stored_bytes {
result.extend(subs.iter().cloned());
}
}
for (_threshold, subs) in self.gt.range(..bytes.clone()) {
result.extend(subs.iter().cloned());
}
if let Some(pred) = value.successor_bytes() {
for (_threshold, subs) in self.lt.range(pred..) {
result.extend(subs.iter().cloned());
}
}
result.into_iter()
}
}
#[cfg(test)]
mod tests {
use super::ComparisonIndex;
use crate::value::Value;
use ankql::ast;
use ankurah_proto as proto;
#[test]
fn test_field_index() {
let mut index = ComparisonIndex::new();
let sub0 = proto::QueryId::test(0);
index.add(ast::Literal::I64(8), ast::ComparisonOperator::LessThan, sub0);
assert_eq!(index.find_matching(Value::I64(8)).collect::<Vec<_>>(), vec![]);
assert_eq!(index.find_matching(Value::I64(7)).collect::<Vec<_>>(), vec![sub0]);
let sub1 = proto::QueryId::test(1);
index.add(ast::Literal::I64(20), ast::ComparisonOperator::GreaterThan, sub1);
assert_eq!(index.find_matching(Value::I64(20)).collect::<Vec<_>>(), vec![]);
assert_eq!(index.find_matching(Value::I64(21)).collect::<Vec<_>>(), vec![sub1]);
index.add(ast::Literal::I64(5), ast::ComparisonOperator::Equal, sub0);
assert_eq!(index.find_matching(Value::I64(5)).collect::<Vec<_>>(), vec![sub0]);
index.add(ast::Literal::I64(25), ast::ComparisonOperator::LessThan, sub0);
assert_eq!(index.find_matching(Value::I64(22)).collect::<Vec<_>>(), vec![sub0, sub1]);
assert_eq!(index.find_matching(Value::I64(25)).collect::<Vec<_>>(), vec![sub1]);
assert_eq!(index.find_matching(Value::I64(26)).collect::<Vec<_>>(), vec![sub1]);
}
#[test]
fn test_field_index_not_equal() {
let mut index = ComparisonIndex::<proto::QueryId>::new();
let sub0 = proto::QueryId::test(0);
index.add(ast::Literal::I64(8), ast::ComparisonOperator::NotEqual, sub0);
assert_eq!(index.find_matching(Value::I64(8)).collect::<Vec<_>>(), vec![]);
assert_eq!(index.find_matching(Value::I64(9)).collect::<Vec<_>>(), vec![sub0]);
}
}