use super::node::FilterNode;
use super::ops::{CompareOp, LogicalOp};
use crate::websocket::SocketId;
use ahash::AHashSet;
use dashmap::{DashMap, DashSet};
use std::collections::BTreeMap;
use std::hash::{Hash, Hasher};
fn compute_eq_hash(channel: &str, key: &str, value: &str) -> u64 {
use ahash::AHasher;
let mut hasher = AHasher::default();
channel.hash(&mut hasher);
key.hash(&mut hasher);
value.hash(&mut hasher);
hasher.finish()
}
#[derive(Default)]
pub struct FilterIndex {
eq_index: DashMap<u64, DashSet<SocketId>>,
channel_keys: DashMap<String, DashSet<u64>>,
complex_filters: DashMap<String, DashSet<SocketId>>,
no_filter: DashMap<String, DashSet<SocketId>>,
}
pub struct IndexLookupResult {
pub indexed_matches: Vec<SocketId>,
pub needs_evaluation: Vec<SocketId>,
pub no_filter: Vec<SocketId>,
}
impl FilterIndex {
pub fn new() -> Self {
Self::default()
}
pub fn add_socket_filter(
&self,
channel: &str,
socket_id: SocketId,
filter: Option<&FilterNode>,
) {
self.remove_socket_from_all_indexes(channel, socket_id, filter);
match filter {
None => {
tracing::debug!(
"FilterIndex: Adding socket {} to no_filter for channel {}",
socket_id,
channel
);
self.no_filter
.entry(channel.to_string())
.or_default()
.insert(socket_id);
}
Some(filter_node) => {
if let Some(indexable) = self.extract_indexable_filter(filter_node) {
tracing::debug!(
"FilterIndex: Adding socket {} to eq_index for channel {}, key={}, values_count={}",
socket_id,
channel,
indexable.key,
indexable.values.len()
);
self.add_to_eq_index(channel, socket_id, &indexable);
} else {
tracing::debug!(
"FilterIndex: Adding socket {} to complex_filters for channel {} (filter not indexable)",
socket_id,
channel
);
self.complex_filters
.entry(channel.to_string())
.or_default()
.insert(socket_id);
}
}
}
}
fn remove_socket_from_all_indexes(
&self,
channel: &str,
socket_id: SocketId,
_current_filter: Option<&FilterNode>,
) {
if let Some(set) = self.no_filter.get(channel) {
set.remove(&socket_id);
}
if let Some(set) = self.complex_filters.get(channel) {
set.remove(&socket_id);
}
}
pub fn remove_socket_filter(
&self,
channel: &str,
socket_id: SocketId,
filter: Option<&FilterNode>,
) {
if let Some(set) = self.no_filter.get(channel) {
set.remove(&socket_id);
}
if let Some(set) = self.complex_filters.get(channel) {
set.remove(&socket_id);
}
if let Some(filter_node) = filter
&& let Some(indexable) = self.extract_indexable_filter(filter_node)
{
self.remove_from_eq_index(channel, socket_id, &indexable);
}
}
pub fn lookup(&self, channel: &str, tags: &BTreeMap<String, String>) -> IndexLookupResult {
let no_filter_sockets = if let Some(set) = self.no_filter.get(channel) {
let mut sockets = Vec::with_capacity(set.len());
for entry in set.iter() {
sockets.push(*entry.key());
}
sockets
} else {
Vec::new()
};
let indexed_matches = self.lookup_eq_index(channel, tags);
let needs_evaluation = if let Some(set) = self.complex_filters.get(channel) {
let mut sockets = Vec::with_capacity(set.len());
for entry in set.iter() {
sockets.push(*entry.key());
}
sockets
} else {
Vec::new()
};
IndexLookupResult {
indexed_matches,
needs_evaluation,
no_filter: no_filter_sockets,
}
}
fn lookup_eq_index(&self, channel: &str, tags: &BTreeMap<String, String>) -> Vec<SocketId> {
if tags.is_empty() {
return Vec::new();
}
let mut matching_hashes: Vec<u64> = Vec::with_capacity(tags.len());
for (tag_key, tag_value) in tags {
let hash = compute_eq_hash(channel, tag_key, tag_value);
if self.eq_index.contains_key(&hash) {
matching_hashes.push(hash);
}
}
match matching_hashes.len() {
0 => Vec::new(),
1 => {
if let Some(socket_set) = self.eq_index.get(&matching_hashes[0]) {
let mut matches = Vec::with_capacity(socket_set.len());
for entry in socket_set.iter() {
matches.push(*entry.key());
}
matches
} else {
Vec::new()
}
}
_ => {
let mut dedup_set = AHashSet::new();
for hash in matching_hashes {
if let Some(socket_set) = self.eq_index.get(&hash) {
for entry in socket_set.iter() {
dedup_set.insert(*entry.key());
}
}
}
dedup_set.into_iter().collect()
}
}
}
fn extract_indexable_filter(&self, filter: &FilterNode) -> Option<IndexableFilter> {
if let Some(op) = filter.logical_op() {
match op {
LogicalOp::Or => {
let mut all_values = Vec::new();
let mut common_key: Option<String> = None;
for child in filter.nodes() {
if let Some(indexable) = self.extract_indexable_filter(child) {
match &common_key {
None => common_key = Some(indexable.key.clone()),
Some(k) if k != &indexable.key => return None, _ => {}
}
all_values.extend(indexable.values);
} else {
return None; }
}
common_key.map(|key| IndexableFilter {
key,
values: all_values,
})
}
LogicalOp::And | LogicalOp::Not => {
None
}
}
} else {
match filter.compare_op() {
CompareOp::Equal => {
let key = filter.key().to_string();
let value = filter.val().to_string();
if key.is_empty() {
return None;
}
Some(IndexableFilter {
key,
values: vec![value],
})
}
CompareOp::In => {
let key = filter.key();
if key.is_empty() {
return None;
}
let vals_ref = filter.vals();
if vals_ref.is_empty() || vals_ref.len() > 500 {
return None;
}
Some(IndexableFilter {
key: key.to_string(),
values: vals_ref.to_vec(),
})
}
_ => None,
}
}
}
fn add_to_eq_index(&self, channel: &str, socket_id: SocketId, indexable: &IndexableFilter) {
let channel_key_set = self.channel_keys.entry(channel.to_string()).or_default();
for value in &indexable.values {
let hash = compute_eq_hash(channel, &indexable.key, value);
channel_key_set.insert(hash);
if let Some(socket_set) = self.eq_index.get(&hash) {
socket_set.insert(socket_id);
} else {
self.eq_index.entry(hash).or_default().insert(socket_id);
}
}
}
fn remove_from_eq_index(
&self,
channel: &str,
socket_id: SocketId,
indexable: &IndexableFilter,
) {
for value in &indexable.values {
let hash = compute_eq_hash(channel, &indexable.key, value);
if let Some(socket_set) = self.eq_index.get(&hash) {
socket_set.remove(&socket_id);
}
}
}
pub fn stats(&self, channel: &str) -> IndexStats {
let mut eq_entries = 0;
let mut eq_sockets = 0;
if let Some(key_hashes) = self.channel_keys.get(channel) {
for hash_entry in key_hashes.iter() {
let hash = *hash_entry.key();
if let Some(socket_set) = self.eq_index.get(&hash) {
eq_entries += 1;
eq_sockets += socket_set.len();
}
}
}
let complex_count = self
.complex_filters
.get(channel)
.map(|s| s.len())
.unwrap_or(0);
let no_filter_count = self.no_filter.get(channel).map(|s| s.len()).unwrap_or(0);
IndexStats {
eq_entries,
eq_sockets,
complex_filters: complex_count,
no_filter: no_filter_count,
}
}
pub fn clear_channel(&self, channel: &str) {
if let Some((_, key_hashes)) = self.channel_keys.remove(channel) {
for hash_entry in key_hashes.iter() {
let hash = *hash_entry.key();
self.eq_index.remove(&hash);
}
}
self.complex_filters.remove(channel);
self.no_filter.remove(channel);
}
}
struct IndexableFilter {
key: String,
values: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct IndexStats {
pub eq_entries: usize,
pub eq_sockets: usize,
pub complex_filters: usize,
pub no_filter: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eq_index_key_hashing() {
let hash1 = compute_eq_hash("channel1", "key1", "value1");
let hash2 = compute_eq_hash("channel1", "key1", "value1");
let hash3 = compute_eq_hash("channel1", "key1", "value2");
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
}
#[test]
fn test_filter_index_no_filter() {
let index = FilterIndex::new();
let socket_id = SocketId::new();
index.add_socket_filter("channel1", socket_id, None);
let result = index.lookup("channel1", &BTreeMap::new());
assert_eq!(result.no_filter.len(), 1);
assert_eq!(result.no_filter[0], socket_id);
}
}