Skip to main content

osp_cli/dsl/stages/
aggregate.rs

1use std::{cmp::Ordering, fmt::Display};
2
3use crate::core::{output_model::OutputItems, row::Row};
4use anyhow::{Result, anyhow};
5use serde_json::Value;
6
7use crate::dsl::{
8    eval::resolve::{resolve_values, resolve_values_truthy},
9    parse::key_spec::KeySpec,
10    stages::common::{parse_alias_after_as, parse_stage_words},
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum AggregateFn {
15    Count,
16    Sum,
17    Avg,
18    Min,
19    Max,
20}
21
22#[derive(Debug, Clone)]
23struct AggregateSpec {
24    function: AggregateFn,
25    column_raw: Option<String>,
26    alias: String,
27}
28
29pub fn apply(items: OutputItems, spec: &str) -> Result<OutputItems> {
30    let parsed = parse_aggregate_spec(spec)?;
31    match items {
32        OutputItems::Rows(rows) => {
33            let value = aggregate_rows(&rows, &parsed);
34            let mut row = Row::new();
35            row.insert(parsed.alias, value);
36            Ok(OutputItems::Rows(vec![row]))
37        }
38        OutputItems::Groups(groups) => {
39            let enriched = groups
40                .into_iter()
41                .map(|mut group| {
42                    let value = aggregate_rows(&group.rows, &parsed);
43                    group.aggregates.insert(parsed.alias.clone(), value);
44                    group
45                })
46                .collect::<Vec<_>>();
47            Ok(OutputItems::Groups(enriched))
48        }
49    }
50}
51
52pub fn count_macro(items: OutputItems, spec: &str) -> Result<OutputItems> {
53    if !spec.trim().is_empty() {
54        return Err(anyhow!("C takes no arguments"));
55    }
56
57    match items {
58        OutputItems::Rows(rows) => {
59            let mut row = Row::new();
60            row.insert("count".to_string(), Value::from(rows.len() as i64));
61            Ok(OutputItems::Rows(vec![row]))
62        }
63        OutputItems::Groups(groups) => {
64            let rows = groups
65                .into_iter()
66                .map(|group| {
67                    let mut row = group.groups;
68                    row.insert("count".to_string(), Value::from(group.rows.len() as i64));
69                    row
70                })
71                .collect::<Vec<_>>();
72            Ok(OutputItems::Rows(rows))
73        }
74    }
75}
76
77fn parse_aggregate_spec(spec: &str) -> Result<AggregateSpec> {
78    let words = parse_stage_words(spec)?;
79
80    if words.is_empty() {
81        return Err(anyhow!("A requires an aggregate function"));
82    }
83
84    let (function, mut column_raw, from_parenthesized) = parse_function_and_column(&words[0])?;
85    let mut index = 1usize;
86
87    if column_raw.is_none() && index < words.len() {
88        if function == AggregateFn::Count && words.len() == 2 {
89            // `A count alias` form
90        } else if !words[index].eq_ignore_ascii_case("AS") {
91            column_raw = Some(words[index].clone());
92            index += 1;
93        }
94    }
95
96    let alias = if let Some(alias) = parse_alias_after_as(&words, index, "A")? {
97        alias
98    } else if index < words.len() {
99        words[index].clone()
100    } else if let Some(column) = &column_raw {
101        if from_parenthesized {
102            format!("{}({column})", function.as_str())
103        } else {
104            column.clone()
105        }
106    } else {
107        function.default_alias().to_string()
108    };
109
110    Ok(AggregateSpec {
111        function,
112        column_raw,
113        alias,
114    })
115}
116
117fn parse_function_and_column(input: &str) -> Result<(AggregateFn, Option<String>, bool)> {
118    if let Some(open) = input.find('(') {
119        if !input.ends_with(')') {
120            return Err(anyhow!("A: malformed function call"));
121        }
122        let function_name = &input[..open];
123        let column = &input[open + 1..input.len() - 1];
124        let function = AggregateFn::parse(function_name)?;
125        let column = if column.trim().is_empty() {
126            None
127        } else {
128            Some(column.trim().to_string())
129        };
130        return Ok((function, column, true));
131    }
132
133    let function = AggregateFn::parse(input)?;
134    Ok((function, None, false))
135}
136
137fn aggregate_rows(rows: &[Row], spec: &AggregateSpec) -> Value {
138    let values = collect_column_values(rows, spec.column_raw.as_deref());
139
140    match spec.function {
141        AggregateFn::Count => Value::from(count_values(&values) as i64),
142        AggregateFn::Sum => Value::from(sum_values(&values)),
143        AggregateFn::Avg => {
144            let numbers = numeric_values(&values);
145            if numbers.is_empty() {
146                Value::from(0.0)
147            } else {
148                Value::from(numbers.iter().sum::<f64>() / numbers.len() as f64)
149            }
150        }
151        AggregateFn::Min => min_value(&values).unwrap_or(Value::Null),
152        AggregateFn::Max => max_value(&values).unwrap_or(Value::Null),
153    }
154}
155
156fn collect_column_values(rows: &[Row], column_raw: Option<&str>) -> Vec<Value> {
157    match column_raw {
158        None => rows.iter().map(|_| Value::Bool(true)).collect(),
159        Some(column_raw) => {
160            let key_spec = KeySpec::parse(column_raw);
161            if key_spec.existence {
162                rows.iter()
163                    .map(|row| {
164                        let found = resolve_values_truthy(row, &key_spec.token, key_spec.exact);
165                        Value::Bool(if key_spec.negated { !found } else { found })
166                    })
167                    .collect()
168            } else {
169                rows.iter()
170                    .flat_map(|row| resolve_values(row, &key_spec.token, key_spec.exact))
171                    .flat_map(expand_array_value)
172                    .collect()
173            }
174        }
175    }
176}
177
178fn expand_array_value(value: Value) -> Vec<Value> {
179    match value {
180        Value::Array(values) => values,
181        scalar => vec![scalar],
182    }
183}
184
185fn count_values(values: &[Value]) -> usize {
186    values.iter().filter(|value| !value.is_null()).count()
187}
188
189fn sum_values(values: &[Value]) -> f64 {
190    numeric_values(values).iter().sum()
191}
192
193fn numeric_values(values: &[Value]) -> Vec<f64> {
194    values
195        .iter()
196        .filter_map(|value| match value {
197            Value::Number(number) => number.as_f64(),
198            Value::String(text) => text.parse::<f64>().ok(),
199            Value::Bool(flag) => Some(if *flag { 1.0 } else { 0.0 }),
200            _ => None,
201        })
202        .collect()
203}
204
205fn min_value(values: &[Value]) -> Option<Value> {
206    values
207        .iter()
208        .filter(|value| !value.is_null())
209        .min_by(|left, right| compare_values(left, right))
210        .cloned()
211}
212
213fn max_value(values: &[Value]) -> Option<Value> {
214    values
215        .iter()
216        .filter(|value| !value.is_null())
217        .max_by(|left, right| compare_values(left, right))
218        .cloned()
219}
220
221fn compare_values(left: &Value, right: &Value) -> Ordering {
222    match (left, right) {
223        (Value::Number(a), Value::Number(b)) => a
224            .as_f64()
225            .partial_cmp(&b.as_f64())
226            .unwrap_or(Ordering::Equal),
227        (Value::String(a), Value::String(b)) => a.cmp(b),
228        _ => value_to_string(left).cmp(&value_to_string(right)),
229    }
230}
231
232fn value_to_string(value: &Value) -> String {
233    match value {
234        Value::String(text) => text.clone(),
235        other => other.to_string(),
236    }
237}
238
239impl AggregateFn {
240    fn parse(value: &str) -> Result<Self> {
241        match value.to_ascii_lowercase().as_str() {
242            "count" => Ok(Self::Count),
243            "sum" => Ok(Self::Sum),
244            "avg" => Ok(Self::Avg),
245            "min" => Ok(Self::Min),
246            "max" => Ok(Self::Max),
247            other => Err(anyhow!("A: unsupported function '{other}'")),
248        }
249    }
250
251    fn as_str(self) -> &'static str {
252        match self {
253            Self::Count => "count",
254            Self::Sum => "sum",
255            Self::Avg => "avg",
256            Self::Min => "min",
257            Self::Max => "max",
258        }
259    }
260
261    fn default_alias(self) -> &'static str {
262        match self {
263            Self::Count => "count",
264            Self::Sum => "sum",
265            Self::Avg => "avg",
266            Self::Min => "min",
267            Self::Max => "max",
268        }
269    }
270}
271
272impl Display for AggregateFn {
273    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        f.write_str(self.as_str())
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use crate::core::output_model::{Group, OutputItems};
281    use serde_json::json;
282
283    use super::{apply, count_macro};
284
285    #[test]
286    fn aggregate_count_global() {
287        let rows = vec![
288            json!({"id": 1}).as_object().cloned().expect("object"),
289            json!({"id": 2}).as_object().cloned().expect("object"),
290        ];
291
292        let output = apply(OutputItems::Rows(rows), "count total").expect("aggregate should work");
293        match output {
294            OutputItems::Rows(rows) => {
295                assert_eq!(
296                    rows[0].get("total").and_then(|value| value.as_i64()),
297                    Some(2)
298                );
299            }
300            OutputItems::Groups(_) => panic!("expected rows"),
301        }
302    }
303
304    #[test]
305    fn aggregate_sum_and_avg() {
306        let rows = vec![
307            json!({"numbers": [1, 2]})
308                .as_object()
309                .cloned()
310                .expect("object"),
311            json!({"numbers": [3]})
312                .as_object()
313                .cloned()
314                .expect("object"),
315        ];
316
317        let output = apply(OutputItems::Rows(rows.clone()), "sum(numbers[]) total")
318            .expect("aggregate should work");
319        match output {
320            OutputItems::Rows(rows) => {
321                assert_eq!(
322                    rows[0].get("total").and_then(|value| value.as_f64()),
323                    Some(6.0)
324                );
325            }
326            OutputItems::Groups(_) => panic!("expected rows"),
327        }
328
329        let output = apply(OutputItems::Rows(rows), "avg(numbers[]) average")
330            .expect("aggregate should work");
331        match output {
332            OutputItems::Rows(rows) => {
333                assert_eq!(
334                    rows[0].get("average").and_then(|value| value.as_f64()),
335                    Some(2.0)
336                );
337            }
338            OutputItems::Groups(_) => panic!("expected rows"),
339        }
340    }
341
342    #[test]
343    fn aggregate_on_groups_adds_aggregates() {
344        let groups = vec![Group {
345            groups: json!({"dept": "sales"})
346                .as_object()
347                .cloned()
348                .expect("object"),
349            aggregates: serde_json::Map::new(),
350            rows: vec![
351                json!({"amount": 100}).as_object().cloned().expect("object"),
352                json!({"amount": 200}).as_object().cloned().expect("object"),
353            ],
354        }];
355
356        let output =
357            apply(OutputItems::Groups(groups), "sum(amount) total").expect("aggregate should work");
358        match output {
359            OutputItems::Groups(groups) => {
360                assert_eq!(
361                    groups[0]
362                        .aggregates
363                        .get("total")
364                        .and_then(|value| value.as_f64()),
365                    Some(300.0)
366                );
367            }
368            OutputItems::Rows(_) => panic!("expected groups"),
369        }
370    }
371
372    #[test]
373    fn count_macro_returns_count_rows() {
374        let rows = vec![
375            json!({"id": 1}).as_object().cloned().expect("object"),
376            json!({"id": 2}).as_object().cloned().expect("object"),
377        ];
378
379        let output = count_macro(OutputItems::Rows(rows), "").expect("count should work");
380        match output {
381            OutputItems::Rows(rows) => {
382                assert_eq!(
383                    rows[0].get("count").and_then(|value| value.as_i64()),
384                    Some(2)
385                );
386            }
387            OutputItems::Groups(_) => panic!("expected rows"),
388        }
389    }
390
391    #[test]
392    fn aggregate_supports_min_max_and_existence_count() {
393        let rows = vec![
394            json!({"score": 10, "enabled": true, "name": "beta"})
395                .as_object()
396                .cloned()
397                .expect("object"),
398            json!({"score": 3, "enabled": false, "name": "alpha"})
399                .as_object()
400                .cloned()
401                .expect("object"),
402            json!({"name": "gamma"})
403                .as_object()
404                .cloned()
405                .expect("object"),
406        ];
407
408        let min = apply(OutputItems::Rows(rows.clone()), "min(score) lowest")
409            .expect("min aggregate should work");
410        let OutputItems::Rows(min_rows) = min else {
411            panic!("expected row output");
412        };
413        assert_eq!(
414            min_rows[0].get("lowest").and_then(|value| value.as_i64()),
415            Some(3)
416        );
417
418        let max = apply(OutputItems::Rows(rows.clone()), "max(name) highest")
419            .expect("max aggregate should work");
420        let OutputItems::Rows(max_rows) = max else {
421            panic!("expected row output");
422        };
423        assert_eq!(
424            max_rows[0].get("highest").and_then(|value| value.as_str()),
425            Some("gamma")
426        );
427
428        let count = apply(OutputItems::Rows(rows), "count(?enabled) enabled_count")
429            .expect("existence count should work");
430        let OutputItems::Rows(count_rows) = count else {
431            panic!("expected row output");
432        };
433        assert_eq!(
434            count_rows[0]
435                .get("enabled_count")
436                .and_then(|value| value.as_i64()),
437            Some(3)
438        );
439    }
440
441    #[test]
442    fn aggregate_parses_default_aliases_and_group_count_macro() {
443        let rows = vec![
444            json!({"amount": 4}).as_object().cloned().expect("object"),
445            json!({"amount": 6}).as_object().cloned().expect("object"),
446        ];
447        let summed = apply(OutputItems::Rows(rows), "sum(amount)").expect("sum should work");
448        let OutputItems::Rows(rows) = summed else {
449            panic!("expected row output");
450        };
451        assert_eq!(
452            rows[0].get("sum(amount)").and_then(|value| value.as_f64()),
453            Some(10.0)
454        );
455
456        let grouped = OutputItems::Groups(vec![Group {
457            groups: json!({"dept": "sales"})
458                .as_object()
459                .cloned()
460                .expect("object"),
461            aggregates: serde_json::Map::new(),
462            rows: vec![
463                json!({"id": 1}).as_object().cloned().expect("object"),
464                json!({"id": 2}).as_object().cloned().expect("object"),
465            ],
466        }]);
467        let counted = count_macro(grouped, "").expect("count macro should work for groups");
468        let OutputItems::Rows(rows) = counted else {
469            panic!("expected row output");
470        };
471        assert_eq!(
472            rows[0].get("dept").and_then(|value| value.as_str()),
473            Some("sales")
474        );
475        assert_eq!(
476            rows[0].get("count").and_then(|value| value.as_i64()),
477            Some(2)
478        );
479    }
480
481    #[test]
482    fn aggregate_rejects_invalid_forms() {
483        let rows = OutputItems::Rows(vec![json!({"id": 1}).as_object().cloned().expect("object")]);
484
485        let missing_fn = apply(rows.clone(), "").expect_err("missing function should fail");
486        assert!(
487            missing_fn
488                .to_string()
489                .contains("A requires an aggregate function")
490        );
491
492        let malformed = apply(rows.clone(), "sum(id").expect_err("malformed function should fail");
493        assert!(malformed.to_string().contains("malformed function call"));
494
495        let unsupported =
496            apply(rows.clone(), "median(id)").expect_err("unsupported function should fail");
497        assert!(
498            unsupported
499                .to_string()
500                .contains("unsupported function 'median'")
501        );
502
503        let count_err = count_macro(rows, "extra").expect_err("C should reject arguments");
504        assert!(count_err.to_string().contains("C takes no arguments"));
505    }
506
507    #[test]
508    fn aggregate_supports_alias_after_as_and_mixed_numeric_inputs() {
509        let rows = vec![
510            json!({"value": "4"}).as_object().cloned().expect("object"),
511            json!({"value": true}).as_object().cloned().expect("object"),
512            json!({"value": 2}).as_object().cloned().expect("object"),
513        ];
514
515        let output =
516            apply(OutputItems::Rows(rows), "sum(value) AS total").expect("sum alias should work");
517        let OutputItems::Rows(rows) = output else {
518            panic!("expected row output");
519        };
520        assert_eq!(
521            rows[0].get("total").and_then(|value| value.as_f64()),
522            Some(7.0)
523        );
524    }
525
526    #[test]
527    fn aggregate_handles_empty_inputs_and_parenthesized_count_aliases() {
528        let empty_rows = OutputItems::Rows(vec![
529            json!({"value": null}).as_object().cloned().expect("object"),
530        ]);
531
532        let avg = apply(empty_rows.clone(), "avg(value) average").expect("avg should work");
533        let OutputItems::Rows(avg_rows) = avg else {
534            panic!("expected row output");
535        };
536        assert_eq!(
537            avg_rows[0].get("average").and_then(|value| value.as_f64()),
538            Some(0.0)
539        );
540
541        let min = apply(empty_rows, "min(value) lowest").expect("min should work");
542        let OutputItems::Rows(min_rows) = min else {
543            panic!("expected row output");
544        };
545        assert_eq!(min_rows[0].get("lowest"), Some(&json!(null)));
546
547        let count_rows = vec![
548            json!({"enabled": true})
549                .as_object()
550                .cloned()
551                .expect("object"),
552            json!({"enabled": false})
553                .as_object()
554                .cloned()
555                .expect("object"),
556        ];
557        let counted =
558            apply(OutputItems::Rows(count_rows), "count(enabled) AS matches").expect("count");
559        let OutputItems::Rows(rows) = counted else {
560            panic!("expected row output");
561        };
562        assert_eq!(
563            rows[0].get("matches").and_then(|value| value.as_i64()),
564            Some(2)
565        );
566    }
567
568    #[test]
569    fn aggregate_prefers_alias_token_for_two_word_count_form() {
570        let rows = vec![
571            json!({"id": 1}).as_object().cloned().expect("object"),
572            json!({"id": 2}).as_object().cloned().expect("object"),
573            json!({"id": 3}).as_object().cloned().expect("object"),
574        ];
575
576        let output = apply(OutputItems::Rows(rows), "count total").expect("count should work");
577        let OutputItems::Rows(rows) = output else {
578            panic!("expected row output");
579        };
580        assert_eq!(
581            rows[0].get("total").and_then(|value| value.as_i64()),
582            Some(3)
583        );
584    }
585
586    #[test]
587    fn aggregate_space_separated_column_form_keeps_column_name_as_alias() {
588        let rows = vec![
589            json!({"amount": 4}).as_object().cloned().expect("object"),
590            json!({"amount": 6}).as_object().cloned().expect("object"),
591        ];
592
593        let output = apply(OutputItems::Rows(rows), "sum amount").expect("sum should work");
594        let OutputItems::Rows(rows) = output else {
595            panic!("expected row output");
596        };
597        assert_eq!(
598            rows[0].get("amount").and_then(|value| value.as_f64()),
599            Some(10.0)
600        );
601    }
602}