Skip to main content

reddb_server/storage/query/optimizer/
stats.rs

1//! Statistics Collection
2//!
3//! Collects and maintains statistics for query optimization.
4
5use std::collections::HashMap;
6use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
7
8fn read_unpoisoned<'a, T>(lock: &'a RwLock<T>) -> RwLockReadGuard<'a, T> {
9    lock.read().unwrap_or_else(|poison| poison.into_inner())
10}
11
12fn write_unpoisoned<'a, T>(lock: &'a RwLock<T>) -> RwLockWriteGuard<'a, T> {
13    lock.write().unwrap_or_else(|poison| poison.into_inner())
14}
15
16/// Column statistics
17#[derive(Debug, Clone)]
18pub struct ColumnStats {
19    /// Column name
20    pub name: String,
21    /// Number of distinct values (NDV)
22    pub ndv: u64,
23    /// Fraction of NULL values
24    pub null_fraction: f64,
25    /// Minimum value (for numeric columns)
26    pub min_value: Option<f64>,
27    /// Maximum value (for numeric columns)
28    pub max_value: Option<f64>,
29}
30
31impl ColumnStats {
32    /// Create new column stats
33    pub fn new(name: String) -> Self {
34        Self {
35            name,
36            ndv: 0,
37            null_fraction: 0.0,
38            min_value: None,
39            max_value: None,
40        }
41    }
42
43    /// Set NDV
44    pub fn with_ndv(mut self, ndv: u64) -> Self {
45        self.ndv = ndv;
46        self
47    }
48
49    /// Set null fraction
50    pub fn with_null_fraction(mut self, fraction: f64) -> Self {
51        self.null_fraction = fraction.clamp(0.0, 1.0);
52        self
53    }
54
55    /// Set min/max values
56    pub fn with_range(mut self, min: f64, max: f64) -> Self {
57        self.min_value = Some(min);
58        self.max_value = Some(max);
59        self
60    }
61
62    /// Estimate selectivity for equality predicate
63    pub fn equality_selectivity(&self) -> f64 {
64        if self.ndv > 0 {
65            1.0 / self.ndv as f64
66        } else {
67            0.01 // Default
68        }
69    }
70
71    /// Estimate selectivity for range predicate
72    pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
73        match (self.min_value, self.max_value) {
74            (Some(min), Some(max)) if max > min => {
75                let range = max - min;
76                let low = lower.unwrap_or(min);
77                let high = upper.unwrap_or(max);
78                ((high - low) / range).clamp(0.0, 1.0)
79            }
80            _ => 0.25, // Default
81        }
82    }
83}
84
85/// Table statistics
86#[derive(Debug, Clone)]
87pub struct TableStats {
88    /// Table name
89    pub name: String,
90    /// Row count
91    pub row_count: u64,
92    /// Column statistics
93    columns: HashMap<String, ColumnStats>,
94    /// Average row size in bytes
95    pub avg_row_size: Option<usize>,
96    /// Last updated timestamp
97    pub last_updated: Option<u64>,
98}
99
100impl TableStats {
101    /// Create new table stats
102    pub fn new(name: String, row_count: u64) -> Self {
103        Self {
104            name,
105            row_count,
106            columns: HashMap::new(),
107            avg_row_size: None,
108            last_updated: None,
109        }
110    }
111
112    /// Add column statistics
113    pub fn add_column(&mut self, stats: ColumnStats) {
114        self.columns.insert(stats.name.clone(), stats);
115    }
116
117    /// Get column statistics
118    pub fn get_column(&self, name: &str) -> Option<&ColumnStats> {
119        self.columns.get(name)
120    }
121
122    /// Get all column names
123    pub fn column_names(&self) -> Vec<&str> {
124        self.columns.keys().map(|s| s.as_str()).collect()
125    }
126
127    /// Set average row size
128    pub fn with_avg_row_size(mut self, size: usize) -> Self {
129        self.avg_row_size = Some(size);
130        self
131    }
132
133    /// Estimate table size in bytes
134    pub fn estimated_size(&self) -> Option<u64> {
135        self.avg_row_size.map(|size| self.row_count * size as u64)
136    }
137}
138
139/// Statistics collector for building table stats
140pub struct StatsCollector {
141    /// Column collectors
142    columns: HashMap<String, ColumnCollector>,
143    /// Total rows seen
144    row_count: u64,
145    /// Total size seen
146    total_size: usize,
147}
148
149impl StatsCollector {
150    /// Create new collector
151    pub fn new() -> Self {
152        Self {
153            columns: HashMap::new(),
154            row_count: 0,
155            total_size: 0,
156        }
157    }
158
159    /// Start collecting for a column
160    pub fn add_column(&mut self, name: &str) {
161        self.columns
162            .insert(name.to_string(), ColumnCollector::new(name.to_string()));
163    }
164
165    /// Observe a row
166    pub fn observe_row(&mut self, row_size: usize) {
167        self.row_count += 1;
168        self.total_size += row_size;
169    }
170
171    /// Observe a value
172    pub fn observe_value(&mut self, column: &str, value: Option<&ObservedValue>) {
173        if let Some(collector) = self.columns.get_mut(column) {
174            collector.observe(value);
175        }
176    }
177
178    /// Build final statistics
179    pub fn build(self, table_name: String) -> TableStats {
180        let mut stats = TableStats::new(table_name, self.row_count);
181
182        if self.row_count > 0 {
183            stats.avg_row_size = Some(self.total_size / self.row_count as usize);
184        }
185
186        for (_, collector) in self.columns {
187            stats.add_column(collector.build(self.row_count));
188        }
189
190        stats
191    }
192}
193
194impl Default for StatsCollector {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200/// Value type for observation
201#[derive(Debug, Clone)]
202pub enum ObservedValue {
203    Int(i64),
204    Float(f64),
205    String(String),
206    Bool(bool),
207    Bytes(Vec<u8>),
208}
209
210impl ObservedValue {
211    pub fn as_f64(&self) -> Option<f64> {
212        match self {
213            ObservedValue::Int(i) => Some(*i as f64),
214            ObservedValue::Float(f) => Some(*f),
215            _ => None,
216        }
217    }
218}
219
220/// Per-column statistics collector
221struct ColumnCollector {
222    name: String,
223    /// Distinct values (using HyperLogLog would be better for large datasets)
224    distinct: std::collections::HashSet<u64>,
225    /// NULL count
226    null_count: u64,
227    /// Min value
228    min_value: Option<f64>,
229    /// Max value
230    max_value: Option<f64>,
231}
232
233impl ColumnCollector {
234    fn new(name: String) -> Self {
235        Self {
236            name,
237            distinct: std::collections::HashSet::new(),
238            null_count: 0,
239            min_value: None,
240            max_value: None,
241        }
242    }
243
244    fn observe(&mut self, value: Option<&ObservedValue>) {
245        match value {
246            None => {
247                self.null_count += 1;
248            }
249            Some(v) => {
250                // Hash for distinct counting
251                let hash = Self::hash_value(v);
252                self.distinct.insert(hash);
253
254                // Track min/max for numeric values
255                if let Some(f) = v.as_f64() {
256                    self.min_value = Some(match self.min_value {
257                        Some(min) => min.min(f),
258                        None => f,
259                    });
260                    self.max_value = Some(match self.max_value {
261                        Some(max) => max.max(f),
262                        None => f,
263                    });
264                }
265            }
266        }
267    }
268
269    fn hash_value(value: &ObservedValue) -> u64 {
270        use std::hash::{Hash, Hasher};
271        let mut hasher = std::collections::hash_map::DefaultHasher::new();
272
273        match value {
274            ObservedValue::Int(i) => i.hash(&mut hasher),
275            ObservedValue::Float(f) => f.to_bits().hash(&mut hasher),
276            ObservedValue::String(s) => s.hash(&mut hasher),
277            ObservedValue::Bool(b) => b.hash(&mut hasher),
278            ObservedValue::Bytes(b) => b.hash(&mut hasher),
279        }
280
281        hasher.finish()
282    }
283
284    fn build(self, row_count: u64) -> ColumnStats {
285        let null_fraction = if row_count > 0 {
286            self.null_count as f64 / row_count as f64
287        } else {
288            0.0
289        };
290
291        ColumnStats {
292            name: self.name,
293            ndv: self.distinct.len() as u64,
294            null_fraction,
295            min_value: self.min_value,
296            max_value: self.max_value,
297        }
298    }
299}
300
301/// Global statistics registry
302pub struct StatsRegistry {
303    /// Table statistics
304    tables: RwLock<HashMap<String, TableStats>>,
305}
306
307impl StatsRegistry {
308    /// Create new registry
309    pub fn new() -> Self {
310        Self {
311            tables: RwLock::new(HashMap::new()),
312        }
313    }
314
315    /// Register table statistics
316    pub fn register(&self, stats: TableStats) {
317        let mut tables = write_unpoisoned(&self.tables);
318        tables.insert(stats.name.clone(), stats);
319    }
320
321    /// Get table statistics
322    pub fn get(&self, table_name: &str) -> Option<TableStats> {
323        let tables = read_unpoisoned(&self.tables);
324        tables.get(table_name).cloned()
325    }
326
327    /// Remove table statistics
328    pub fn remove(&self, table_name: &str) -> Option<TableStats> {
329        let mut tables = write_unpoisoned(&self.tables);
330        tables.remove(table_name)
331    }
332
333    /// List all tables with statistics
334    pub fn list(&self) -> Vec<String> {
335        let tables = read_unpoisoned(&self.tables);
336        tables.keys().cloned().collect()
337    }
338
339    /// Clear all statistics
340    pub fn clear(&self) {
341        let mut tables = write_unpoisoned(&self.tables);
342        tables.clear();
343    }
344}
345
346impl Default for StatsRegistry {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_column_stats() {
358        let stats = ColumnStats::new("status".to_string())
359            .with_ndv(5)
360            .with_null_fraction(0.1);
361
362        assert_eq!(stats.ndv, 5);
363        assert!((stats.null_fraction - 0.1).abs() < 0.001);
364        assert!((stats.equality_selectivity() - 0.2).abs() < 0.001);
365    }
366
367    #[test]
368    fn test_range_selectivity() {
369        let stats = ColumnStats::new("age".to_string())
370            .with_ndv(100)
371            .with_range(0.0, 100.0);
372
373        // Half the range
374        let sel = stats.range_selectivity(Some(0.0), Some(50.0));
375        assert!((sel - 0.5).abs() < 0.001);
376
377        // Quarter of the range
378        let sel = stats.range_selectivity(Some(25.0), Some(50.0));
379        assert!((sel - 0.25).abs() < 0.001);
380    }
381
382    #[test]
383    fn test_table_stats() {
384        let mut stats = TableStats::new("users".to_string(), 10000);
385
386        stats.add_column(
387            ColumnStats::new("id".to_string())
388                .with_ndv(10000)
389                .with_null_fraction(0.0),
390        );
391
392        stats.add_column(
393            ColumnStats::new("status".to_string())
394                .with_ndv(5)
395                .with_null_fraction(0.02),
396        );
397
398        assert_eq!(stats.row_count, 10000);
399        assert!(stats.get_column("id").is_some());
400        assert!(stats.get_column("status").is_some());
401        assert!(stats.get_column("unknown").is_none());
402    }
403
404    #[test]
405    fn test_stats_collector() {
406        let mut collector = StatsCollector::new();
407        collector.add_column("value");
408
409        // Observe some values
410        for i in 0..100 {
411            collector.observe_row(100);
412            if i % 10 == 0 {
413                collector.observe_value("value", None); // NULL
414            } else {
415                collector.observe_value("value", Some(&ObservedValue::Int(i % 5)));
416            }
417        }
418
419        let stats = collector.build("test".to_string());
420
421        assert_eq!(stats.row_count, 100);
422        assert_eq!(stats.avg_row_size, Some(100));
423
424        let col = stats.get_column("value").unwrap();
425        assert_eq!(col.ndv, 5); // 0, 1, 2, 3, 4
426        assert!((col.null_fraction - 0.1).abs() < 0.01);
427    }
428
429    #[test]
430    fn test_stats_registry() {
431        let registry = StatsRegistry::new();
432
433        let stats = TableStats::new("users".to_string(), 1000);
434        registry.register(stats);
435
436        assert!(registry.get("users").is_some());
437        assert!(registry.get("orders").is_none());
438
439        assert_eq!(registry.list().len(), 1);
440
441        registry.remove("users");
442        assert!(registry.get("users").is_none());
443    }
444
445    #[test]
446    fn test_observed_value_hashing() {
447        let mut collector = StatsCollector::new();
448        collector.add_column("mixed");
449
450        // Different types should hash differently
451        collector.observe_value("mixed", Some(&ObservedValue::Int(42)));
452        collector.observe_value("mixed", Some(&ObservedValue::String("42".to_string())));
453        collector.observe_value("mixed", Some(&ObservedValue::Float(42.0)));
454
455        let stats = collector.build("test".to_string());
456        let col = stats.get_column("mixed").unwrap();
457
458        // All three should be distinct
459        assert_eq!(col.ndv, 3);
460    }
461}