reddb_server/storage/query/optimizer/
stats.rs1use std::collections::HashMap;
6use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
7
8fn read_unpoisoned<'a, T>(lock: &'a RwLock<T>) -> RwLockReadGuard<'a, T> {
9 lock.read().unwrap_or_else(|poison| poison.into_inner())
10}
11
12fn write_unpoisoned<'a, T>(lock: &'a RwLock<T>) -> RwLockWriteGuard<'a, T> {
13 lock.write().unwrap_or_else(|poison| poison.into_inner())
14}
15
16#[derive(Debug, Clone)]
18pub struct ColumnStats {
19 pub name: String,
21 pub ndv: u64,
23 pub null_fraction: f64,
25 pub min_value: Option<f64>,
27 pub max_value: Option<f64>,
29}
30
31impl ColumnStats {
32 pub fn new(name: String) -> Self {
34 Self {
35 name,
36 ndv: 0,
37 null_fraction: 0.0,
38 min_value: None,
39 max_value: None,
40 }
41 }
42
43 pub fn with_ndv(mut self, ndv: u64) -> Self {
45 self.ndv = ndv;
46 self
47 }
48
49 pub fn with_null_fraction(mut self, fraction: f64) -> Self {
51 self.null_fraction = fraction.clamp(0.0, 1.0);
52 self
53 }
54
55 pub fn with_range(mut self, min: f64, max: f64) -> Self {
57 self.min_value = Some(min);
58 self.max_value = Some(max);
59 self
60 }
61
62 pub fn equality_selectivity(&self) -> f64 {
64 if self.ndv > 0 {
65 1.0 / self.ndv as f64
66 } else {
67 0.01 }
69 }
70
71 pub fn range_selectivity(&self, lower: Option<f64>, upper: Option<f64>) -> f64 {
73 match (self.min_value, self.max_value) {
74 (Some(min), Some(max)) if max > min => {
75 let range = max - min;
76 let low = lower.unwrap_or(min);
77 let high = upper.unwrap_or(max);
78 ((high - low) / range).clamp(0.0, 1.0)
79 }
80 _ => 0.25, }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct TableStats {
88 pub name: String,
90 pub row_count: u64,
92 columns: HashMap<String, ColumnStats>,
94 pub avg_row_size: Option<usize>,
96 pub last_updated: Option<u64>,
98}
99
100impl TableStats {
101 pub fn new(name: String, row_count: u64) -> Self {
103 Self {
104 name,
105 row_count,
106 columns: HashMap::new(),
107 avg_row_size: None,
108 last_updated: None,
109 }
110 }
111
112 pub fn add_column(&mut self, stats: ColumnStats) {
114 self.columns.insert(stats.name.clone(), stats);
115 }
116
117 pub fn get_column(&self, name: &str) -> Option<&ColumnStats> {
119 self.columns.get(name)
120 }
121
122 pub fn column_names(&self) -> Vec<&str> {
124 self.columns.keys().map(|s| s.as_str()).collect()
125 }
126
127 pub fn with_avg_row_size(mut self, size: usize) -> Self {
129 self.avg_row_size = Some(size);
130 self
131 }
132
133 pub fn estimated_size(&self) -> Option<u64> {
135 self.avg_row_size.map(|size| self.row_count * size as u64)
136 }
137}
138
139pub struct StatsCollector {
141 columns: HashMap<String, ColumnCollector>,
143 row_count: u64,
145 total_size: usize,
147}
148
149impl StatsCollector {
150 pub fn new() -> Self {
152 Self {
153 columns: HashMap::new(),
154 row_count: 0,
155 total_size: 0,
156 }
157 }
158
159 pub fn add_column(&mut self, name: &str) {
161 self.columns
162 .insert(name.to_string(), ColumnCollector::new(name.to_string()));
163 }
164
165 pub fn observe_row(&mut self, row_size: usize) {
167 self.row_count += 1;
168 self.total_size += row_size;
169 }
170
171 pub fn observe_value(&mut self, column: &str, value: Option<&ObservedValue>) {
173 if let Some(collector) = self.columns.get_mut(column) {
174 collector.observe(value);
175 }
176 }
177
178 pub fn build(self, table_name: String) -> TableStats {
180 let mut stats = TableStats::new(table_name, self.row_count);
181
182 if self.row_count > 0 {
183 stats.avg_row_size = Some(self.total_size / self.row_count as usize);
184 }
185
186 for (_, collector) in self.columns {
187 stats.add_column(collector.build(self.row_count));
188 }
189
190 stats
191 }
192}
193
194impl Default for StatsCollector {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200#[derive(Debug, Clone)]
202pub enum ObservedValue {
203 Int(i64),
204 Float(f64),
205 String(String),
206 Bool(bool),
207 Bytes(Vec<u8>),
208}
209
210impl ObservedValue {
211 pub fn as_f64(&self) -> Option<f64> {
212 match self {
213 ObservedValue::Int(i) => Some(*i as f64),
214 ObservedValue::Float(f) => Some(*f),
215 _ => None,
216 }
217 }
218}
219
220struct ColumnCollector {
222 name: String,
223 distinct: std::collections::HashSet<u64>,
225 null_count: u64,
227 min_value: Option<f64>,
229 max_value: Option<f64>,
231}
232
233impl ColumnCollector {
234 fn new(name: String) -> Self {
235 Self {
236 name,
237 distinct: std::collections::HashSet::new(),
238 null_count: 0,
239 min_value: None,
240 max_value: None,
241 }
242 }
243
244 fn observe(&mut self, value: Option<&ObservedValue>) {
245 match value {
246 None => {
247 self.null_count += 1;
248 }
249 Some(v) => {
250 let hash = Self::hash_value(v);
252 self.distinct.insert(hash);
253
254 if let Some(f) = v.as_f64() {
256 self.min_value = Some(match self.min_value {
257 Some(min) => min.min(f),
258 None => f,
259 });
260 self.max_value = Some(match self.max_value {
261 Some(max) => max.max(f),
262 None => f,
263 });
264 }
265 }
266 }
267 }
268
269 fn hash_value(value: &ObservedValue) -> u64 {
270 use std::hash::{Hash, Hasher};
271 let mut hasher = std::collections::hash_map::DefaultHasher::new();
272
273 match value {
274 ObservedValue::Int(i) => i.hash(&mut hasher),
275 ObservedValue::Float(f) => f.to_bits().hash(&mut hasher),
276 ObservedValue::String(s) => s.hash(&mut hasher),
277 ObservedValue::Bool(b) => b.hash(&mut hasher),
278 ObservedValue::Bytes(b) => b.hash(&mut hasher),
279 }
280
281 hasher.finish()
282 }
283
284 fn build(self, row_count: u64) -> ColumnStats {
285 let null_fraction = if row_count > 0 {
286 self.null_count as f64 / row_count as f64
287 } else {
288 0.0
289 };
290
291 ColumnStats {
292 name: self.name,
293 ndv: self.distinct.len() as u64,
294 null_fraction,
295 min_value: self.min_value,
296 max_value: self.max_value,
297 }
298 }
299}
300
301pub struct StatsRegistry {
303 tables: RwLock<HashMap<String, TableStats>>,
305}
306
307impl StatsRegistry {
308 pub fn new() -> Self {
310 Self {
311 tables: RwLock::new(HashMap::new()),
312 }
313 }
314
315 pub fn register(&self, stats: TableStats) {
317 let mut tables = write_unpoisoned(&self.tables);
318 tables.insert(stats.name.clone(), stats);
319 }
320
321 pub fn get(&self, table_name: &str) -> Option<TableStats> {
323 let tables = read_unpoisoned(&self.tables);
324 tables.get(table_name).cloned()
325 }
326
327 pub fn remove(&self, table_name: &str) -> Option<TableStats> {
329 let mut tables = write_unpoisoned(&self.tables);
330 tables.remove(table_name)
331 }
332
333 pub fn list(&self) -> Vec<String> {
335 let tables = read_unpoisoned(&self.tables);
336 tables.keys().cloned().collect()
337 }
338
339 pub fn clear(&self) {
341 let mut tables = write_unpoisoned(&self.tables);
342 tables.clear();
343 }
344}
345
346impl Default for StatsRegistry {
347 fn default() -> Self {
348 Self::new()
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_column_stats() {
358 let stats = ColumnStats::new("status".to_string())
359 .with_ndv(5)
360 .with_null_fraction(0.1);
361
362 assert_eq!(stats.ndv, 5);
363 assert!((stats.null_fraction - 0.1).abs() < 0.001);
364 assert!((stats.equality_selectivity() - 0.2).abs() < 0.001);
365 }
366
367 #[test]
368 fn test_range_selectivity() {
369 let stats = ColumnStats::new("age".to_string())
370 .with_ndv(100)
371 .with_range(0.0, 100.0);
372
373 let sel = stats.range_selectivity(Some(0.0), Some(50.0));
375 assert!((sel - 0.5).abs() < 0.001);
376
377 let sel = stats.range_selectivity(Some(25.0), Some(50.0));
379 assert!((sel - 0.25).abs() < 0.001);
380 }
381
382 #[test]
383 fn test_table_stats() {
384 let mut stats = TableStats::new("users".to_string(), 10000);
385
386 stats.add_column(
387 ColumnStats::new("id".to_string())
388 .with_ndv(10000)
389 .with_null_fraction(0.0),
390 );
391
392 stats.add_column(
393 ColumnStats::new("status".to_string())
394 .with_ndv(5)
395 .with_null_fraction(0.02),
396 );
397
398 assert_eq!(stats.row_count, 10000);
399 assert!(stats.get_column("id").is_some());
400 assert!(stats.get_column("status").is_some());
401 assert!(stats.get_column("unknown").is_none());
402 }
403
404 #[test]
405 fn test_stats_collector() {
406 let mut collector = StatsCollector::new();
407 collector.add_column("value");
408
409 for i in 0..100 {
411 collector.observe_row(100);
412 if i % 10 == 0 {
413 collector.observe_value("value", None); } else {
415 collector.observe_value("value", Some(&ObservedValue::Int(i % 5)));
416 }
417 }
418
419 let stats = collector.build("test".to_string());
420
421 assert_eq!(stats.row_count, 100);
422 assert_eq!(stats.avg_row_size, Some(100));
423
424 let col = stats.get_column("value").unwrap();
425 assert_eq!(col.ndv, 5); assert!((col.null_fraction - 0.1).abs() < 0.01);
427 }
428
429 #[test]
430 fn test_stats_registry() {
431 let registry = StatsRegistry::new();
432
433 let stats = TableStats::new("users".to_string(), 1000);
434 registry.register(stats);
435
436 assert!(registry.get("users").is_some());
437 assert!(registry.get("orders").is_none());
438
439 assert_eq!(registry.list().len(), 1);
440
441 registry.remove("users");
442 assert!(registry.get("users").is_none());
443 }
444
445 #[test]
446 fn test_observed_value_hashing() {
447 let mut collector = StatsCollector::new();
448 collector.add_column("mixed");
449
450 collector.observe_value("mixed", Some(&ObservedValue::Int(42)));
452 collector.observe_value("mixed", Some(&ObservedValue::String("42".to_string())));
453 collector.observe_value("mixed", Some(&ObservedValue::Float(42.0)));
454
455 let stats = collector.build("test".to_string());
456 let col = stats.get_column("mixed").unwrap();
457
458 assert_eq!(col.ndv, 3);
460 }
461}