Skip to main content

chartml_core/
transform.rs

1use crate::data::{get_f64, get_string, Row};
2use crate::error::ChartError;
3use crate::spec::{AggregateSpec, Dimension, FilterGroup, FilterRule, SortSpec, TransformSpec};
4use std::collections::HashMap;
5
6/// Apply transforms to data.
7pub fn apply_transforms(data: Vec<Row>, spec: &TransformSpec) -> Result<Vec<Row>, ChartError> {
8    let mut result = data;
9
10    if let Some(ref agg) = spec.aggregate {
11        result = aggregate(&result, agg)?;
12    }
13
14    Ok(result)
15}
16
17/// Aggregate data by dimensions with measures.
18/// Matches JS behavior: if no dimensions and no measures, skip aggregation
19/// and just apply filters/sort/limit to the raw data.
20fn aggregate(data: &[Row], spec: &AggregateSpec) -> Result<Vec<Row>, ChartError> {
21    // No aggregation needed — just filter/sort/limit the raw data
22    if spec.dimensions.is_empty() && spec.measures.is_empty() {
23        let mut result = data.to_vec();
24        if let Some(ref filters) = spec.filters {
25            result = apply_filters(result, filters);
26        }
27        if let Some(ref sorts) = spec.sort {
28            apply_sort(&mut result, sorts);
29        }
30        if let Some(limit) = spec.limit {
31            result.truncate(limit as usize);
32        }
33        return Ok(result);
34    }
35
36    // 1. Extract dimension output names and source column names
37    let dim_fields: Vec<String> = spec
38        .dimensions
39        .iter()
40        .map(|d| match d {
41            Dimension::Simple(name) => name.clone(),
42            Dimension::Detailed(spec) => spec.name.clone().unwrap_or_else(|| spec.column.clone()),
43        })
44        .collect();
45
46    let dim_columns: Vec<String> = spec
47        .dimensions
48        .iter()
49        .map(|d| match d {
50            Dimension::Simple(name) => name.clone(),
51            Dimension::Detailed(spec) => spec.column.clone(),
52        })
53        .collect();
54
55    // 2. Group rows by dimension values, preserving first-occurrence order of group keys.
56    let mut groups: HashMap<Vec<String>, Vec<&Row>> = HashMap::new();
57    let mut key_order: Vec<Vec<String>> = Vec::new();
58    let mut seen_keys: std::collections::HashSet<Vec<String>> = std::collections::HashSet::new();
59    for row in data {
60        let key: Vec<String> = dim_columns
61            .iter()
62            .map(|field| get_string(row, field).unwrap_or_default())
63            .collect();
64        if seen_keys.insert(key.clone()) {
65            key_order.push(key.clone());
66        }
67        groups.entry(key).or_default().push(row);
68    }
69
70    // 3. For each group, compute measures (in insertion order)
71    let mut result: Vec<Row> = Vec::new();
72    for key in &key_order {
73        let rows = &groups[key];
74        let mut out_row = Row::new();
75
76        // Add dimension values
77        for (i, field_name) in dim_fields.iter().enumerate() {
78            out_row.insert(
79                field_name.clone(),
80                serde_json::Value::String(key[i].clone()),
81            );
82        }
83
84        // Compute each measure
85        for measure in &spec.measures {
86            let value = if let Some(ref col) = measure.column {
87                if let Some(ref agg_fn) = measure.aggregation {
88                    compute_aggregation(rows, col, agg_fn)?
89                } else {
90                    // No aggregation specified, take first value
91                    rows.first()
92                        .and_then(|r| get_f64(r, col))
93                        .unwrap_or(0.0)
94                }
95            } else if let Some(ref expr) = measure.expression {
96                evaluate_expression(&out_row, expr)?
97            } else {
98                0.0
99            };
100            out_row.insert(measure.name.clone(), serde_json::json!(value));
101        }
102
103        result.push(out_row);
104    }
105
106    // 4. Apply filters (if any)
107    if let Some(ref filters) = spec.filters {
108        result = apply_filters(result, filters);
109    }
110
111    // 5. Apply sort (if any)
112    if let Some(ref sorts) = spec.sort {
113        apply_sort(&mut result, sorts);
114    }
115
116    // 6. Apply limit (if any)
117    if let Some(limit) = spec.limit {
118        result.truncate(limit as usize);
119    }
120
121    Ok(result)
122}
123
124fn compute_aggregation(rows: &[&Row], column: &str, agg: &str) -> Result<f64, ChartError> {
125    let values: Vec<f64> = rows.iter().filter_map(|r| get_f64(r, column)).collect();
126    if values.is_empty() {
127        return Ok(0.0);
128    }
129
130    Ok(match agg {
131        "sum" => values.iter().sum(),
132        "avg" => values.iter().sum::<f64>() / values.len() as f64,
133        "count" => values.len() as f64,
134        "min" => values.iter().cloned().fold(f64::INFINITY, f64::min),
135        "max" => values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
136        "countDistinct" => {
137            let mut seen = std::collections::HashSet::new();
138            for v in &values {
139                seen.insert(v.to_bits());
140            }
141            seen.len() as f64
142        }
143        "median" => {
144            let mut sorted = values.clone();
145            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
146            let mid = sorted.len() / 2;
147            if sorted.len().is_multiple_of(2) {
148                (sorted[mid - 1] + sorted[mid]) / 2.0
149            } else {
150                sorted[mid]
151            }
152        }
153        other => {
154            return Err(ChartError::InvalidSpec(format!(
155                "Unknown aggregation: {}",
156                other
157            )))
158        }
159    })
160}
161
162/// Evaluate a simple arithmetic expression referencing fields already computed in the row.
163/// Supports: field_a / field_b, field_a * field_b, field_a + field_b, field_a - field_b
164fn evaluate_expression(row: &Row, expr: &str) -> Result<f64, ChartError> {
165    // Find the operator
166    let operators = ['/', '*', '+', '-'];
167    for op in &operators {
168        if let Some(pos) = expr.find(*op) {
169            // Make sure it's not a negative sign at position 0
170            if *op == '-' && pos == 0 {
171                continue;
172            }
173            let left_name = expr[..pos].trim();
174            let right_name = expr[pos + 1..].trim();
175
176            let left_val = get_f64(row, left_name).ok_or_else(|| {
177                ChartError::DataError(format!(
178                    "Expression field '{}' not found in row",
179                    left_name
180                ))
181            })?;
182            let right_val = get_f64(row, right_name).ok_or_else(|| {
183                ChartError::DataError(format!(
184                    "Expression field '{}' not found in row",
185                    right_name
186                ))
187            })?;
188
189            return Ok(match op {
190                '+' => left_val + right_val,
191                '-' => left_val - right_val,
192                '*' => left_val * right_val,
193                '/' => {
194                    if right_val == 0.0 {
195                        0.0
196                    } else {
197                        left_val / right_val
198                    }
199                }
200                _ => unreachable!(),
201            });
202        }
203    }
204
205    // No operator found — try interpreting as a field reference
206    get_f64(row, expr.trim()).ok_or_else(|| {
207        ChartError::DataError(format!(
208            "Cannot evaluate expression '{}': field not found",
209            expr
210        ))
211    })
212}
213
214fn apply_filters(data: Vec<Row>, filters: &FilterGroup) -> Vec<Row> {
215    let is_and = filters.combinator.as_deref() != Some("or");
216
217    data.into_iter()
218        .filter(|row| {
219            let results: Vec<bool> = filters.rules.iter().map(|rule| eval_filter_rule(row, rule)).collect();
220
221            if is_and {
222                results.iter().all(|&r| r)
223            } else {
224                results.iter().any(|&r| r)
225            }
226        })
227        .collect()
228}
229
230fn eval_filter_rule(row: &Row, rule: &FilterRule) -> bool {
231    let field_val = row.get(&rule.field);
232
233    match rule.operator.as_str() {
234        "isNull" => field_val.is_none() || field_val == Some(&serde_json::Value::Null),
235        "isNotNull" => field_val.is_some() && field_val != Some(&serde_json::Value::Null),
236        "=" | "==" => rule
237            .value
238            .as_ref()
239            .is_some_and(|v| field_val == Some(v)),
240        "!=" => rule
241            .value
242            .as_ref()
243            .is_some_and(|v| field_val != Some(v)),
244        ">" => compare_values(field_val, rule.value.as_ref(), |a, b| a > b),
245        ">=" => compare_values(field_val, rule.value.as_ref(), |a, b| a >= b),
246        "<" => compare_values(field_val, rule.value.as_ref(), |a, b| a < b),
247        "<=" => compare_values(field_val, rule.value.as_ref(), |a, b| a <= b),
248        "in" => {
249            if let (Some(fv), Some(serde_json::Value::Array(arr))) =
250                (field_val, rule.value.as_ref())
251            {
252                arr.contains(fv)
253            } else {
254                false
255            }
256        }
257        "contains" => {
258            if let (Some(serde_json::Value::String(fv)), Some(serde_json::Value::String(rv))) =
259                (field_val, rule.value.as_ref())
260            {
261                fv.contains(rv.as_str())
262            } else {
263                false
264            }
265        }
266        _ => true, // unknown operator, pass through
267    }
268}
269
270fn compare_values(
271    field: Option<&serde_json::Value>,
272    rule_val: Option<&serde_json::Value>,
273    cmp: impl Fn(f64, f64) -> bool,
274) -> bool {
275    match (field, rule_val) {
276        (Some(serde_json::Value::Number(a)), Some(serde_json::Value::Number(b))) => {
277            if let (Some(fa), Some(fb)) = (a.as_f64(), b.as_f64()) {
278                cmp(fa, fb)
279            } else {
280                false
281            }
282        }
283        _ => false,
284    }
285}
286
287fn apply_sort(data: &mut [Row], sorts: &[SortSpec]) {
288    data.sort_by(|a, b| {
289        for sort in sorts {
290            let a_val = get_f64(a, &sort.field);
291            let b_val = get_f64(b, &sort.field);
292            let ord = match (a_val, b_val) {
293                (Some(av), Some(bv)) => av
294                    .partial_cmp(&bv)
295                    .unwrap_or(std::cmp::Ordering::Equal),
296                _ => std::cmp::Ordering::Equal,
297            };
298            let ord = if sort.direction.as_deref() == Some("desc") {
299                ord.reverse()
300            } else {
301                ord
302            };
303            if ord != std::cmp::Ordering::Equal {
304                return ord;
305            }
306        }
307        std::cmp::Ordering::Equal
308    });
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::spec::{AggregateSpec, Dimension, Measure, SortSpec, TransformSpec};
315    use serde_json::json;
316
317    fn make_row(pairs: Vec<(&str, serde_json::Value)>) -> Row {
318        pairs
319            .into_iter()
320            .map(|(k, v)| (k.to_string(), v))
321            .collect()
322    }
323
324    fn sales_data() -> Vec<Row> {
325        vec![
326            make_row(vec![
327                ("region", json!("North")),
328                ("product", json!("Widget")),
329                ("revenue", json!(100.0)),
330                ("units", json!(10)),
331            ]),
332            make_row(vec![
333                ("region", json!("North")),
334                ("product", json!("Gadget")),
335                ("revenue", json!(200.0)),
336                ("units", json!(15)),
337            ]),
338            make_row(vec![
339                ("region", json!("South")),
340                ("product", json!("Widget")),
341                ("revenue", json!(150.0)),
342                ("units", json!(12)),
343            ]),
344            make_row(vec![
345                ("region", json!("South")),
346                ("product", json!("Widget")),
347                ("revenue", json!(50.0)),
348                ("units", json!(5)),
349            ]),
350            make_row(vec![
351                ("region", json!("East")),
352                ("product", json!("Gadget")),
353                ("revenue", json!(300.0)),
354                ("units", json!(20)),
355            ]),
356        ]
357    }
358
359    #[test]
360    fn aggregate_sum() {
361        let data = sales_data();
362        let spec = AggregateSpec {
363            dimensions: vec![Dimension::Simple("region".to_string())],
364            measures: vec![Measure {
365                column: Some("revenue".to_string()),
366                aggregation: Some("sum".to_string()),
367                name: "total_revenue".to_string(),
368                expression: None,
369            }],
370            filters: None,
371            sort: None,
372            limit: None,
373        };
374
375        let result = aggregate(&data, &spec).unwrap();
376        assert_eq!(result.len(), 3);
377
378        // Find each region and check totals
379        let north = result
380            .iter()
381            .find(|r| get_string(r, "region") == Some("North".to_string()))
382            .unwrap();
383        assert_eq!(get_f64(north, "total_revenue"), Some(300.0));
384
385        let south = result
386            .iter()
387            .find(|r| get_string(r, "region") == Some("South".to_string()))
388            .unwrap();
389        assert_eq!(get_f64(south, "total_revenue"), Some(200.0));
390
391        let east = result
392            .iter()
393            .find(|r| get_string(r, "region") == Some("East".to_string()))
394            .unwrap();
395        assert_eq!(get_f64(east, "total_revenue"), Some(300.0));
396    }
397
398    #[test]
399    fn aggregate_avg() {
400        let data = sales_data();
401        let spec = AggregateSpec {
402            dimensions: vec![Dimension::Simple("region".to_string())],
403            measures: vec![Measure {
404                column: Some("revenue".to_string()),
405                aggregation: Some("avg".to_string()),
406                name: "avg_revenue".to_string(),
407                expression: None,
408            }],
409            filters: None,
410            sort: None,
411            limit: None,
412        };
413
414        let result = aggregate(&data, &spec).unwrap();
415        let north = result
416            .iter()
417            .find(|r| get_string(r, "region") == Some("North".to_string()))
418            .unwrap();
419        assert_eq!(get_f64(north, "avg_revenue"), Some(150.0)); // (100+200)/2
420
421        let south = result
422            .iter()
423            .find(|r| get_string(r, "region") == Some("South".to_string()))
424            .unwrap();
425        assert_eq!(get_f64(south, "avg_revenue"), Some(100.0)); // (150+50)/2
426    }
427
428    #[test]
429    fn aggregate_count() {
430        let data = sales_data();
431        let spec = AggregateSpec {
432            dimensions: vec![Dimension::Simple("region".to_string())],
433            measures: vec![Measure {
434                column: Some("revenue".to_string()),
435                aggregation: Some("count".to_string()),
436                name: "count".to_string(),
437                expression: None,
438            }],
439            filters: None,
440            sort: None,
441            limit: None,
442        };
443
444        let result = aggregate(&data, &spec).unwrap();
445        let north = result
446            .iter()
447            .find(|r| get_string(r, "region") == Some("North".to_string()))
448            .unwrap();
449        assert_eq!(get_f64(north, "count"), Some(2.0));
450
451        let south = result
452            .iter()
453            .find(|r| get_string(r, "region") == Some("South".to_string()))
454            .unwrap();
455        assert_eq!(get_f64(south, "count"), Some(2.0));
456    }
457
458    #[test]
459    fn aggregate_min_max() {
460        let data = sales_data();
461        let spec = AggregateSpec {
462            dimensions: vec![Dimension::Simple("region".to_string())],
463            measures: vec![
464                Measure {
465                    column: Some("revenue".to_string()),
466                    aggregation: Some("min".to_string()),
467                    name: "min_rev".to_string(),
468                    expression: None,
469                },
470                Measure {
471                    column: Some("revenue".to_string()),
472                    aggregation: Some("max".to_string()),
473                    name: "max_rev".to_string(),
474                    expression: None,
475                },
476            ],
477            filters: None,
478            sort: None,
479            limit: None,
480        };
481
482        let result = aggregate(&data, &spec).unwrap();
483        let south = result
484            .iter()
485            .find(|r| get_string(r, "region") == Some("South".to_string()))
486            .unwrap();
487        assert_eq!(get_f64(south, "min_rev"), Some(50.0));
488        assert_eq!(get_f64(south, "max_rev"), Some(150.0));
489    }
490
491    #[test]
492    fn aggregate_count_distinct() {
493        let data = sales_data();
494        let spec = AggregateSpec {
495            dimensions: vec![Dimension::Simple("region".to_string())],
496            measures: vec![Measure {
497                column: Some("revenue".to_string()),
498                aggregation: Some("countDistinct".to_string()),
499                name: "distinct_rev".to_string(),
500                expression: None,
501            }],
502            filters: None,
503            sort: None,
504            limit: None,
505        };
506
507        let result = aggregate(&data, &spec).unwrap();
508        let north = result
509            .iter()
510            .find(|r| get_string(r, "region") == Some("North".to_string()))
511            .unwrap();
512        assert_eq!(get_f64(north, "distinct_rev"), Some(2.0)); // 100 and 200
513    }
514
515    #[test]
516    fn aggregate_median() {
517        let data = vec![
518            make_row(vec![("g", json!("A")), ("v", json!(1.0))]),
519            make_row(vec![("g", json!("A")), ("v", json!(3.0))]),
520            make_row(vec![("g", json!("A")), ("v", json!(5.0))]),
521            make_row(vec![("g", json!("B")), ("v", json!(2.0))]),
522            make_row(vec![("g", json!("B")), ("v", json!(4.0))]),
523        ];
524        let spec = AggregateSpec {
525            dimensions: vec![Dimension::Simple("g".to_string())],
526            measures: vec![Measure {
527                column: Some("v".to_string()),
528                aggregation: Some("median".to_string()),
529                name: "med".to_string(),
530                expression: None,
531            }],
532            filters: None,
533            sort: None,
534            limit: None,
535        };
536
537        let result = aggregate(&data, &spec).unwrap();
538        let a = result
539            .iter()
540            .find(|r| get_string(r, "g") == Some("A".to_string()))
541            .unwrap();
542        assert_eq!(get_f64(a, "med"), Some(3.0)); // odd count: middle
543
544        let b = result
545            .iter()
546            .find(|r| get_string(r, "g") == Some("B".to_string()))
547            .unwrap();
548        assert_eq!(get_f64(b, "med"), Some(3.0)); // even count: (2+4)/2
549    }
550
551    #[test]
552    fn aggregate_with_sort() {
553        let data = sales_data();
554        let spec = AggregateSpec {
555            dimensions: vec![Dimension::Simple("region".to_string())],
556            measures: vec![Measure {
557                column: Some("revenue".to_string()),
558                aggregation: Some("sum".to_string()),
559                name: "total_revenue".to_string(),
560                expression: None,
561            }],
562            filters: None,
563            sort: Some(vec![SortSpec {
564                field: "total_revenue".to_string(),
565                direction: Some("desc".to_string()),
566            }]),
567            limit: None,
568        };
569
570        let result = aggregate(&data, &spec).unwrap();
571        assert_eq!(result.len(), 3);
572        // North=300, East=300, South=200 — both 300s first, then 200
573        let first_val = get_f64(&result[0], "total_revenue").unwrap();
574        let second_val = get_f64(&result[1], "total_revenue").unwrap();
575        let third_val = get_f64(&result[2], "total_revenue").unwrap();
576        assert!(first_val >= second_val);
577        assert!(second_val >= third_val);
578        assert_eq!(third_val, 200.0);
579    }
580
581    #[test]
582    fn aggregate_with_limit() {
583        let data = sales_data();
584        let spec = AggregateSpec {
585            dimensions: vec![Dimension::Simple("region".to_string())],
586            measures: vec![Measure {
587                column: Some("revenue".to_string()),
588                aggregation: Some("sum".to_string()),
589                name: "total_revenue".to_string(),
590                expression: None,
591            }],
592            filters: None,
593            sort: Some(vec![SortSpec {
594                field: "total_revenue".to_string(),
595                direction: Some("desc".to_string()),
596            }]),
597            limit: Some(2),
598        };
599
600        let result = aggregate(&data, &spec).unwrap();
601        assert_eq!(result.len(), 2);
602    }
603
604    #[test]
605    fn aggregate_with_detailed_dimension() {
606        let data = sales_data();
607        let spec = AggregateSpec {
608            dimensions: vec![Dimension::Detailed(crate::spec::DimensionSpec {
609                column: "region".to_string(),
610                name: Some("Region".to_string()),
611                dim_type: None,
612            })],
613            measures: vec![Measure {
614                column: Some("revenue".to_string()),
615                aggregation: Some("sum".to_string()),
616                name: "total".to_string(),
617                expression: None,
618            }],
619            filters: None,
620            sort: None,
621            limit: None,
622        };
623
624        let result = aggregate(&data, &spec).unwrap();
625        // Output field name should be "Region" not "region"
626        let north = result
627            .iter()
628            .find(|r| get_string(r, "Region") == Some("North".to_string()))
629            .unwrap();
630        assert_eq!(get_f64(north, "total"), Some(300.0));
631    }
632
633    #[test]
634    fn aggregate_with_filter_gt() {
635        let data = sales_data();
636        let spec = AggregateSpec {
637            dimensions: vec![Dimension::Simple("region".to_string())],
638            measures: vec![Measure {
639                column: Some("revenue".to_string()),
640                aggregation: Some("sum".to_string()),
641                name: "total_revenue".to_string(),
642                expression: None,
643            }],
644            filters: Some(FilterGroup {
645                combinator: None, // defaults to "and"
646                rules: vec![FilterRule {
647                    field: "total_revenue".to_string(),
648                    operator: ">".to_string(),
649                    value: Some(json!(250)),
650                }],
651            }),
652            sort: None,
653            limit: None,
654        };
655
656        let result = aggregate(&data, &spec).unwrap();
657        // North=300, East=300 pass; South=200 filtered out
658        assert_eq!(result.len(), 2);
659        for row in &result {
660            assert!(get_f64(row, "total_revenue").unwrap() > 250.0);
661        }
662    }
663
664    #[test]
665    fn aggregate_with_filter_eq() {
666        let data = sales_data();
667        let spec = AggregateSpec {
668            dimensions: vec![Dimension::Simple("region".to_string())],
669            measures: vec![Measure {
670                column: Some("revenue".to_string()),
671                aggregation: Some("sum".to_string()),
672                name: "total_revenue".to_string(),
673                expression: None,
674            }],
675            filters: Some(FilterGroup {
676                combinator: None,
677                rules: vec![FilterRule {
678                    field: "region".to_string(),
679                    operator: "=".to_string(),
680                    value: Some(json!("North")),
681                }],
682            }),
683            sort: None,
684            limit: None,
685        };
686
687        let result = aggregate(&data, &spec).unwrap();
688        assert_eq!(result.len(), 1);
689        assert_eq!(
690            get_string(&result[0], "region"),
691            Some("North".to_string())
692        );
693    }
694
695    #[test]
696    fn aggregate_with_filter_in() {
697        let data = sales_data();
698        let spec = AggregateSpec {
699            dimensions: vec![Dimension::Simple("region".to_string())],
700            measures: vec![Measure {
701                column: Some("revenue".to_string()),
702                aggregation: Some("sum".to_string()),
703                name: "total_revenue".to_string(),
704                expression: None,
705            }],
706            filters: Some(FilterGroup {
707                combinator: None,
708                rules: vec![FilterRule {
709                    field: "region".to_string(),
710                    operator: "in".to_string(),
711                    value: Some(json!(["North", "East"])),
712                }],
713            }),
714            sort: None,
715            limit: None,
716        };
717
718        let result = aggregate(&data, &spec).unwrap();
719        assert_eq!(result.len(), 2);
720    }
721
722    #[test]
723    fn aggregate_with_filter_or_combinator() {
724        let data = sales_data();
725        let spec = AggregateSpec {
726            dimensions: vec![Dimension::Simple("region".to_string())],
727            measures: vec![Measure {
728                column: Some("revenue".to_string()),
729                aggregation: Some("sum".to_string()),
730                name: "total_revenue".to_string(),
731                expression: None,
732            }],
733            filters: Some(FilterGroup {
734                combinator: Some("or".to_string()),
735                rules: vec![
736                    FilterRule {
737                        field: "region".to_string(),
738                        operator: "=".to_string(),
739                        value: Some(json!("North")),
740                    },
741                    FilterRule {
742                        field: "region".to_string(),
743                        operator: "=".to_string(),
744                        value: Some(json!("East")),
745                    },
746                ],
747            }),
748            sort: None,
749            limit: None,
750        };
751
752        let result = aggregate(&data, &spec).unwrap();
753        assert_eq!(result.len(), 2);
754    }
755
756    #[test]
757    fn aggregate_with_expression_measure() {
758        let data = sales_data();
759        let spec = AggregateSpec {
760            dimensions: vec![Dimension::Simple("region".to_string())],
761            measures: vec![
762                Measure {
763                    column: Some("revenue".to_string()),
764                    aggregation: Some("sum".to_string()),
765                    name: "total_revenue".to_string(),
766                    expression: None,
767                },
768                Measure {
769                    column: Some("units".to_string()),
770                    aggregation: Some("sum".to_string()),
771                    name: "total_units".to_string(),
772                    expression: None,
773                },
774                Measure {
775                    column: None,
776                    aggregation: None,
777                    name: "avg_price".to_string(),
778                    expression: Some("total_revenue / total_units".to_string()),
779                },
780            ],
781            filters: None,
782            sort: None,
783            limit: None,
784        };
785
786        let result = aggregate(&data, &spec).unwrap();
787        let north = result
788            .iter()
789            .find(|r| get_string(r, "region") == Some("North".to_string()))
790            .unwrap();
791        // North: revenue=300, units=25, avg_price=12
792        assert_eq!(get_f64(north, "total_revenue"), Some(300.0));
793        assert_eq!(get_f64(north, "total_units"), Some(25.0));
794        assert_eq!(get_f64(north, "avg_price"), Some(12.0));
795    }
796
797    #[test]
798    fn apply_transforms_full_pipeline() {
799        let data = sales_data();
800        let spec = TransformSpec {
801            sql: None,
802            forecast: None,
803            aggregate: Some(AggregateSpec {
804                dimensions: vec![Dimension::Simple("region".to_string())],
805                measures: vec![Measure {
806                    column: Some("revenue".to_string()),
807                    aggregation: Some("sum".to_string()),
808                    name: "total_revenue".to_string(),
809                    expression: None,
810                }],
811                filters: None,
812                sort: Some(vec![SortSpec {
813                    field: "total_revenue".to_string(),
814                    direction: Some("desc".to_string()),
815                }]),
816                limit: Some(2),
817            }),
818        };
819
820        let result = apply_transforms(data, &spec).unwrap();
821        assert_eq!(result.len(), 2);
822        let first_val = get_f64(&result[0], "total_revenue").unwrap();
823        let second_val = get_f64(&result[1], "total_revenue").unwrap();
824        assert!(first_val >= second_val);
825    }
826
827    #[test]
828    fn unknown_aggregation_returns_error() {
829        let data = sales_data();
830        let spec = AggregateSpec {
831            dimensions: vec![Dimension::Simple("region".to_string())],
832            measures: vec![Measure {
833                column: Some("revenue".to_string()),
834                aggregation: Some("bogus".to_string()),
835                name: "x".to_string(),
836                expression: None,
837            }],
838            filters: None,
839            sort: None,
840            limit: None,
841        };
842
843        let result = aggregate(&data, &spec);
844        assert!(result.is_err());
845        assert!(result
846            .unwrap_err()
847            .to_string()
848            .contains("Unknown aggregation: bogus"));
849    }
850
851    #[test]
852    fn filter_contains() {
853        let row = make_row(vec![("name", json!("Hello World"))]);
854        let rule = FilterRule {
855            field: "name".to_string(),
856            operator: "contains".to_string(),
857            value: Some(json!("World")),
858        };
859        assert!(eval_filter_rule(&row, &rule));
860
861        let rule_miss = FilterRule {
862            field: "name".to_string(),
863            operator: "contains".to_string(),
864            value: Some(json!("xyz")),
865        };
866        assert!(!eval_filter_rule(&row, &rule_miss));
867    }
868
869    #[test]
870    fn filter_is_null_and_is_not_null() {
871        let row_with = make_row(vec![("a", json!(42))]);
872        let row_null = make_row(vec![("a", serde_json::Value::Null)]);
873        let row_missing: Row = HashMap::new();
874
875        let rule_null = FilterRule {
876            field: "a".to_string(),
877            operator: "isNull".to_string(),
878            value: None,
879        };
880        let rule_not_null = FilterRule {
881            field: "a".to_string(),
882            operator: "isNotNull".to_string(),
883            value: None,
884        };
885
886        assert!(!eval_filter_rule(&row_with, &rule_null));
887        assert!(eval_filter_rule(&row_with, &rule_not_null));
888
889        assert!(eval_filter_rule(&row_null, &rule_null));
890        assert!(!eval_filter_rule(&row_null, &rule_not_null));
891
892        assert!(eval_filter_rule(&row_missing, &rule_null));
893        assert!(!eval_filter_rule(&row_missing, &rule_not_null));
894    }
895
896    #[test]
897    fn filter_ne() {
898        let row = make_row(vec![("x", json!("A"))]);
899        let rule = FilterRule {
900            field: "x".to_string(),
901            operator: "!=".to_string(),
902            value: Some(json!("B")),
903        };
904        assert!(eval_filter_rule(&row, &rule));
905
906        let rule_same = FilterRule {
907            field: "x".to_string(),
908            operator: "!=".to_string(),
909            value: Some(json!("A")),
910        };
911        assert!(!eval_filter_rule(&row, &rule_same));
912    }
913
914    #[test]
915    fn filter_lte_gte() {
916        let row = make_row(vec![("v", json!(10))]);
917
918        assert!(eval_filter_rule(
919            &row,
920            &FilterRule {
921                field: "v".to_string(),
922                operator: "<=".to_string(),
923                value: Some(json!(10)),
924            }
925        ));
926        assert!(eval_filter_rule(
927            &row,
928            &FilterRule {
929                field: "v".to_string(),
930                operator: ">=".to_string(),
931                value: Some(json!(10)),
932            }
933        ));
934        assert!(!eval_filter_rule(
935            &row,
936            &FilterRule {
937                field: "v".to_string(),
938                operator: "<".to_string(),
939                value: Some(json!(10)),
940            }
941        ));
942    }
943
944    #[test]
945    fn sort_asc_default() {
946        let mut data = vec![
947            make_row(vec![("v", json!(30))]),
948            make_row(vec![("v", json!(10))]),
949            make_row(vec![("v", json!(20))]),
950        ];
951        apply_sort(
952            &mut data,
953            &[SortSpec {
954                field: "v".to_string(),
955                direction: None, // defaults to asc
956            }],
957        );
958        assert_eq!(get_f64(&data[0], "v"), Some(10.0));
959        assert_eq!(get_f64(&data[1], "v"), Some(20.0));
960        assert_eq!(get_f64(&data[2], "v"), Some(30.0));
961    }
962
963    #[test]
964    fn empty_data_aggregation() {
965        let data: Vec<Row> = vec![];
966        let spec = AggregateSpec {
967            dimensions: vec![Dimension::Simple("region".to_string())],
968            measures: vec![Measure {
969                column: Some("revenue".to_string()),
970                aggregation: Some("sum".to_string()),
971                name: "total".to_string(),
972                expression: None,
973            }],
974            filters: None,
975            sort: None,
976            limit: None,
977        };
978
979        let result = aggregate(&data, &spec).unwrap();
980        assert!(result.is_empty());
981    }
982}