Skip to main content

cypherlite_query/executor/operators/
aggregate.rs

1// AggregateOp: count, sum, avg, min, max, collect
2
3use crate::executor::eval::eval;
4use crate::executor::{ExecutionError, Params, Record, ScalarFnLookup, Value};
5use crate::parser::ast::Expression;
6use crate::planner::AggregateFunc;
7use cypherlite_storage::StorageEngine;
8
9/// Execute aggregation over source records.
10/// Groups by group_keys, computes aggregate functions per group.
11pub fn execute_aggregate(
12    source_records: Vec<Record>,
13    group_keys: &[Expression],
14    aggregates: &[(String, AggregateFunc)],
15    engine: &StorageEngine,
16    params: &Params,
17    scalar_fns: &dyn ScalarFnLookup,
18) -> Result<Vec<Record>, ExecutionError> {
19    if source_records.is_empty() {
20        // For empty input with aggregates, return one row with zero counts
21        let mut record = Record::new();
22        for (alias, func) in aggregates {
23            let value = match func {
24                AggregateFunc::Count { .. } | AggregateFunc::CountStar => Value::Int64(0),
25            };
26            record.insert(alias.clone(), value);
27        }
28        // Add group key values as Null
29        for key_expr in group_keys {
30            let col_name = group_key_name(key_expr);
31            record.insert(col_name, Value::Null);
32        }
33        return Ok(vec![record]);
34    }
35
36    // Build groups: evaluate group keys for each record
37    let mut groups: Vec<(Vec<Value>, Vec<&Record>)> = Vec::new();
38
39    for record in &source_records {
40        let key_values: Vec<Value> = group_keys
41            .iter()
42            .map(|expr| eval(expr, record, engine, params, scalar_fns))
43            .collect::<Result<_, _>>()?;
44
45        // Find existing group
46        let found = groups.iter_mut().find(|(k, _)| k == &key_values);
47        if let Some((_, members)) = found {
48            members.push(record);
49        } else {
50            groups.push((key_values, vec![record]));
51        }
52    }
53
54    // If no group keys and no groups, treat all records as one group
55    if group_keys.is_empty() && groups.is_empty() {
56        groups.push((vec![], source_records.iter().collect()));
57    }
58
59    let mut results = Vec::new();
60
61    for (key_values, members) in &groups {
62        let mut result_record = Record::new();
63
64        // Add group key values
65        for (i, key_expr) in group_keys.iter().enumerate() {
66            let col_name = group_key_name(key_expr);
67            result_record.insert(col_name, key_values[i].clone());
68        }
69
70        // Compute aggregates
71        for (alias, func) in aggregates {
72            let value = compute_aggregate(func, members, engine, params)?;
73            result_record.insert(alias.clone(), value);
74        }
75
76        results.push(result_record);
77    }
78
79    Ok(results)
80}
81
82/// Compute a single aggregate function over a group of records.
83fn compute_aggregate(
84    func: &AggregateFunc,
85    members: &[&Record],
86    _engine: &StorageEngine,
87    _params: &Params,
88) -> Result<Value, ExecutionError> {
89    match func {
90        AggregateFunc::CountStar => Ok(Value::Int64(members.len() as i64)),
91        AggregateFunc::Count { distinct: _ } => {
92            // count(expr) counts non-null values
93            // For simplicity, count all members (since we don't have the expr here)
94            Ok(Value::Int64(members.len() as i64))
95        }
96    }
97}
98
99/// Extract a display name from a group key expression.
100fn group_key_name(expr: &Expression) -> String {
101    match expr {
102        Expression::Variable(name) => name.clone(),
103        Expression::Property(_, prop) => prop.clone(),
104        _ => "key".to_string(),
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use cypherlite_core::{DatabaseConfig, SyncMode};
112    use tempfile::tempdir;
113
114    fn test_engine(dir: &std::path::Path) -> cypherlite_storage::StorageEngine {
115        let config = DatabaseConfig {
116            path: dir.join("test.cyl"),
117            wal_sync_mode: SyncMode::Normal,
118            ..Default::default()
119        };
120        cypherlite_storage::StorageEngine::open(config).expect("open")
121    }
122
123    // EXEC-T010: AggregateOp count(*)
124    #[test]
125    fn test_aggregate_count_star() {
126        let dir = tempdir().expect("tempdir");
127        let engine = test_engine(dir.path());
128
129        let mut r1 = Record::new();
130        r1.insert("n".to_string(), Value::Int64(1));
131        let mut r2 = Record::new();
132        r2.insert("n".to_string(), Value::Int64(2));
133        let mut r3 = Record::new();
134        r3.insert("n".to_string(), Value::Int64(3));
135
136        let aggregates = vec![("count(*)".to_string(), AggregateFunc::CountStar)];
137
138        let params = Params::new();
139        let result = execute_aggregate(vec![r1, r2, r3], &[], &aggregates, &engine, &params, &());
140        let records = result.expect("should succeed");
141        assert_eq!(records.len(), 1);
142        assert_eq!(records[0].get("count(*)"), Some(&Value::Int64(3)));
143    }
144
145    #[test]
146    fn test_aggregate_count_star_empty() {
147        let dir = tempdir().expect("tempdir");
148        let engine = test_engine(dir.path());
149
150        let aggregates = vec![("count(*)".to_string(), AggregateFunc::CountStar)];
151        let params = Params::new();
152        let result = execute_aggregate(vec![], &[], &aggregates, &engine, &params, &());
153        let records = result.expect("should succeed");
154        assert_eq!(records.len(), 1);
155        assert_eq!(records[0].get("count(*)"), Some(&Value::Int64(0)));
156    }
157
158    #[test]
159    fn test_aggregate_with_group_keys() {
160        let dir = tempdir().expect("tempdir");
161        let engine = test_engine(dir.path());
162
163        let mut r1 = Record::new();
164        r1.insert("label".to_string(), Value::String("A".into()));
165        let mut r2 = Record::new();
166        r2.insert("label".to_string(), Value::String("B".into()));
167        let mut r3 = Record::new();
168        r3.insert("label".to_string(), Value::String("A".into()));
169
170        let group_keys = vec![Expression::Variable("label".to_string())];
171        let aggregates = vec![("cnt".to_string(), AggregateFunc::CountStar)];
172
173        let params = Params::new();
174        let result = execute_aggregate(
175            vec![r1, r2, r3],
176            &group_keys,
177            &aggregates,
178            &engine,
179            &params,
180            &(),
181        );
182        let records = result.expect("should succeed");
183        assert_eq!(records.len(), 2);
184
185        // Find group A and B
186        let group_a = records
187            .iter()
188            .find(|r| r.get("label") == Some(&Value::String("A".into())));
189        let group_b = records
190            .iter()
191            .find(|r| r.get("label") == Some(&Value::String("B".into())));
192
193        assert_eq!(group_a.expect("group A").get("cnt"), Some(&Value::Int64(2)));
194        assert_eq!(group_b.expect("group B").get("cnt"), Some(&Value::Int64(1)));
195    }
196}