1use std::{cmp::Ordering, fmt::Display};
2
3use crate::core::{output_model::OutputItems, row::Row};
4use anyhow::{Result, anyhow};
5use serde_json::Value;
6
7use crate::dsl::{
8 eval::resolve::{resolve_values, resolve_values_truthy},
9 parse::key_spec::KeySpec,
10 stages::common::{parse_alias_after_as, parse_stage_words},
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum AggregateFn {
15 Count,
16 Sum,
17 Avg,
18 Min,
19 Max,
20}
21
22#[derive(Debug, Clone)]
23struct AggregateSpec {
24 function: AggregateFn,
25 column_raw: Option<String>,
26 alias: String,
27}
28
29pub fn apply(items: OutputItems, spec: &str) -> Result<OutputItems> {
30 let parsed = parse_aggregate_spec(spec)?;
31 match items {
32 OutputItems::Rows(rows) => {
33 let value = aggregate_rows(&rows, &parsed);
34 let mut row = Row::new();
35 row.insert(parsed.alias, value);
36 Ok(OutputItems::Rows(vec![row]))
37 }
38 OutputItems::Groups(groups) => {
39 let enriched = groups
40 .into_iter()
41 .map(|mut group| {
42 let value = aggregate_rows(&group.rows, &parsed);
43 group.aggregates.insert(parsed.alias.clone(), value);
44 group
45 })
46 .collect::<Vec<_>>();
47 Ok(OutputItems::Groups(enriched))
48 }
49 }
50}
51
52pub fn count_macro(items: OutputItems, spec: &str) -> Result<OutputItems> {
53 if !spec.trim().is_empty() {
54 return Err(anyhow!("C takes no arguments"));
55 }
56
57 match items {
58 OutputItems::Rows(rows) => {
59 let mut row = Row::new();
60 row.insert("count".to_string(), Value::from(rows.len() as i64));
61 Ok(OutputItems::Rows(vec![row]))
62 }
63 OutputItems::Groups(groups) => {
64 let rows = groups
65 .into_iter()
66 .map(|group| {
67 let mut row = group.groups;
68 row.insert("count".to_string(), Value::from(group.rows.len() as i64));
69 row
70 })
71 .collect::<Vec<_>>();
72 Ok(OutputItems::Rows(rows))
73 }
74 }
75}
76
77fn parse_aggregate_spec(spec: &str) -> Result<AggregateSpec> {
78 let words = parse_stage_words(spec)?;
79
80 if words.is_empty() {
81 return Err(anyhow!("A requires an aggregate function"));
82 }
83
84 let (function, mut column_raw, from_parenthesized) = parse_function_and_column(&words[0])?;
85 let mut index = 1usize;
86
87 if column_raw.is_none() && index < words.len() {
88 if function == AggregateFn::Count && words.len() == 2 {
89 } else if !words[index].eq_ignore_ascii_case("AS") {
91 column_raw = Some(words[index].clone());
92 index += 1;
93 }
94 }
95
96 let alias = if let Some(alias) = parse_alias_after_as(&words, index, "A")? {
97 alias
98 } else if index < words.len() {
99 words[index].clone()
100 } else if let Some(column) = &column_raw {
101 if from_parenthesized {
102 format!("{}({column})", function.as_str())
103 } else {
104 column.clone()
105 }
106 } else {
107 function.default_alias().to_string()
108 };
109
110 Ok(AggregateSpec {
111 function,
112 column_raw,
113 alias,
114 })
115}
116
117fn parse_function_and_column(input: &str) -> Result<(AggregateFn, Option<String>, bool)> {
118 if let Some(open) = input.find('(') {
119 if !input.ends_with(')') {
120 return Err(anyhow!("A: malformed function call"));
121 }
122 let function_name = &input[..open];
123 let column = &input[open + 1..input.len() - 1];
124 let function = AggregateFn::parse(function_name)?;
125 let column = if column.trim().is_empty() {
126 None
127 } else {
128 Some(column.trim().to_string())
129 };
130 return Ok((function, column, true));
131 }
132
133 let function = AggregateFn::parse(input)?;
134 Ok((function, None, false))
135}
136
137fn aggregate_rows(rows: &[Row], spec: &AggregateSpec) -> Value {
138 let values = collect_column_values(rows, spec.column_raw.as_deref());
139
140 match spec.function {
141 AggregateFn::Count => Value::from(count_values(&values) as i64),
142 AggregateFn::Sum => Value::from(sum_values(&values)),
143 AggregateFn::Avg => {
144 let numbers = numeric_values(&values);
145 if numbers.is_empty() {
146 Value::from(0.0)
147 } else {
148 Value::from(numbers.iter().sum::<f64>() / numbers.len() as f64)
149 }
150 }
151 AggregateFn::Min => min_value(&values).unwrap_or(Value::Null),
152 AggregateFn::Max => max_value(&values).unwrap_or(Value::Null),
153 }
154}
155
156fn collect_column_values(rows: &[Row], column_raw: Option<&str>) -> Vec<Value> {
157 match column_raw {
158 None => rows.iter().map(|_| Value::Bool(true)).collect(),
159 Some(column_raw) => {
160 let key_spec = KeySpec::parse(column_raw);
161 if key_spec.existence {
162 rows.iter()
163 .map(|row| {
164 let found = resolve_values_truthy(row, &key_spec.token, key_spec.exact);
165 Value::Bool(if key_spec.negated { !found } else { found })
166 })
167 .collect()
168 } else {
169 rows.iter()
170 .flat_map(|row| resolve_values(row, &key_spec.token, key_spec.exact))
171 .flat_map(expand_array_value)
172 .collect()
173 }
174 }
175 }
176}
177
178fn expand_array_value(value: Value) -> Vec<Value> {
179 match value {
180 Value::Array(values) => values,
181 scalar => vec![scalar],
182 }
183}
184
185fn count_values(values: &[Value]) -> usize {
186 values.iter().filter(|value| !value.is_null()).count()
187}
188
189fn sum_values(values: &[Value]) -> f64 {
190 numeric_values(values).iter().sum()
191}
192
193fn numeric_values(values: &[Value]) -> Vec<f64> {
194 values
195 .iter()
196 .filter_map(|value| match value {
197 Value::Number(number) => number.as_f64(),
198 Value::String(text) => text.parse::<f64>().ok(),
199 Value::Bool(flag) => Some(if *flag { 1.0 } else { 0.0 }),
200 _ => None,
201 })
202 .collect()
203}
204
205fn min_value(values: &[Value]) -> Option<Value> {
206 values
207 .iter()
208 .filter(|value| !value.is_null())
209 .min_by(|left, right| compare_values(left, right))
210 .cloned()
211}
212
213fn max_value(values: &[Value]) -> Option<Value> {
214 values
215 .iter()
216 .filter(|value| !value.is_null())
217 .max_by(|left, right| compare_values(left, right))
218 .cloned()
219}
220
221fn compare_values(left: &Value, right: &Value) -> Ordering {
222 match (left, right) {
223 (Value::Number(a), Value::Number(b)) => a
224 .as_f64()
225 .partial_cmp(&b.as_f64())
226 .unwrap_or(Ordering::Equal),
227 (Value::String(a), Value::String(b)) => a.cmp(b),
228 _ => value_to_string(left).cmp(&value_to_string(right)),
229 }
230}
231
232fn value_to_string(value: &Value) -> String {
233 match value {
234 Value::String(text) => text.clone(),
235 other => other.to_string(),
236 }
237}
238
239impl AggregateFn {
240 fn parse(value: &str) -> Result<Self> {
241 match value.to_ascii_lowercase().as_str() {
242 "count" => Ok(Self::Count),
243 "sum" => Ok(Self::Sum),
244 "avg" => Ok(Self::Avg),
245 "min" => Ok(Self::Min),
246 "max" => Ok(Self::Max),
247 other => Err(anyhow!("A: unsupported function '{other}'")),
248 }
249 }
250
251 fn as_str(self) -> &'static str {
252 match self {
253 Self::Count => "count",
254 Self::Sum => "sum",
255 Self::Avg => "avg",
256 Self::Min => "min",
257 Self::Max => "max",
258 }
259 }
260
261 fn default_alias(self) -> &'static str {
262 match self {
263 Self::Count => "count",
264 Self::Sum => "sum",
265 Self::Avg => "avg",
266 Self::Min => "min",
267 Self::Max => "max",
268 }
269 }
270}
271
272impl Display for AggregateFn {
273 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274 f.write_str(self.as_str())
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use crate::core::output_model::{Group, OutputItems};
281 use serde_json::json;
282
283 use super::{apply, count_macro};
284
285 #[test]
286 fn aggregate_count_global() {
287 let rows = vec![
288 json!({"id": 1}).as_object().cloned().expect("object"),
289 json!({"id": 2}).as_object().cloned().expect("object"),
290 ];
291
292 let output = apply(OutputItems::Rows(rows), "count total").expect("aggregate should work");
293 match output {
294 OutputItems::Rows(rows) => {
295 assert_eq!(
296 rows[0].get("total").and_then(|value| value.as_i64()),
297 Some(2)
298 );
299 }
300 OutputItems::Groups(_) => panic!("expected rows"),
301 }
302 }
303
304 #[test]
305 fn aggregate_sum_and_avg() {
306 let rows = vec![
307 json!({"numbers": [1, 2]})
308 .as_object()
309 .cloned()
310 .expect("object"),
311 json!({"numbers": [3]})
312 .as_object()
313 .cloned()
314 .expect("object"),
315 ];
316
317 let output = apply(OutputItems::Rows(rows.clone()), "sum(numbers[]) total")
318 .expect("aggregate should work");
319 match output {
320 OutputItems::Rows(rows) => {
321 assert_eq!(
322 rows[0].get("total").and_then(|value| value.as_f64()),
323 Some(6.0)
324 );
325 }
326 OutputItems::Groups(_) => panic!("expected rows"),
327 }
328
329 let output = apply(OutputItems::Rows(rows), "avg(numbers[]) average")
330 .expect("aggregate should work");
331 match output {
332 OutputItems::Rows(rows) => {
333 assert_eq!(
334 rows[0].get("average").and_then(|value| value.as_f64()),
335 Some(2.0)
336 );
337 }
338 OutputItems::Groups(_) => panic!("expected rows"),
339 }
340 }
341
342 #[test]
343 fn aggregate_on_groups_adds_aggregates() {
344 let groups = vec![Group {
345 groups: json!({"dept": "sales"})
346 .as_object()
347 .cloned()
348 .expect("object"),
349 aggregates: serde_json::Map::new(),
350 rows: vec![
351 json!({"amount": 100}).as_object().cloned().expect("object"),
352 json!({"amount": 200}).as_object().cloned().expect("object"),
353 ],
354 }];
355
356 let output =
357 apply(OutputItems::Groups(groups), "sum(amount) total").expect("aggregate should work");
358 match output {
359 OutputItems::Groups(groups) => {
360 assert_eq!(
361 groups[0]
362 .aggregates
363 .get("total")
364 .and_then(|value| value.as_f64()),
365 Some(300.0)
366 );
367 }
368 OutputItems::Rows(_) => panic!("expected groups"),
369 }
370 }
371
372 #[test]
373 fn count_macro_returns_count_rows() {
374 let rows = vec![
375 json!({"id": 1}).as_object().cloned().expect("object"),
376 json!({"id": 2}).as_object().cloned().expect("object"),
377 ];
378
379 let output = count_macro(OutputItems::Rows(rows), "").expect("count should work");
380 match output {
381 OutputItems::Rows(rows) => {
382 assert_eq!(
383 rows[0].get("count").and_then(|value| value.as_i64()),
384 Some(2)
385 );
386 }
387 OutputItems::Groups(_) => panic!("expected rows"),
388 }
389 }
390
391 #[test]
392 fn aggregate_supports_min_max_and_existence_count() {
393 let rows = vec![
394 json!({"score": 10, "enabled": true, "name": "beta"})
395 .as_object()
396 .cloned()
397 .expect("object"),
398 json!({"score": 3, "enabled": false, "name": "alpha"})
399 .as_object()
400 .cloned()
401 .expect("object"),
402 json!({"name": "gamma"})
403 .as_object()
404 .cloned()
405 .expect("object"),
406 ];
407
408 let min = apply(OutputItems::Rows(rows.clone()), "min(score) lowest")
409 .expect("min aggregate should work");
410 let OutputItems::Rows(min_rows) = min else {
411 panic!("expected row output");
412 };
413 assert_eq!(
414 min_rows[0].get("lowest").and_then(|value| value.as_i64()),
415 Some(3)
416 );
417
418 let max = apply(OutputItems::Rows(rows.clone()), "max(name) highest")
419 .expect("max aggregate should work");
420 let OutputItems::Rows(max_rows) = max else {
421 panic!("expected row output");
422 };
423 assert_eq!(
424 max_rows[0].get("highest").and_then(|value| value.as_str()),
425 Some("gamma")
426 );
427
428 let count = apply(OutputItems::Rows(rows), "count(?enabled) enabled_count")
429 .expect("existence count should work");
430 let OutputItems::Rows(count_rows) = count else {
431 panic!("expected row output");
432 };
433 assert_eq!(
434 count_rows[0]
435 .get("enabled_count")
436 .and_then(|value| value.as_i64()),
437 Some(3)
438 );
439 }
440
441 #[test]
442 fn aggregate_parses_default_aliases_and_group_count_macro() {
443 let rows = vec![
444 json!({"amount": 4}).as_object().cloned().expect("object"),
445 json!({"amount": 6}).as_object().cloned().expect("object"),
446 ];
447 let summed = apply(OutputItems::Rows(rows), "sum(amount)").expect("sum should work");
448 let OutputItems::Rows(rows) = summed else {
449 panic!("expected row output");
450 };
451 assert_eq!(
452 rows[0].get("sum(amount)").and_then(|value| value.as_f64()),
453 Some(10.0)
454 );
455
456 let grouped = OutputItems::Groups(vec![Group {
457 groups: json!({"dept": "sales"})
458 .as_object()
459 .cloned()
460 .expect("object"),
461 aggregates: serde_json::Map::new(),
462 rows: vec![
463 json!({"id": 1}).as_object().cloned().expect("object"),
464 json!({"id": 2}).as_object().cloned().expect("object"),
465 ],
466 }]);
467 let counted = count_macro(grouped, "").expect("count macro should work for groups");
468 let OutputItems::Rows(rows) = counted else {
469 panic!("expected row output");
470 };
471 assert_eq!(
472 rows[0].get("dept").and_then(|value| value.as_str()),
473 Some("sales")
474 );
475 assert_eq!(
476 rows[0].get("count").and_then(|value| value.as_i64()),
477 Some(2)
478 );
479 }
480
481 #[test]
482 fn aggregate_rejects_invalid_forms() {
483 let rows = OutputItems::Rows(vec![json!({"id": 1}).as_object().cloned().expect("object")]);
484
485 let missing_fn = apply(rows.clone(), "").expect_err("missing function should fail");
486 assert!(
487 missing_fn
488 .to_string()
489 .contains("A requires an aggregate function")
490 );
491
492 let malformed = apply(rows.clone(), "sum(id").expect_err("malformed function should fail");
493 assert!(malformed.to_string().contains("malformed function call"));
494
495 let unsupported =
496 apply(rows.clone(), "median(id)").expect_err("unsupported function should fail");
497 assert!(
498 unsupported
499 .to_string()
500 .contains("unsupported function 'median'")
501 );
502
503 let count_err = count_macro(rows, "extra").expect_err("C should reject arguments");
504 assert!(count_err.to_string().contains("C takes no arguments"));
505 }
506
507 #[test]
508 fn aggregate_supports_alias_after_as_and_mixed_numeric_inputs() {
509 let rows = vec![
510 json!({"value": "4"}).as_object().cloned().expect("object"),
511 json!({"value": true}).as_object().cloned().expect("object"),
512 json!({"value": 2}).as_object().cloned().expect("object"),
513 ];
514
515 let output =
516 apply(OutputItems::Rows(rows), "sum(value) AS total").expect("sum alias should work");
517 let OutputItems::Rows(rows) = output else {
518 panic!("expected row output");
519 };
520 assert_eq!(
521 rows[0].get("total").and_then(|value| value.as_f64()),
522 Some(7.0)
523 );
524 }
525
526 #[test]
527 fn aggregate_handles_empty_inputs_and_parenthesized_count_aliases() {
528 let empty_rows = OutputItems::Rows(vec![
529 json!({"value": null}).as_object().cloned().expect("object"),
530 ]);
531
532 let avg = apply(empty_rows.clone(), "avg(value) average").expect("avg should work");
533 let OutputItems::Rows(avg_rows) = avg else {
534 panic!("expected row output");
535 };
536 assert_eq!(
537 avg_rows[0].get("average").and_then(|value| value.as_f64()),
538 Some(0.0)
539 );
540
541 let min = apply(empty_rows, "min(value) lowest").expect("min should work");
542 let OutputItems::Rows(min_rows) = min else {
543 panic!("expected row output");
544 };
545 assert_eq!(min_rows[0].get("lowest"), Some(&json!(null)));
546
547 let count_rows = vec![
548 json!({"enabled": true})
549 .as_object()
550 .cloned()
551 .expect("object"),
552 json!({"enabled": false})
553 .as_object()
554 .cloned()
555 .expect("object"),
556 ];
557 let counted =
558 apply(OutputItems::Rows(count_rows), "count(enabled) AS matches").expect("count");
559 let OutputItems::Rows(rows) = counted else {
560 panic!("expected row output");
561 };
562 assert_eq!(
563 rows[0].get("matches").and_then(|value| value.as_i64()),
564 Some(2)
565 );
566 }
567
568 #[test]
569 fn aggregate_prefers_alias_token_for_two_word_count_form() {
570 let rows = vec![
571 json!({"id": 1}).as_object().cloned().expect("object"),
572 json!({"id": 2}).as_object().cloned().expect("object"),
573 json!({"id": 3}).as_object().cloned().expect("object"),
574 ];
575
576 let output = apply(OutputItems::Rows(rows), "count total").expect("count should work");
577 let OutputItems::Rows(rows) = output else {
578 panic!("expected row output");
579 };
580 assert_eq!(
581 rows[0].get("total").and_then(|value| value.as_i64()),
582 Some(3)
583 );
584 }
585
586 #[test]
587 fn aggregate_space_separated_column_form_keeps_column_name_as_alias() {
588 let rows = vec![
589 json!({"amount": 4}).as_object().cloned().expect("object"),
590 json!({"amount": 6}).as_object().cloned().expect("object"),
591 ];
592
593 let output = apply(OutputItems::Rows(rows), "sum amount").expect("sum should work");
594 let OutputItems::Rows(rows) = output else {
595 panic!("expected row output");
596 };
597 assert_eq!(
598 rows[0].get("amount").and_then(|value| value.as_f64()),
599 Some(10.0)
600 );
601 }
602}