use super::histogram::Histogram;
use grafeo_common::types::Value;
use std::collections::HashMap;
pub type PropertyKey = String;
#[derive(Debug, Clone, Default)]
pub struct Statistics {
pub labels: HashMap<String, LabelStatistics>,
pub edge_types: HashMap<String, EdgeTypeStatistics>,
pub properties: HashMap<PropertyKey, ColumnStatistics>,
pub total_nodes: u64,
pub total_edges: u64,
}
impl Statistics {
pub fn new() -> Self {
Self::default()
}
pub fn update_label(&mut self, label: &str, stats: LabelStatistics) {
self.labels.insert(label.to_string(), stats);
}
pub fn update_edge_type(&mut self, edge_type: &str, stats: EdgeTypeStatistics) {
self.edge_types.insert(edge_type.to_string(), stats);
}
pub fn update_property(&mut self, property: &str, stats: ColumnStatistics) {
self.properties.insert(property.to_string(), stats);
}
pub fn get_label(&self, label: &str) -> Option<&LabelStatistics> {
self.labels.get(label)
}
pub fn get_edge_type(&self, edge_type: &str) -> Option<&EdgeTypeStatistics> {
self.edge_types.get(edge_type)
}
pub fn get_property(&self, property: &str) -> Option<&ColumnStatistics> {
self.properties.get(property)
}
pub fn estimate_label_cardinality(&self, label: &str) -> f64 {
self.labels
.get(label)
.map_or(1000.0, |s| s.node_count as f64) }
pub fn estimate_avg_degree(&self, edge_type: &str, outgoing: bool) -> f64 {
self.edge_types.get(edge_type).map_or(10.0, |s| {
if outgoing {
s.avg_out_degree
} else {
s.avg_in_degree
}
}) }
pub fn estimate_equality_selectivity(&self, property: &str, _value: &Value) -> f64 {
self.properties.get(property).map_or(0.5, |s| {
if s.distinct_count > 0 {
1.0 / s.distinct_count as f64
} else {
0.5
}
})
}
pub fn estimate_range_selectivity(
&self,
property: &str,
lower: Option<&Value>,
upper: Option<&Value>,
) -> f64 {
self.properties
.get(property)
.and_then(|s| s.histogram.as_ref())
.map_or(0.33, |h| {
h.estimate_range_selectivity(lower, upper, true, true)
}) }
}
#[derive(Debug, Clone)]
pub struct LabelStatistics {
pub node_count: u64,
pub avg_out_degree: f64,
pub avg_in_degree: f64,
pub properties: HashMap<PropertyKey, ColumnStatistics>,
}
impl LabelStatistics {
pub fn new(node_count: u64) -> Self {
Self {
node_count,
avg_out_degree: 0.0,
avg_in_degree: 0.0,
properties: HashMap::new(),
}
}
pub fn with_degrees(mut self, out_degree: f64, in_degree: f64) -> Self {
self.avg_out_degree = out_degree;
self.avg_in_degree = in_degree;
self
}
pub fn with_property(mut self, property: &str, stats: ColumnStatistics) -> Self {
self.properties.insert(property.to_string(), stats);
self
}
}
pub type TableStatistics = LabelStatistics;
#[derive(Debug, Clone)]
pub struct EdgeTypeStatistics {
pub edge_count: u64,
pub avg_out_degree: f64,
pub avg_in_degree: f64,
pub properties: HashMap<PropertyKey, ColumnStatistics>,
}
impl EdgeTypeStatistics {
pub fn new(edge_count: u64, avg_out_degree: f64, avg_in_degree: f64) -> Self {
Self {
edge_count,
avg_out_degree,
avg_in_degree,
properties: HashMap::new(),
}
}
pub fn with_property(mut self, property: &str, stats: ColumnStatistics) -> Self {
self.properties.insert(property.to_string(), stats);
self
}
}
#[derive(Debug, Clone)]
pub struct ColumnStatistics {
pub distinct_count: u64,
pub total_count: u64,
pub null_count: u64,
pub min_value: Option<Value>,
pub max_value: Option<Value>,
pub avg_value: Option<f64>,
pub histogram: Option<Histogram>,
pub most_common: Vec<(Value, f64)>,
}
impl ColumnStatistics {
pub fn new(distinct_count: u64, total_count: u64, null_count: u64) -> Self {
Self {
distinct_count,
total_count,
null_count,
min_value: None,
max_value: None,
avg_value: None,
histogram: None,
most_common: Vec::new(),
}
}
pub fn with_min_max(mut self, min: Value, max: Value) -> Self {
self.min_value = Some(min);
self.max_value = Some(max);
self
}
pub fn with_avg(mut self, avg: f64) -> Self {
self.avg_value = Some(avg);
self
}
pub fn with_histogram(mut self, histogram: Histogram) -> Self {
self.histogram = Some(histogram);
self
}
pub fn with_most_common(mut self, values: Vec<(Value, f64)>) -> Self {
self.most_common = values;
self
}
pub fn null_fraction(&self) -> f64 {
if self.total_count == 0 {
0.0
} else {
self.null_count as f64 / self.total_count as f64
}
}
pub fn estimate_equality_selectivity(&self, value: &Value) -> f64 {
for (mcv, freq) in &self.most_common {
if mcv == value {
return *freq;
}
}
if let Some(ref hist) = self.histogram {
return hist.estimate_equality_selectivity(value);
}
if self.distinct_count > 0 {
1.0 / self.distinct_count as f64
} else {
0.0
}
}
pub fn estimate_range_selectivity(&self, lower: Option<&Value>, upper: Option<&Value>) -> f64 {
if let Some(ref hist) = self.histogram {
return hist.estimate_range_selectivity(lower, upper, true, true);
}
match (&self.min_value, &self.max_value, lower, upper) {
(Some(min), Some(max), Some(l), Some(u)) => {
estimate_linear_range(min, max, l, u)
}
(Some(_), Some(_), Some(_), None) => 0.5, (Some(_), Some(_), None, Some(_)) => 0.5, _ => 0.33, }
}
}
fn estimate_linear_range(min: &Value, max: &Value, lower: &Value, upper: &Value) -> f64 {
match (min, max, lower, upper) {
(
Value::Int64(min_v),
Value::Int64(max_v),
Value::Int64(lower_v),
Value::Int64(upper_v),
) => {
let total_range = (max_v - min_v) as f64;
if total_range <= 0.0 {
return 1.0;
}
let effective_lower = (*lower_v).max(*min_v);
let effective_upper = (*upper_v).min(*max_v);
if effective_upper < effective_lower {
return 0.0;
}
(effective_upper - effective_lower) as f64 / total_range
}
(
Value::Float64(min_v),
Value::Float64(max_v),
Value::Float64(lower_v),
Value::Float64(upper_v),
) => {
let total_range = max_v - min_v;
if total_range <= 0.0 {
return 1.0;
}
let effective_lower = lower_v.max(*min_v);
let effective_upper = upper_v.min(*max_v);
if effective_upper < effective_lower {
return 0.0;
}
(effective_upper - effective_lower) / total_range
}
_ => 0.33,
}
}
#[cfg(test)]
pub(crate) struct StatisticsCollector {
values: Vec<Value>,
distinct: std::collections::HashSet<String>,
min: Option<Value>,
max: Option<Value>,
sum: f64,
null_count: u64,
frequencies: HashMap<String, u64>,
}
#[cfg(test)]
impl StatisticsCollector {
pub fn new() -> Self {
Self {
values: Vec::new(),
distinct: std::collections::HashSet::new(),
min: None,
max: None,
sum: 0.0,
null_count: 0,
frequencies: HashMap::new(),
}
}
pub fn add(&mut self, value: Value) {
if matches!(value, Value::Null) {
self.null_count += 1;
return;
}
let key = format!("{value:?}");
self.distinct.insert(key.clone());
*self.frequencies.entry(key).or_insert(0) += 1;
self.update_min_max(&value);
if let Some(v) = value_to_f64(&value) {
self.sum += v;
}
self.values.push(value);
}
fn update_min_max(&mut self, value: &Value) {
match &self.min {
None => self.min = Some(value.clone()),
Some(current) => {
if compare_values(value, current) == Some(std::cmp::Ordering::Less) {
self.min = Some(value.clone());
}
}
}
match &self.max {
None => self.max = Some(value.clone()),
Some(current) => {
if compare_values(value, current) == Some(std::cmp::Ordering::Greater) {
self.max = Some(value.clone());
}
}
}
}
pub fn build(mut self, num_histogram_buckets: usize, num_mcv: usize) -> ColumnStatistics {
let total_count = self.values.len() as u64 + self.null_count;
let distinct_count = self.distinct.len() as u64;
let avg = if !self.values.is_empty() {
Some(self.sum / self.values.len() as f64)
} else {
None
};
self.values
.sort_by(|a, b| compare_values(a, b).unwrap_or(std::cmp::Ordering::Equal));
let histogram = if self.values.len() >= num_histogram_buckets {
Some(Histogram::build(&self.values, num_histogram_buckets))
} else {
None
};
let total_non_null = self.values.len() as f64;
let mut freq_vec: Vec<_> = self.frequencies.into_iter().collect();
freq_vec.sort_by_key(|b| std::cmp::Reverse(b.1));
let most_common: Vec<(Value, f64)> = freq_vec
.into_iter()
.take(num_mcv)
.filter_map(|(key, count)| {
let freq = count as f64 / total_non_null;
if key.starts_with("Int64(") {
let num_str = key.trim_start_matches("Int64(").trim_end_matches(')');
num_str.parse::<i64>().ok().map(|n| (Value::Int64(n), freq))
} else if key.starts_with("String(") {
let s = key
.trim_start_matches("String(Arc(\"")
.trim_end_matches("\"))");
Some((Value::String(s.to_string().into()), freq))
} else {
None
}
})
.collect();
let mut stats = ColumnStatistics::new(distinct_count, total_count, self.null_count);
if let Some(min) = self.min
&& let Some(max) = self.max
{
stats = stats.with_min_max(min, max);
}
if let Some(avg) = avg {
stats = stats.with_avg(avg);
}
if let Some(hist) = histogram {
stats = stats.with_histogram(hist);
}
if !most_common.is_empty() {
stats = stats.with_most_common(most_common);
}
stats
}
}
#[cfg(test)]
impl Default for StatisticsCollector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
fn value_to_f64(value: &Value) -> Option<f64> {
match value {
Value::Int64(i) => Some(*i as f64),
Value::Float64(f) => Some(*f),
_ => None,
}
}
#[cfg(test)]
fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
match (a, b) {
(Value::Int64(a), Value::Int64(b)) => Some(a.cmp(b)),
(Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
(Value::String(a), Value::String(b)) => Some(a.cmp(b)),
(Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
(Value::Int64(a), Value::Float64(b)) => (*a as f64).partial_cmp(b),
(Value::Float64(a), Value::Int64(b)) => a.partial_cmp(&(*b as f64)),
(Value::Timestamp(a), Value::Timestamp(b)) => Some(a.cmp(b)),
(Value::Date(a), Value::Date(b)) => Some(a.cmp(b)),
(Value::Time(a), Value::Time(b)) => Some(a.cmp(b)),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_statistics_collector() {
let mut collector = StatisticsCollector::new();
for i in 0..100 {
collector.add(Value::Int64(i % 10)); }
let stats = collector.build(10, 5);
assert_eq!(stats.distinct_count, 10);
assert_eq!(stats.total_count, 100);
assert_eq!(stats.null_count, 0);
assert_eq!(stats.min_value, Some(Value::Int64(0)));
assert_eq!(stats.max_value, Some(Value::Int64(9)));
}
#[test]
fn test_statistics_with_nulls() {
let mut collector = StatisticsCollector::new();
collector.add(Value::Int64(1));
collector.add(Value::Null);
collector.add(Value::Int64(2));
collector.add(Value::Null);
collector.add(Value::Int64(3));
let stats = collector.build(5, 3);
assert_eq!(stats.total_count, 5);
assert_eq!(stats.null_count, 2);
assert_eq!(stats.distinct_count, 3);
assert!((stats.null_fraction() - 0.4).abs() < 0.01);
}
#[test]
fn test_label_statistics() {
let stats = LabelStatistics::new(1000)
.with_degrees(5.0, 3.0)
.with_property(
"age",
ColumnStatistics::new(50, 1000, 10)
.with_min_max(Value::Int64(0), Value::Int64(100)),
);
assert_eq!(stats.node_count, 1000);
assert_eq!(stats.avg_out_degree, 5.0);
assert!(stats.properties.contains_key("age"));
}
#[test]
fn test_statistics_min_max_updates() {
let mut collector = StatisticsCollector::new();
collector.add(Value::Int64(50));
collector.add(Value::Int64(10)); collector.add(Value::Int64(90)); collector.add(Value::Int64(5)); collector.add(Value::Int64(95));
let stats = collector.build(2, 3);
assert_eq!(stats.min_value, Some(Value::Int64(5)));
assert_eq!(stats.max_value, Some(Value::Int64(95)));
}
#[test]
fn test_statistics_most_common_values() {
let mut collector = StatisticsCollector::new();
for _ in 0..50 {
collector.add(Value::Int64(42));
}
for _ in 0..30 {
collector.add(Value::Int64(7));
}
for _ in 0..20 {
collector.add(Value::String("hello".into()));
}
let stats = collector.build(5, 3);
assert!(
!stats.most_common.is_empty(),
"MCV list should be populated"
);
let (top_val, top_freq) = &stats.most_common[0];
assert_eq!(*top_val, Value::Int64(42));
assert!((top_freq - 0.5).abs() < 0.01, "42 appears 50/100 = 0.5");
let has_string = stats
.most_common
.iter()
.any(|(v, _)| matches!(v, Value::String(_)));
assert!(has_string, "String MCVs should be parsed back");
}
#[test]
fn test_database_statistics() {
let mut db_stats = Statistics::new();
db_stats.update_label(
"Person",
LabelStatistics::new(10000).with_degrees(10.0, 10.0),
);
db_stats.update_edge_type("KNOWS", EdgeTypeStatistics::new(50000, 5.0, 5.0));
assert_eq!(db_stats.estimate_label_cardinality("Person"), 10000.0);
assert_eq!(db_stats.estimate_label_cardinality("Unknown"), 1000.0);
assert_eq!(db_stats.estimate_avg_degree("KNOWS", true), 5.0);
}
}