Skip to main content

reddb_server/storage/cache/
aggregates.rs

1//! Aggregation Cache
2//!
3//! Precomputed aggregation values for fast query responses.
4//! Inspired by Neo4j's statistics layer and Turso's precomputed counts.
5//!
6//! # Features
7//!
8//! - **Count Cache**: Precomputed COUNT(*) per table/filter
9//! - **Sum/Avg Cache**: Numeric aggregations by column
10//! - **Cardinality Cache**: Distinct value counts for query planning
11//! - **Incremental Updates**: Delta updates instead of full recalculation
12//!
13//! # Example
14//!
15//! ```ignore
16//! let mut agg = AggregationCache::new();
17//!
18//! // Register tables to track
19//! agg.register_table("hosts", &["status", "os_family", "criticality"]);
20//!
21//! // Update on inserts
22//! agg.on_insert("hosts", &row);
23//!
24//! // Fast aggregation queries
25//! let count = agg.count("hosts", Some("status = 'active'")); // O(1)
26//! let avg = agg.avg("hosts", "criticality"); // O(1)
27//! let distinct = agg.distinct_count("hosts", "os_family"); // O(1)
28//! ```
29
30use std::collections::{HashMap, HashSet};
31use std::time::Instant;
32
33// ============================================================================
34// Aggregation Types
35// ============================================================================
36
37/// Numeric aggregation value
38#[derive(Debug, Clone, Default)]
39pub struct NumericAgg {
40    /// Sum of values
41    pub sum: f64,
42    /// Count of values
43    pub count: u64,
44    /// Minimum value
45    pub min: Option<f64>,
46    /// Maximum value
47    pub max: Option<f64>,
48    /// Sum of squares (for variance/stddev)
49    pub sum_sq: f64,
50}
51
52impl NumericAgg {
53    /// Add a value
54    pub fn add(&mut self, value: f64) {
55        self.sum += value;
56        self.count += 1;
57        self.sum_sq += value * value;
58
59        self.min = Some(match self.min {
60            Some(m) => m.min(value),
61            None => value,
62        });
63
64        self.max = Some(match self.max {
65            Some(m) => m.max(value),
66            None => value,
67        });
68    }
69
70    /// Remove a value (for updates/deletes)
71    pub fn remove(&mut self, value: f64) {
72        if self.count > 0 {
73            self.sum -= value;
74            self.count -= 1;
75            self.sum_sq -= value * value;
76            // Min/max become invalid - need recompute or track
77        }
78    }
79
80    /// Get average
81    pub fn avg(&self) -> Option<f64> {
82        if self.count == 0 {
83            None
84        } else {
85            Some(self.sum / self.count as f64)
86        }
87    }
88
89    /// Get variance
90    pub fn variance(&self) -> Option<f64> {
91        if self.count < 2 {
92            None
93        } else {
94            let mean = self.sum / self.count as f64;
95            Some(self.sum_sq / self.count as f64 - mean * mean)
96        }
97    }
98
99    /// Get standard deviation
100    pub fn stddev(&self) -> Option<f64> {
101        self.variance().map(|v| v.sqrt())
102    }
103}
104
105/// Cardinality estimator using HyperLogLog-style counting
106#[derive(Debug, Clone)]
107pub struct CardinalityEstimate {
108    /// Distinct values seen (exact for small sets)
109    distinct_values: HashSet<u64>,
110    /// Threshold for switching to approximate
111    exact_threshold: usize,
112    /// Approximate count if over threshold
113    approximate: Option<u64>,
114    /// Last update time
115    updated_at: Instant,
116}
117
118impl CardinalityEstimate {
119    pub fn new(exact_threshold: usize) -> Self {
120        Self {
121            distinct_values: HashSet::new(),
122            exact_threshold,
123            approximate: None,
124            updated_at: Instant::now(),
125        }
126    }
127
128    /// Add a value (hash of the actual value)
129    pub fn add(&mut self, hash: u64) {
130        if self.approximate.is_none() {
131            self.distinct_values.insert(hash);
132            if self.distinct_values.len() > self.exact_threshold {
133                // Switch to approximate mode
134                self.approximate = Some(self.distinct_values.len() as u64);
135                self.distinct_values.clear();
136            }
137        } else {
138            // Approximate mode: use probabilistic estimation
139            // Simple: just increment if hash is "rare enough"
140            if hash.is_multiple_of(1000) {
141                if let Some(ref mut count) = self.approximate {
142                    *count += 1;
143                }
144            }
145        }
146        self.updated_at = Instant::now();
147    }
148
149    /// Get cardinality estimate
150    pub fn estimate(&self) -> u64 {
151        if let Some(approx) = self.approximate {
152            approx
153        } else {
154            self.distinct_values.len() as u64
155        }
156    }
157}
158
159impl Default for CardinalityEstimate {
160    fn default() -> Self {
161        Self::new(10000)
162    }
163}
164
165// ============================================================================
166// Table Aggregates
167// ============================================================================
168
169/// Aggregations for a single table
170#[derive(Debug)]
171struct TableAggregates {
172    /// Total row count
173    row_count: u64,
174    /// Count by filter predicate (e.g., "status=active" -> count)
175    filtered_counts: HashMap<String, u64>,
176    /// Numeric aggregations by column
177    numeric_aggs: HashMap<String, NumericAgg>,
178    /// Cardinality estimates by column
179    cardinalities: HashMap<String, CardinalityEstimate>,
180    /// Columns being tracked
181    tracked_columns: Vec<String>,
182    /// When aggregates were last refreshed
183    last_refresh: Instant,
184    /// Whether aggregates are stale
185    stale: bool,
186}
187
188impl TableAggregates {
189    fn new(tracked_columns: Vec<String>) -> Self {
190        Self {
191            row_count: 0,
192            filtered_counts: HashMap::new(),
193            numeric_aggs: HashMap::new(),
194            cardinalities: tracked_columns
195                .iter()
196                .map(|c| (c.clone(), CardinalityEstimate::default()))
197                .collect(),
198            tracked_columns,
199            last_refresh: Instant::now(),
200            stale: false,
201        }
202    }
203}
204
205// ============================================================================
206// Aggregation Cache
207// ============================================================================
208
209/// Cache for precomputed aggregations
210pub struct AggregationCache {
211    /// Aggregations per table
212    tables: HashMap<String, TableAggregates>,
213    /// Global row count across all tables
214    global_row_count: u64,
215}
216
217impl AggregationCache {
218    /// Create a new aggregation cache
219    pub fn new() -> Self {
220        Self {
221            tables: HashMap::new(),
222            global_row_count: 0,
223        }
224    }
225
226    /// Register a table for aggregation tracking
227    pub fn register_table(&mut self, table: &str, tracked_columns: &[&str]) {
228        let columns = tracked_columns.iter().map(|s| s.to_string()).collect();
229        self.tables
230            .insert(table.to_string(), TableAggregates::new(columns));
231    }
232
233    /// Get total row count for a table
234    pub fn count(&self, table: &str) -> Option<u64> {
235        self.tables.get(table).map(|t| t.row_count)
236    }
237
238    /// Get filtered count (if cached)
239    pub fn count_filtered(&self, table: &str, filter_key: &str) -> Option<u64> {
240        self.tables
241            .get(table)
242            .and_then(|t| t.filtered_counts.get(filter_key).copied())
243    }
244
245    /// Set a filtered count (precomputed)
246    pub fn set_filtered_count(&mut self, table: &str, filter_key: &str, count: u64) {
247        if let Some(aggs) = self.tables.get_mut(table) {
248            aggs.filtered_counts.insert(filter_key.to_string(), count);
249        }
250    }
251
252    /// Get numeric aggregation for a column
253    pub fn numeric_agg(&self, table: &str, column: &str) -> Option<&NumericAgg> {
254        self.tables
255            .get(table)
256            .and_then(|t| t.numeric_aggs.get(column))
257    }
258
259    /// Get average for a column
260    pub fn avg(&self, table: &str, column: &str) -> Option<f64> {
261        self.numeric_agg(table, column).and_then(|a| a.avg())
262    }
263
264    /// Get sum for a column
265    pub fn sum(&self, table: &str, column: &str) -> Option<f64> {
266        self.numeric_agg(table, column).map(|a| a.sum)
267    }
268
269    /// Get min for a column
270    pub fn min(&self, table: &str, column: &str) -> Option<f64> {
271        self.numeric_agg(table, column).and_then(|a| a.min)
272    }
273
274    /// Get max for a column
275    pub fn max(&self, table: &str, column: &str) -> Option<f64> {
276        self.numeric_agg(table, column).and_then(|a| a.max)
277    }
278
279    /// Get distinct count estimate for a column
280    pub fn distinct_count(&self, table: &str, column: &str) -> Option<u64> {
281        self.tables
282            .get(table)
283            .and_then(|t| t.cardinalities.get(column))
284            .map(|c| c.estimate())
285    }
286
287    /// Record an insert operation
288    pub fn on_insert(&mut self, table: &str, values: &HashMap<String, AggValue>) {
289        if let Some(aggs) = self.tables.get_mut(table) {
290            aggs.row_count += 1;
291            self.global_row_count += 1;
292
293            for (col, value) in values {
294                // Update numeric aggregations
295                if let AggValue::Number(n) = value {
296                    aggs.numeric_aggs
297                        .entry(col.clone())
298                        .or_insert_with(NumericAgg::default)
299                        .add(*n);
300                }
301
302                // Update cardinality
303                if let Some(card) = aggs.cardinalities.get_mut(col) {
304                    card.add(value.hash());
305                }
306            }
307
308            // Invalidate filtered counts (need recompute)
309            aggs.filtered_counts.clear();
310        }
311    }
312
313    /// Record a delete operation
314    pub fn on_delete(&mut self, table: &str, values: &HashMap<String, AggValue>) {
315        if let Some(aggs) = self.tables.get_mut(table) {
316            aggs.row_count = aggs.row_count.saturating_sub(1);
317            self.global_row_count = self.global_row_count.saturating_sub(1);
318
319            for (col, value) in values {
320                if let AggValue::Number(n) = value {
321                    if let Some(num_agg) = aggs.numeric_aggs.get_mut(col) {
322                        num_agg.remove(*n);
323                    }
324                }
325            }
326
327            // Mark as needing refresh for min/max
328            aggs.stale = true;
329            aggs.filtered_counts.clear();
330        }
331    }
332
333    /// Full refresh for a table (recompute all aggregates)
334    pub fn refresh<I>(&mut self, table: &str, rows: I)
335    where
336        I: Iterator<Item = HashMap<String, AggValue>>,
337    {
338        if let Some(aggs) = self.tables.get_mut(table) {
339            // Reset aggregates
340            aggs.row_count = 0;
341            aggs.numeric_aggs.clear();
342            for card in aggs.cardinalities.values_mut() {
343                *card = CardinalityEstimate::default();
344            }
345
346            // Rebuild from rows
347            for row in rows {
348                aggs.row_count += 1;
349
350                for (col, value) in &row {
351                    if let AggValue::Number(n) = value {
352                        aggs.numeric_aggs
353                            .entry(col.clone())
354                            .or_insert_with(NumericAgg::default)
355                            .add(*n);
356                    }
357
358                    if let Some(card) = aggs.cardinalities.get_mut(col) {
359                        card.add(value.hash());
360                    }
361                }
362            }
363
364            aggs.stale = false;
365            aggs.last_refresh = Instant::now();
366        }
367    }
368
369    /// Get global row count
370    pub fn global_count(&self) -> u64 {
371        self.global_row_count
372    }
373
374    /// Check if aggregates are stale
375    pub fn is_stale(&self, table: &str) -> bool {
376        self.tables.get(table).map(|t| t.stale).unwrap_or(true)
377    }
378
379    /// Get statistics summary
380    pub fn stats(&self) -> AggCacheStats {
381        AggCacheStats {
382            tables: self.tables.len(),
383            total_rows: self.global_row_count,
384            tracked_columns: self.tables.values().map(|t| t.tracked_columns.len()).sum(),
385        }
386    }
387}
388
389impl Default for AggregationCache {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395/// Value type for aggregation operations
396#[derive(Debug, Clone)]
397pub enum AggValue {
398    Number(f64),
399    String(String),
400    Bool(bool),
401    Null,
402}
403
404impl AggValue {
405    /// Get a hash of the value for cardinality estimation
406    pub fn hash(&self) -> u64 {
407        use std::collections::hash_map::DefaultHasher;
408        use std::hash::{Hash, Hasher};
409
410        let mut hasher = DefaultHasher::new();
411        match self {
412            AggValue::Number(n) => n.to_bits().hash(&mut hasher),
413            AggValue::String(s) => s.hash(&mut hasher),
414            AggValue::Bool(b) => b.hash(&mut hasher),
415            AggValue::Null => 0u64.hash(&mut hasher),
416        }
417        hasher.finish()
418    }
419}
420
421/// Aggregation cache statistics
422#[derive(Debug, Clone)]
423pub struct AggCacheStats {
424    /// Number of tables tracked
425    pub tables: usize,
426    /// Total rows across all tables
427    pub total_rows: u64,
428    /// Total tracked columns
429    pub tracked_columns: usize,
430}
431
432// ============================================================================
433// Tests
434// ============================================================================
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn test_numeric_agg() {
442        let mut agg = NumericAgg::default();
443        agg.add(10.0);
444        agg.add(20.0);
445        agg.add(30.0);
446
447        assert_eq!(agg.count, 3);
448        assert_eq!(agg.sum, 60.0);
449        assert_eq!(agg.avg(), Some(20.0));
450        assert_eq!(agg.min, Some(10.0));
451        assert_eq!(agg.max, Some(30.0));
452    }
453
454    #[test]
455    fn test_aggregation_cache() {
456        let mut cache = AggregationCache::new();
457        cache.register_table("hosts", &["criticality", "status"]);
458
459        // Insert some rows
460        let mut row1 = HashMap::new();
461        row1.insert("criticality".to_string(), AggValue::Number(5.0));
462        row1.insert("status".to_string(), AggValue::String("active".to_string()));
463        cache.on_insert("hosts", &row1);
464
465        let mut row2 = HashMap::new();
466        row2.insert("criticality".to_string(), AggValue::Number(8.0));
467        row2.insert("status".to_string(), AggValue::String("active".to_string()));
468        cache.on_insert("hosts", &row2);
469
470        let mut row3 = HashMap::new();
471        row3.insert("criticality".to_string(), AggValue::Number(2.0));
472        row3.insert(
473            "status".to_string(),
474            AggValue::String("inactive".to_string()),
475        );
476        cache.on_insert("hosts", &row3);
477
478        assert_eq!(cache.count("hosts"), Some(3));
479        assert_eq!(cache.avg("hosts", "criticality"), Some(5.0));
480        assert_eq!(cache.sum("hosts", "criticality"), Some(15.0));
481        assert_eq!(cache.min("hosts", "criticality"), Some(2.0));
482        assert_eq!(cache.max("hosts", "criticality"), Some(8.0));
483    }
484
485    #[test]
486    fn test_cardinality() {
487        let mut card = CardinalityEstimate::new(100);
488
489        // Add distinct values
490        for i in 0..50 {
491            card.add(i);
492        }
493
494        assert_eq!(card.estimate(), 50);
495
496        // Add duplicates
497        for i in 0..50 {
498            card.add(i);
499        }
500
501        // Should still be ~50 (duplicates don't count)
502        assert_eq!(card.estimate(), 50);
503    }
504}