cypherlite_query/executor/operators/
aggregate.rs1use crate::executor::eval::eval;
4use crate::executor::{ExecutionError, Params, Record, ScalarFnLookup, Value};
5use crate::parser::ast::Expression;
6use crate::planner::AggregateFunc;
7use cypherlite_storage::StorageEngine;
8
9pub fn execute_aggregate(
12 source_records: Vec<Record>,
13 group_keys: &[Expression],
14 aggregates: &[(String, AggregateFunc)],
15 engine: &StorageEngine,
16 params: &Params,
17 scalar_fns: &dyn ScalarFnLookup,
18) -> Result<Vec<Record>, ExecutionError> {
19 if source_records.is_empty() {
20 let mut record = Record::new();
22 for (alias, func) in aggregates {
23 let value = match func {
24 AggregateFunc::Count { .. } | AggregateFunc::CountStar => Value::Int64(0),
25 };
26 record.insert(alias.clone(), value);
27 }
28 for key_expr in group_keys {
30 let col_name = group_key_name(key_expr);
31 record.insert(col_name, Value::Null);
32 }
33 return Ok(vec![record]);
34 }
35
36 let mut groups: Vec<(Vec<Value>, Vec<&Record>)> = Vec::new();
38
39 for record in &source_records {
40 let key_values: Vec<Value> = group_keys
41 .iter()
42 .map(|expr| eval(expr, record, engine, params, scalar_fns))
43 .collect::<Result<_, _>>()?;
44
45 let found = groups.iter_mut().find(|(k, _)| k == &key_values);
47 if let Some((_, members)) = found {
48 members.push(record);
49 } else {
50 groups.push((key_values, vec![record]));
51 }
52 }
53
54 if group_keys.is_empty() && groups.is_empty() {
56 groups.push((vec![], source_records.iter().collect()));
57 }
58
59 let mut results = Vec::new();
60
61 for (key_values, members) in &groups {
62 let mut result_record = Record::new();
63
64 for (i, key_expr) in group_keys.iter().enumerate() {
66 let col_name = group_key_name(key_expr);
67 result_record.insert(col_name, key_values[i].clone());
68 }
69
70 for (alias, func) in aggregates {
72 let value = compute_aggregate(func, members, engine, params)?;
73 result_record.insert(alias.clone(), value);
74 }
75
76 results.push(result_record);
77 }
78
79 Ok(results)
80}
81
82fn compute_aggregate(
84 func: &AggregateFunc,
85 members: &[&Record],
86 _engine: &StorageEngine,
87 _params: &Params,
88) -> Result<Value, ExecutionError> {
89 match func {
90 AggregateFunc::CountStar => Ok(Value::Int64(members.len() as i64)),
91 AggregateFunc::Count { distinct: _ } => {
92 Ok(Value::Int64(members.len() as i64))
95 }
96 }
97}
98
99fn group_key_name(expr: &Expression) -> String {
101 match expr {
102 Expression::Variable(name) => name.clone(),
103 Expression::Property(_, prop) => prop.clone(),
104 _ => "key".to_string(),
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use cypherlite_core::{DatabaseConfig, SyncMode};
112 use tempfile::tempdir;
113
114 fn test_engine(dir: &std::path::Path) -> cypherlite_storage::StorageEngine {
115 let config = DatabaseConfig {
116 path: dir.join("test.cyl"),
117 wal_sync_mode: SyncMode::Normal,
118 ..Default::default()
119 };
120 cypherlite_storage::StorageEngine::open(config).expect("open")
121 }
122
123 #[test]
125 fn test_aggregate_count_star() {
126 let dir = tempdir().expect("tempdir");
127 let engine = test_engine(dir.path());
128
129 let mut r1 = Record::new();
130 r1.insert("n".to_string(), Value::Int64(1));
131 let mut r2 = Record::new();
132 r2.insert("n".to_string(), Value::Int64(2));
133 let mut r3 = Record::new();
134 r3.insert("n".to_string(), Value::Int64(3));
135
136 let aggregates = vec![("count(*)".to_string(), AggregateFunc::CountStar)];
137
138 let params = Params::new();
139 let result = execute_aggregate(vec![r1, r2, r3], &[], &aggregates, &engine, ¶ms, &());
140 let records = result.expect("should succeed");
141 assert_eq!(records.len(), 1);
142 assert_eq!(records[0].get("count(*)"), Some(&Value::Int64(3)));
143 }
144
145 #[test]
146 fn test_aggregate_count_star_empty() {
147 let dir = tempdir().expect("tempdir");
148 let engine = test_engine(dir.path());
149
150 let aggregates = vec![("count(*)".to_string(), AggregateFunc::CountStar)];
151 let params = Params::new();
152 let result = execute_aggregate(vec![], &[], &aggregates, &engine, ¶ms, &());
153 let records = result.expect("should succeed");
154 assert_eq!(records.len(), 1);
155 assert_eq!(records[0].get("count(*)"), Some(&Value::Int64(0)));
156 }
157
158 #[test]
159 fn test_aggregate_with_group_keys() {
160 let dir = tempdir().expect("tempdir");
161 let engine = test_engine(dir.path());
162
163 let mut r1 = Record::new();
164 r1.insert("label".to_string(), Value::String("A".into()));
165 let mut r2 = Record::new();
166 r2.insert("label".to_string(), Value::String("B".into()));
167 let mut r3 = Record::new();
168 r3.insert("label".to_string(), Value::String("A".into()));
169
170 let group_keys = vec![Expression::Variable("label".to_string())];
171 let aggregates = vec![("cnt".to_string(), AggregateFunc::CountStar)];
172
173 let params = Params::new();
174 let result = execute_aggregate(
175 vec![r1, r2, r3],
176 &group_keys,
177 &aggregates,
178 &engine,
179 ¶ms,
180 &(),
181 );
182 let records = result.expect("should succeed");
183 assert_eq!(records.len(), 2);
184
185 let group_a = records
187 .iter()
188 .find(|r| r.get("label") == Some(&Value::String("A".into())));
189 let group_b = records
190 .iter()
191 .find(|r| r.get("label") == Some(&Value::String("B".into())));
192
193 assert_eq!(group_a.expect("group A").get("cnt"), Some(&Value::Int64(2)));
194 assert_eq!(group_b.expect("group B").get("cnt"), Some(&Value::Int64(1)));
195 }
196}