Skip to main content

grafeo_core/statistics/
collector.rs

1//! Collecting and storing graph statistics.
2//!
3//! Use [`StatisticsCollector`] to stream values through and build statistics,
4//! or construct [`ColumnStatistics`] directly if you already know the numbers.
5//! The [`Statistics`] struct holds everything the optimizer needs.
6
7use super::histogram::Histogram;
8use grafeo_common::types::Value;
9use std::collections::HashMap;
10
11/// A property key identifier.
12pub type PropertyKey = String;
13
14/// Everything the optimizer knows about the data - cardinalities, distributions, degrees.
15///
16/// This is the main struct the query planner consults when choosing between
17/// different execution strategies.
18#[derive(Debug, Clone, Default)]
19pub struct Statistics {
20    /// Per-label statistics.
21    pub labels: HashMap<String, LabelStatistics>,
22    /// Per-edge-type statistics.
23    pub edge_types: HashMap<String, EdgeTypeStatistics>,
24    /// Per-property statistics (across all labels).
25    pub properties: HashMap<PropertyKey, ColumnStatistics>,
26    /// Total node count.
27    pub total_nodes: u64,
28    /// Total edge count.
29    pub total_edges: u64,
30}
31
32impl Statistics {
33    /// Creates a new empty statistics object.
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// Updates label statistics.
39    pub fn update_label(&mut self, label: &str, stats: LabelStatistics) {
40        self.labels.insert(label.to_string(), stats);
41    }
42
43    /// Updates edge type statistics.
44    pub fn update_edge_type(&mut self, edge_type: &str, stats: EdgeTypeStatistics) {
45        self.edge_types.insert(edge_type.to_string(), stats);
46    }
47
48    /// Updates property statistics.
49    pub fn update_property(&mut self, property: &str, stats: ColumnStatistics) {
50        self.properties.insert(property.to_string(), stats);
51    }
52
53    /// Gets label statistics.
54    pub fn get_label(&self, label: &str) -> Option<&LabelStatistics> {
55        self.labels.get(label)
56    }
57
58    /// Gets edge type statistics.
59    pub fn get_edge_type(&self, edge_type: &str) -> Option<&EdgeTypeStatistics> {
60        self.edge_types.get(edge_type)
61    }
62
63    /// Gets property statistics.
64    pub fn get_property(&self, property: &str) -> Option<&ColumnStatistics> {
65        self.properties.get(property)
66    }
67
68    /// Estimates the cardinality of a label scan.
69    pub fn estimate_label_cardinality(&self, label: &str) -> f64 {
70        self.labels
71            .get(label)
72            .map_or(1000.0, |s| s.node_count as f64) // Default estimate if no statistics
73    }
74
75    /// Estimates the average degree for an edge type.
76    pub fn estimate_avg_degree(&self, edge_type: &str, outgoing: bool) -> f64 {
77        self.edge_types.get(edge_type).map_or(10.0, |s| {
78            if outgoing {
79                s.avg_out_degree
80            } else {
81                s.avg_in_degree
82            }
83        }) // Default estimate
84    }
85
86    /// Estimates selectivity of an equality predicate.
87    pub fn estimate_equality_selectivity(&self, property: &str, _value: &Value) -> f64 {
88        self.properties.get(property).map_or(0.5, |s| {
89            if s.distinct_count > 0 {
90                1.0 / s.distinct_count as f64
91            } else {
92                0.5
93            }
94        })
95    }
96
97    /// Estimates selectivity of a range predicate.
98    pub fn estimate_range_selectivity(
99        &self,
100        property: &str,
101        lower: Option<&Value>,
102        upper: Option<&Value>,
103    ) -> f64 {
104        self.properties
105            .get(property)
106            .and_then(|s| s.histogram.as_ref())
107            .map_or(0.33, |h| {
108                h.estimate_range_selectivity(lower, upper, true, true)
109            }) // Default for range predicates
110    }
111}
112
113/// Statistics for nodes with a particular label (like "Person" or "Company").
114#[derive(Debug, Clone)]
115pub struct LabelStatistics {
116    /// Number of nodes with this label.
117    pub node_count: u64,
118    /// Average outgoing degree.
119    pub avg_out_degree: f64,
120    /// Average incoming degree.
121    pub avg_in_degree: f64,
122    /// Per-property statistics for nodes with this label.
123    pub properties: HashMap<PropertyKey, ColumnStatistics>,
124}
125
126impl LabelStatistics {
127    /// Creates new label statistics.
128    pub fn new(node_count: u64) -> Self {
129        Self {
130            node_count,
131            avg_out_degree: 0.0,
132            avg_in_degree: 0.0,
133            properties: HashMap::new(),
134        }
135    }
136
137    /// Sets the average degrees.
138    pub fn with_degrees(mut self, out_degree: f64, in_degree: f64) -> Self {
139        self.avg_out_degree = out_degree;
140        self.avg_in_degree = in_degree;
141        self
142    }
143
144    /// Adds property statistics.
145    pub fn with_property(mut self, property: &str, stats: ColumnStatistics) -> Self {
146        self.properties.insert(property.to_string(), stats);
147        self
148    }
149}
150
151/// Alias for table statistics (used in relational contexts).
152pub type TableStatistics = LabelStatistics;
153
154/// Statistics for edges of a particular type (like "KNOWS" or "WORKS_AT").
155#[derive(Debug, Clone)]
156pub struct EdgeTypeStatistics {
157    /// Number of edges of this type.
158    pub edge_count: u64,
159    /// Average outgoing degree (edges per source node).
160    pub avg_out_degree: f64,
161    /// Average incoming degree (edges per target node).
162    pub avg_in_degree: f64,
163    /// Per-property statistics for edges of this type.
164    pub properties: HashMap<PropertyKey, ColumnStatistics>,
165}
166
167impl EdgeTypeStatistics {
168    /// Creates new edge type statistics.
169    pub fn new(edge_count: u64, avg_out_degree: f64, avg_in_degree: f64) -> Self {
170        Self {
171            edge_count,
172            avg_out_degree,
173            avg_in_degree,
174            properties: HashMap::new(),
175        }
176    }
177
178    /// Adds property statistics.
179    pub fn with_property(mut self, property: &str, stats: ColumnStatistics) -> Self {
180        self.properties.insert(property.to_string(), stats);
181        self
182    }
183}
184
185/// Detailed statistics about a property's values - min, max, histogram, null ratio.
186#[derive(Debug, Clone)]
187pub struct ColumnStatistics {
188    /// Number of distinct values.
189    pub distinct_count: u64,
190    /// Total number of values (including nulls).
191    pub total_count: u64,
192    /// Number of null values.
193    pub null_count: u64,
194    /// Minimum value (if applicable).
195    pub min_value: Option<Value>,
196    /// Maximum value (if applicable).
197    pub max_value: Option<Value>,
198    /// Average value (for numeric types).
199    pub avg_value: Option<f64>,
200    /// Equi-depth histogram (for selectivity estimation).
201    pub histogram: Option<Histogram>,
202    /// Most common values with their frequencies.
203    pub most_common: Vec<(Value, f64)>,
204}
205
206impl ColumnStatistics {
207    /// Creates new column statistics with basic info.
208    pub fn new(distinct_count: u64, total_count: u64, null_count: u64) -> Self {
209        Self {
210            distinct_count,
211            total_count,
212            null_count,
213            min_value: None,
214            max_value: None,
215            avg_value: None,
216            histogram: None,
217            most_common: Vec::new(),
218        }
219    }
220
221    /// Sets min/max values.
222    pub fn with_min_max(mut self, min: Value, max: Value) -> Self {
223        self.min_value = Some(min);
224        self.max_value = Some(max);
225        self
226    }
227
228    /// Sets the average value.
229    pub fn with_avg(mut self, avg: f64) -> Self {
230        self.avg_value = Some(avg);
231        self
232    }
233
234    /// Sets the histogram.
235    pub fn with_histogram(mut self, histogram: Histogram) -> Self {
236        self.histogram = Some(histogram);
237        self
238    }
239
240    /// Sets the most common values.
241    pub fn with_most_common(mut self, values: Vec<(Value, f64)>) -> Self {
242        self.most_common = values;
243        self
244    }
245
246    /// Returns the null fraction.
247    pub fn null_fraction(&self) -> f64 {
248        if self.total_count == 0 {
249            0.0
250        } else {
251            self.null_count as f64 / self.total_count as f64
252        }
253    }
254
255    /// Estimates selectivity for an equality predicate.
256    pub fn estimate_equality_selectivity(&self, value: &Value) -> f64 {
257        // Check most common values first
258        for (mcv, freq) in &self.most_common {
259            if mcv == value {
260                return *freq;
261            }
262        }
263
264        // Use histogram if available
265        if let Some(ref hist) = self.histogram {
266            return hist.estimate_equality_selectivity(value);
267        }
268
269        // Fall back to uniform distribution assumption
270        if self.distinct_count > 0 {
271            1.0 / self.distinct_count as f64
272        } else {
273            0.0
274        }
275    }
276
277    /// Estimates selectivity for a range predicate.
278    pub fn estimate_range_selectivity(&self, lower: Option<&Value>, upper: Option<&Value>) -> f64 {
279        if let Some(ref hist) = self.histogram {
280            return hist.estimate_range_selectivity(lower, upper, true, true);
281        }
282
283        // Without histogram, use min/max if available
284        match (&self.min_value, &self.max_value, lower, upper) {
285            (Some(min), Some(max), Some(l), Some(u)) => {
286                // Linear interpolation
287                estimate_linear_range(min, max, l, u)
288            }
289            (Some(_), Some(_), Some(_), None) => 0.5, // Greater than
290            (Some(_), Some(_), None, Some(_)) => 0.5, // Less than
291            _ => 0.33,                                // Default
292        }
293    }
294}
295
296/// Estimates range selectivity using linear interpolation.
297fn estimate_linear_range(min: &Value, max: &Value, lower: &Value, upper: &Value) -> f64 {
298    match (min, max, lower, upper) {
299        (
300            Value::Int64(min_v),
301            Value::Int64(max_v),
302            Value::Int64(lower_v),
303            Value::Int64(upper_v),
304        ) => {
305            let total_range = (max_v - min_v) as f64;
306            if total_range <= 0.0 {
307                return 1.0;
308            }
309
310            let effective_lower = (*lower_v).max(*min_v);
311            let effective_upper = (*upper_v).min(*max_v);
312
313            if effective_upper < effective_lower {
314                return 0.0;
315            }
316
317            (effective_upper - effective_lower) as f64 / total_range
318        }
319        (
320            Value::Float64(min_v),
321            Value::Float64(max_v),
322            Value::Float64(lower_v),
323            Value::Float64(upper_v),
324        ) => {
325            let total_range = max_v - min_v;
326            if total_range <= 0.0 {
327                return 1.0;
328            }
329
330            let effective_lower = lower_v.max(*min_v);
331            let effective_upper = upper_v.min(*max_v);
332
333            if effective_upper < effective_lower {
334                return 0.0;
335            }
336
337            (effective_upper - effective_lower) / total_range
338        }
339        _ => 0.33,
340    }
341}
342
343/// Streams values through to build statistics automatically.
344///
345/// Call [`add()`](Self::add) for each value, then [`build()`](Self::build)
346/// to get the final [`ColumnStatistics`] with histogram and most common values.
347#[cfg(test)]
348pub(crate) struct StatisticsCollector {
349    /// Values collected for histogram building.
350    values: Vec<Value>,
351    /// Distinct value tracker.
352    distinct: std::collections::HashSet<String>,
353    /// Running min.
354    min: Option<Value>,
355    /// Running max.
356    max: Option<Value>,
357    /// Running sum (for numeric).
358    sum: f64,
359    /// Null count.
360    null_count: u64,
361    /// Value frequency counter.
362    frequencies: HashMap<String, u64>,
363}
364
365#[cfg(test)]
366impl StatisticsCollector {
367    /// Creates a new statistics collector.
368    pub fn new() -> Self {
369        Self {
370            values: Vec::new(),
371            distinct: std::collections::HashSet::new(),
372            min: None,
373            max: None,
374            sum: 0.0,
375            null_count: 0,
376            frequencies: HashMap::new(),
377        }
378    }
379
380    /// Adds a value to the collector.
381    pub fn add(&mut self, value: Value) {
382        if matches!(value, Value::Null) {
383            self.null_count += 1;
384            return;
385        }
386
387        // Track distinct values
388        let key = format!("{value:?}");
389        self.distinct.insert(key.clone());
390
391        // Track frequencies
392        *self.frequencies.entry(key).or_insert(0) += 1;
393
394        // Track min/max
395        self.update_min_max(&value);
396
397        // Track sum for numeric
398        if let Some(v) = value_to_f64(&value) {
399            self.sum += v;
400        }
401
402        self.values.push(value);
403    }
404
405    fn update_min_max(&mut self, value: &Value) {
406        // Update min
407        match &self.min {
408            None => self.min = Some(value.clone()),
409            Some(current) => {
410                if compare_values(value, current) == Some(std::cmp::Ordering::Less) {
411                    self.min = Some(value.clone());
412                }
413            }
414        }
415
416        // Update max
417        match &self.max {
418            None => self.max = Some(value.clone()),
419            Some(current) => {
420                if compare_values(value, current) == Some(std::cmp::Ordering::Greater) {
421                    self.max = Some(value.clone());
422                }
423            }
424        }
425    }
426
427    /// Builds column statistics from collected data.
428    pub fn build(mut self, num_histogram_buckets: usize, num_mcv: usize) -> ColumnStatistics {
429        let total_count = self.values.len() as u64 + self.null_count;
430        let distinct_count = self.distinct.len() as u64;
431
432        let avg = if !self.values.is_empty() {
433            Some(self.sum / self.values.len() as f64)
434        } else {
435            None
436        };
437
438        // Build histogram
439        self.values
440            .sort_by(|a, b| compare_values(a, b).unwrap_or(std::cmp::Ordering::Equal));
441        let histogram = if self.values.len() >= num_histogram_buckets {
442            Some(Histogram::build(&self.values, num_histogram_buckets))
443        } else {
444            None
445        };
446
447        // Find most common values
448        let total_non_null = self.values.len() as f64;
449        let mut freq_vec: Vec<_> = self.frequencies.into_iter().collect();
450        freq_vec.sort_by(|a, b| b.1.cmp(&a.1));
451
452        let most_common: Vec<(Value, f64)> = freq_vec
453            .into_iter()
454            .take(num_mcv)
455            .filter_map(|(key, count)| {
456                // Try to parse the key back to a value (simplified)
457                let freq = count as f64 / total_non_null;
458                // This is a simplification - we'd need to store actual values
459                if key.starts_with("Int64(") {
460                    let num_str = key.trim_start_matches("Int64(").trim_end_matches(')');
461                    num_str.parse::<i64>().ok().map(|n| (Value::Int64(n), freq))
462                } else if key.starts_with("String(") {
463                    let s = key
464                        .trim_start_matches("String(Arc(\"")
465                        .trim_end_matches("\"))");
466                    Some((Value::String(s.to_string().into()), freq))
467                } else {
468                    None
469                }
470            })
471            .collect();
472
473        let mut stats = ColumnStatistics::new(distinct_count, total_count, self.null_count);
474
475        if let Some(min) = self.min
476            && let Some(max) = self.max
477        {
478            stats = stats.with_min_max(min, max);
479        }
480
481        if let Some(avg) = avg {
482            stats = stats.with_avg(avg);
483        }
484
485        if let Some(hist) = histogram {
486            stats = stats.with_histogram(hist);
487        }
488
489        if !most_common.is_empty() {
490            stats = stats.with_most_common(most_common);
491        }
492
493        stats
494    }
495}
496
497#[cfg(test)]
498impl Default for StatisticsCollector {
499    fn default() -> Self {
500        Self::new()
501    }
502}
503
504/// Converts a value to f64.
505#[cfg(test)]
506fn value_to_f64(value: &Value) -> Option<f64> {
507    match value {
508        Value::Int64(i) => Some(*i as f64),
509        Value::Float64(f) => Some(*f),
510        _ => None,
511    }
512}
513
514/// Compares two values.
515#[cfg(test)]
516fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
517    match (a, b) {
518        (Value::Int64(a), Value::Int64(b)) => Some(a.cmp(b)),
519        (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
520        (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
521        (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
522        (Value::Int64(a), Value::Float64(b)) => (*a as f64).partial_cmp(b),
523        (Value::Float64(a), Value::Int64(b)) => a.partial_cmp(&(*b as f64)),
524        (Value::Timestamp(a), Value::Timestamp(b)) => Some(a.cmp(b)),
525        (Value::Date(a), Value::Date(b)) => Some(a.cmp(b)),
526        (Value::Time(a), Value::Time(b)) => Some(a.cmp(b)),
527        _ => None,
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_statistics_collector() {
537        let mut collector = StatisticsCollector::new();
538
539        for i in 0..100 {
540            collector.add(Value::Int64(i % 10)); // Values 0-9, each appearing 10 times
541        }
542
543        let stats = collector.build(10, 5);
544
545        assert_eq!(stats.distinct_count, 10);
546        assert_eq!(stats.total_count, 100);
547        assert_eq!(stats.null_count, 0);
548        assert_eq!(stats.min_value, Some(Value::Int64(0)));
549        assert_eq!(stats.max_value, Some(Value::Int64(9)));
550    }
551
552    #[test]
553    fn test_statistics_with_nulls() {
554        let mut collector = StatisticsCollector::new();
555
556        collector.add(Value::Int64(1));
557        collector.add(Value::Null);
558        collector.add(Value::Int64(2));
559        collector.add(Value::Null);
560        collector.add(Value::Int64(3));
561
562        let stats = collector.build(5, 3);
563
564        assert_eq!(stats.total_count, 5);
565        assert_eq!(stats.null_count, 2);
566        assert_eq!(stats.distinct_count, 3);
567        assert!((stats.null_fraction() - 0.4).abs() < 0.01);
568    }
569
570    #[test]
571    fn test_label_statistics() {
572        let stats = LabelStatistics::new(1000)
573            .with_degrees(5.0, 3.0)
574            .with_property(
575                "age",
576                ColumnStatistics::new(50, 1000, 10)
577                    .with_min_max(Value::Int64(0), Value::Int64(100)),
578            );
579
580        assert_eq!(stats.node_count, 1000);
581        assert_eq!(stats.avg_out_degree, 5.0);
582        assert!(stats.properties.contains_key("age"));
583    }
584
585    #[test]
586    fn test_statistics_min_max_updates() {
587        // Values in decreasing then increasing order to exercise both min and max updates
588        let mut collector = StatisticsCollector::new();
589
590        collector.add(Value::Int64(50));
591        collector.add(Value::Int64(10)); // new min
592        collector.add(Value::Int64(90)); // new max
593        collector.add(Value::Int64(5)); // new min again
594        collector.add(Value::Int64(95)); // new max again
595
596        let stats = collector.build(2, 3);
597
598        assert_eq!(stats.min_value, Some(Value::Int64(5)));
599        assert_eq!(stats.max_value, Some(Value::Int64(95)));
600    }
601
602    #[test]
603    fn test_statistics_most_common_values() {
604        let mut collector = StatisticsCollector::new();
605
606        // Add values with known frequencies so MCVs are populated
607        for _ in 0..50 {
608            collector.add(Value::Int64(42));
609        }
610        for _ in 0..30 {
611            collector.add(Value::Int64(7));
612        }
613        for _ in 0..20 {
614            collector.add(Value::String("hello".into()));
615        }
616
617        let stats = collector.build(5, 3);
618
619        // Should have most_common populated with parsed Int64 and String values
620        assert!(
621            !stats.most_common.is_empty(),
622            "MCV list should be populated"
623        );
624
625        // The most frequent value should be Int64(42) at freq 0.5
626        let (top_val, top_freq) = &stats.most_common[0];
627        assert_eq!(*top_val, Value::Int64(42));
628        assert!((top_freq - 0.5).abs() < 0.01, "42 appears 50/100 = 0.5");
629
630        // Check that String values were also parsed back
631        let has_string = stats
632            .most_common
633            .iter()
634            .any(|(v, _)| matches!(v, Value::String(_)));
635        assert!(has_string, "String MCVs should be parsed back");
636    }
637
638    #[test]
639    fn test_database_statistics() {
640        let mut db_stats = Statistics::new();
641
642        db_stats.update_label(
643            "Person",
644            LabelStatistics::new(10000).with_degrees(10.0, 10.0),
645        );
646
647        db_stats.update_edge_type("KNOWS", EdgeTypeStatistics::new(50000, 5.0, 5.0));
648
649        assert_eq!(db_stats.estimate_label_cardinality("Person"), 10000.0);
650        assert_eq!(db_stats.estimate_label_cardinality("Unknown"), 1000.0); // Default
651
652        assert_eq!(db_stats.estimate_avg_degree("KNOWS", true), 5.0);
653    }
654}