use std::collections::{HashMap, HashSet};
use std::time::Instant;
#[derive(Debug, Clone, Default)]
pub struct NumericAgg {
pub sum: f64,
pub count: u64,
pub min: Option<f64>,
pub max: Option<f64>,
pub sum_sq: f64,
}
impl NumericAgg {
pub fn add(&mut self, value: f64) {
self.sum += value;
self.count += 1;
self.sum_sq += value * value;
self.min = Some(match self.min {
Some(m) => m.min(value),
None => value,
});
self.max = Some(match self.max {
Some(m) => m.max(value),
None => value,
});
}
pub fn remove(&mut self, value: f64) {
if self.count > 0 {
self.sum -= value;
self.count -= 1;
self.sum_sq -= value * value;
}
}
pub fn avg(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.sum / self.count as f64)
}
}
pub fn variance(&self) -> Option<f64> {
if self.count < 2 {
None
} else {
let mean = self.sum / self.count as f64;
Some(self.sum_sq / self.count as f64 - mean * mean)
}
}
pub fn stddev(&self) -> Option<f64> {
self.variance().map(|v| v.sqrt())
}
}
#[derive(Debug, Clone)]
pub struct CardinalityEstimate {
distinct_values: HashSet<u64>,
exact_threshold: usize,
approximate: Option<u64>,
updated_at: Instant,
}
impl CardinalityEstimate {
pub fn new(exact_threshold: usize) -> Self {
Self {
distinct_values: HashSet::new(),
exact_threshold,
approximate: None,
updated_at: Instant::now(),
}
}
pub fn add(&mut self, hash: u64) {
if self.approximate.is_none() {
self.distinct_values.insert(hash);
if self.distinct_values.len() > self.exact_threshold {
self.approximate = Some(self.distinct_values.len() as u64);
self.distinct_values.clear();
}
} else {
if hash.is_multiple_of(1000) {
if let Some(ref mut count) = self.approximate {
*count += 1;
}
}
}
self.updated_at = Instant::now();
}
pub fn estimate(&self) -> u64 {
if let Some(approx) = self.approximate {
approx
} else {
self.distinct_values.len() as u64
}
}
}
impl Default for CardinalityEstimate {
fn default() -> Self {
Self::new(10000)
}
}
#[derive(Debug)]
struct TableAggregates {
row_count: u64,
filtered_counts: HashMap<String, u64>,
numeric_aggs: HashMap<String, NumericAgg>,
cardinalities: HashMap<String, CardinalityEstimate>,
tracked_columns: Vec<String>,
last_refresh: Instant,
stale: bool,
}
impl TableAggregates {
fn new(tracked_columns: Vec<String>) -> Self {
Self {
row_count: 0,
filtered_counts: HashMap::new(),
numeric_aggs: HashMap::new(),
cardinalities: tracked_columns
.iter()
.map(|c| (c.clone(), CardinalityEstimate::default()))
.collect(),
tracked_columns,
last_refresh: Instant::now(),
stale: false,
}
}
}
pub struct AggregationCache {
tables: HashMap<String, TableAggregates>,
global_row_count: u64,
}
impl AggregationCache {
pub fn new() -> Self {
Self {
tables: HashMap::new(),
global_row_count: 0,
}
}
pub fn register_table(&mut self, table: &str, tracked_columns: &[&str]) {
let columns = tracked_columns.iter().map(|s| s.to_string()).collect();
self.tables
.insert(table.to_string(), TableAggregates::new(columns));
}
pub fn count(&self, table: &str) -> Option<u64> {
self.tables.get(table).map(|t| t.row_count)
}
pub fn count_filtered(&self, table: &str, filter_key: &str) -> Option<u64> {
self.tables
.get(table)
.and_then(|t| t.filtered_counts.get(filter_key).copied())
}
pub fn set_filtered_count(&mut self, table: &str, filter_key: &str, count: u64) {
if let Some(aggs) = self.tables.get_mut(table) {
aggs.filtered_counts.insert(filter_key.to_string(), count);
}
}
pub fn numeric_agg(&self, table: &str, column: &str) -> Option<&NumericAgg> {
self.tables
.get(table)
.and_then(|t| t.numeric_aggs.get(column))
}
pub fn avg(&self, table: &str, column: &str) -> Option<f64> {
self.numeric_agg(table, column).and_then(|a| a.avg())
}
pub fn sum(&self, table: &str, column: &str) -> Option<f64> {
self.numeric_agg(table, column).map(|a| a.sum)
}
pub fn min(&self, table: &str, column: &str) -> Option<f64> {
self.numeric_agg(table, column).and_then(|a| a.min)
}
pub fn max(&self, table: &str, column: &str) -> Option<f64> {
self.numeric_agg(table, column).and_then(|a| a.max)
}
pub fn distinct_count(&self, table: &str, column: &str) -> Option<u64> {
self.tables
.get(table)
.and_then(|t| t.cardinalities.get(column))
.map(|c| c.estimate())
}
pub fn on_insert(&mut self, table: &str, values: &HashMap<String, AggValue>) {
if let Some(aggs) = self.tables.get_mut(table) {
aggs.row_count += 1;
self.global_row_count += 1;
for (col, value) in values {
if let AggValue::Number(n) = value {
aggs.numeric_aggs
.entry(col.clone())
.or_insert_with(NumericAgg::default)
.add(*n);
}
if let Some(card) = aggs.cardinalities.get_mut(col) {
card.add(value.hash());
}
}
aggs.filtered_counts.clear();
}
}
pub fn on_delete(&mut self, table: &str, values: &HashMap<String, AggValue>) {
if let Some(aggs) = self.tables.get_mut(table) {
aggs.row_count = aggs.row_count.saturating_sub(1);
self.global_row_count = self.global_row_count.saturating_sub(1);
for (col, value) in values {
if let AggValue::Number(n) = value {
if let Some(num_agg) = aggs.numeric_aggs.get_mut(col) {
num_agg.remove(*n);
}
}
}
aggs.stale = true;
aggs.filtered_counts.clear();
}
}
pub fn refresh<I>(&mut self, table: &str, rows: I)
where
I: Iterator<Item = HashMap<String, AggValue>>,
{
if let Some(aggs) = self.tables.get_mut(table) {
aggs.row_count = 0;
aggs.numeric_aggs.clear();
for card in aggs.cardinalities.values_mut() {
*card = CardinalityEstimate::default();
}
for row in rows {
aggs.row_count += 1;
for (col, value) in &row {
if let AggValue::Number(n) = value {
aggs.numeric_aggs
.entry(col.clone())
.or_insert_with(NumericAgg::default)
.add(*n);
}
if let Some(card) = aggs.cardinalities.get_mut(col) {
card.add(value.hash());
}
}
}
aggs.stale = false;
aggs.last_refresh = Instant::now();
}
}
pub fn global_count(&self) -> u64 {
self.global_row_count
}
pub fn is_stale(&self, table: &str) -> bool {
self.tables.get(table).map(|t| t.stale).unwrap_or(true)
}
pub fn stats(&self) -> AggCacheStats {
AggCacheStats {
tables: self.tables.len(),
total_rows: self.global_row_count,
tracked_columns: self.tables.values().map(|t| t.tracked_columns.len()).sum(),
}
}
}
impl Default for AggregationCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum AggValue {
Number(f64),
String(String),
Bool(bool),
Null,
}
impl AggValue {
pub fn hash(&self) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
match self {
AggValue::Number(n) => n.to_bits().hash(&mut hasher),
AggValue::String(s) => s.hash(&mut hasher),
AggValue::Bool(b) => b.hash(&mut hasher),
AggValue::Null => 0u64.hash(&mut hasher),
}
hasher.finish()
}
}
#[derive(Debug, Clone)]
pub struct AggCacheStats {
pub tables: usize,
pub total_rows: u64,
pub tracked_columns: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_numeric_agg() {
let mut agg = NumericAgg::default();
agg.add(10.0);
agg.add(20.0);
agg.add(30.0);
assert_eq!(agg.count, 3);
assert_eq!(agg.sum, 60.0);
assert_eq!(agg.avg(), Some(20.0));
assert_eq!(agg.min, Some(10.0));
assert_eq!(agg.max, Some(30.0));
}
#[test]
fn test_aggregation_cache() {
let mut cache = AggregationCache::new();
cache.register_table("hosts", &["criticality", "status"]);
let mut row1 = HashMap::new();
row1.insert("criticality".to_string(), AggValue::Number(5.0));
row1.insert("status".to_string(), AggValue::String("active".to_string()));
cache.on_insert("hosts", &row1);
let mut row2 = HashMap::new();
row2.insert("criticality".to_string(), AggValue::Number(8.0));
row2.insert("status".to_string(), AggValue::String("active".to_string()));
cache.on_insert("hosts", &row2);
let mut row3 = HashMap::new();
row3.insert("criticality".to_string(), AggValue::Number(2.0));
row3.insert(
"status".to_string(),
AggValue::String("inactive".to_string()),
);
cache.on_insert("hosts", &row3);
assert_eq!(cache.count("hosts"), Some(3));
assert_eq!(cache.avg("hosts", "criticality"), Some(5.0));
assert_eq!(cache.sum("hosts", "criticality"), Some(15.0));
assert_eq!(cache.min("hosts", "criticality"), Some(2.0));
assert_eq!(cache.max("hosts", "criticality"), Some(8.0));
}
#[test]
fn test_cardinality() {
let mut card = CardinalityEstimate::new(100);
for i in 0..50 {
card.add(i);
}
assert_eq!(card.estimate(), 50);
for i in 0..50 {
card.add(i);
}
assert_eq!(card.estimate(), 50);
}
}