Skip to main content

chartml_core/
transform.rs

1use crate::data::{get_f64, get_string, Row};
2use crate::error::ChartError;
3use crate::spec::{AggregateSpec, Dimension, FilterGroup, FilterRule, SortSpec, TransformSpec};
4use std::collections::HashMap;
5
6/// Apply transforms to data.
7pub fn apply_transforms(data: Vec<Row>, spec: &TransformSpec) -> Result<Vec<Row>, ChartError> {
8    let mut result = data;
9
10    if let Some(ref agg) = spec.aggregate {
11        result = aggregate(&result, agg)?;
12    }
13
14    Ok(result)
15}
16
17/// Aggregate data by dimensions with measures.
18/// Matches JS behavior: if no dimensions and no measures, skip aggregation
19/// and just apply filters/sort/limit to the raw data.
20fn aggregate(data: &[Row], spec: &AggregateSpec) -> Result<Vec<Row>, ChartError> {
21    // No aggregation needed — just filter/sort/limit the raw data
22    if spec.dimensions.is_empty() && spec.measures.is_empty() {
23        let mut result = data.to_vec();
24        if let Some(ref filters) = spec.filters {
25            result = apply_filters(result, filters);
26        }
27        if let Some(ref sorts) = spec.sort {
28            apply_sort(&mut result, sorts);
29        }
30        if let Some(limit) = spec.limit {
31            result.truncate(limit as usize);
32        }
33        return Ok(result);
34    }
35
36    // 1. Extract dimension output names and source column names
37    let dim_fields: Vec<String> = spec
38        .dimensions
39        .iter()
40        .map(|d| match d {
41            Dimension::Simple(name) => name.clone(),
42            Dimension::Detailed(spec) => spec.name.clone().unwrap_or_else(|| spec.column.clone()),
43        })
44        .collect();
45
46    let dim_columns: Vec<String> = spec
47        .dimensions
48        .iter()
49        .map(|d| match d {
50            Dimension::Simple(name) => name.clone(),
51            Dimension::Detailed(spec) => spec.column.clone(),
52        })
53        .collect();
54
55    // 2. Group rows by dimension values, preserving first-occurrence order of group keys.
56    let mut groups: HashMap<Vec<String>, Vec<&Row>> = HashMap::new();
57    let mut key_order: Vec<Vec<String>> = Vec::new();
58    let mut seen_keys: std::collections::HashSet<Vec<String>> = std::collections::HashSet::new();
59    for row in data {
60        let key: Vec<String> = dim_columns
61            .iter()
62            .map(|field| get_string(row, field).unwrap_or_default())
63            .collect();
64        if seen_keys.insert(key.clone()) {
65            key_order.push(key.clone());
66        }
67        groups.entry(key).or_default().push(row);
68    }
69
70    // 3. For each group, compute measures (in insertion order)
71    let mut result: Vec<Row> = Vec::new();
72    for key in &key_order {
73        let rows = &groups[key];
74        let mut out_row = Row::new();
75
76        // Add dimension values
77        for (i, field_name) in dim_fields.iter().enumerate() {
78            out_row.insert(
79                field_name.clone(),
80                serde_json::Value::String(key[i].clone()),
81            );
82        }
83
84        // Compute each measure
85        for measure in &spec.measures {
86            let value = if let Some(ref col) = measure.column {
87                if let Some(ref agg_fn) = measure.aggregation {
88                    compute_aggregation(rows, col, agg_fn)?
89                } else {
90                    // No aggregation specified, take first value
91                    rows.first()
92                        .and_then(|r| get_f64(r, col))
93                        .unwrap_or(0.0)
94                }
95            } else if let Some(ref expr) = measure.expression {
96                evaluate_expression(&out_row, expr)?
97            } else {
98                0.0
99            };
100            out_row.insert(measure.name.clone(), serde_json::json!(value));
101        }
102
103        result.push(out_row);
104    }
105
106    // 4. Apply filters (if any)
107    if let Some(ref filters) = spec.filters {
108        result = apply_filters(result, filters);
109    }
110
111    // 5. Apply sort (if any)
112    if let Some(ref sorts) = spec.sort {
113        apply_sort(&mut result, sorts);
114    }
115
116    // 6. Apply limit (if any)
117    if let Some(limit) = spec.limit {
118        result.truncate(limit as usize);
119    }
120
121    Ok(result)
122}
123
124fn compute_aggregation(rows: &[&Row], column: &str, agg: &str) -> Result<f64, ChartError> {
125    let values: Vec<f64> = rows.iter().filter_map(|r| get_f64(r, column)).collect();
126    if values.is_empty() {
127        return Ok(0.0);
128    }
129
130    Ok(match agg {
131        "sum" => values.iter().sum(),
132        "avg" => values.iter().sum::<f64>() / values.len() as f64,
133        "count" => values.len() as f64,
134        "min" => values.iter().cloned().fold(f64::INFINITY, f64::min),
135        "max" => values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
136        "countDistinct" => {
137            let mut seen = std::collections::HashSet::new();
138            for v in &values {
139                seen.insert(v.to_bits());
140            }
141            seen.len() as f64
142        }
143        "median" => {
144            let mut sorted = values.clone();
145            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
146            let mid = sorted.len() / 2;
147            if sorted.len().is_multiple_of(2) {
148                (sorted[mid - 1] + sorted[mid]) / 2.0
149            } else {
150                sorted[mid]
151            }
152        }
153        other => {
154            return Err(ChartError::InvalidSpec(format!(
155                "Unknown aggregation: {}",
156                other
157            )))
158        }
159    })
160}
161
162/// Evaluate a simple arithmetic expression referencing fields already computed in the row.
163/// Supports: field_a / field_b, field_a * field_b, field_a + field_b, field_a - field_b
164fn evaluate_expression(row: &Row, expr: &str) -> Result<f64, ChartError> {
165    // Find the operator
166    let operators = ['/', '*', '+', '-'];
167    for op in &operators {
168        if let Some(pos) = expr.find(*op) {
169            // Make sure it's not a negative sign at position 0
170            if *op == '-' && pos == 0 {
171                continue;
172            }
173            let left_name = expr[..pos].trim();
174            let right_name = expr[pos + 1..].trim();
175
176            let left_val = get_f64(row, left_name).ok_or_else(|| {
177                ChartError::DataError(format!(
178                    "Expression field '{}' not found in row",
179                    left_name
180                ))
181            })?;
182            let right_val = get_f64(row, right_name).ok_or_else(|| {
183                ChartError::DataError(format!(
184                    "Expression field '{}' not found in row",
185                    right_name
186                ))
187            })?;
188
189            return Ok(match op {
190                '+' => left_val + right_val,
191                '-' => left_val - right_val,
192                '*' => left_val * right_val,
193                '/' => {
194                    if right_val == 0.0 {
195                        0.0
196                    } else {
197                        left_val / right_val
198                    }
199                }
200                _ => unreachable!(),
201            });
202        }
203    }
204
205    // No operator found — try interpreting as a field reference
206    get_f64(row, expr.trim()).ok_or_else(|| {
207        ChartError::DataError(format!(
208            "Cannot evaluate expression '{}': field not found",
209            expr
210        ))
211    })
212}
213
214fn apply_filters(data: Vec<Row>, filters: &FilterGroup) -> Vec<Row> {
215    let is_and = filters.combinator.as_deref() != Some("or");
216
217    data.into_iter()
218        .filter(|row| {
219            let results: Vec<bool> = filters.rules.iter().map(|rule| eval_filter_rule(row, rule)).collect();
220
221            if is_and {
222                results.iter().all(|&r| r)
223            } else {
224                results.iter().any(|&r| r)
225            }
226        })
227        .collect()
228}
229
230fn eval_filter_rule(row: &Row, rule: &FilterRule) -> bool {
231    let field_val = row.get(&rule.field);
232
233    match rule.operator.as_str() {
234        "isNull" => field_val.is_none() || field_val == Some(&serde_json::Value::Null),
235        "isNotNull" => field_val.is_some() && field_val != Some(&serde_json::Value::Null),
236        "=" | "==" => rule
237            .value
238            .as_ref()
239            .is_some_and(|v| field_val == Some(v)),
240        "!=" => rule
241            .value
242            .as_ref()
243            .is_some_and(|v| field_val != Some(v)),
244        ">" => compare_values(field_val, rule.value.as_ref(), |a, b| a > b),
245        ">=" => compare_values(field_val, rule.value.as_ref(), |a, b| a >= b),
246        "<" => compare_values(field_val, rule.value.as_ref(), |a, b| a < b),
247        "<=" => compare_values(field_val, rule.value.as_ref(), |a, b| a <= b),
248        "in" => {
249            if let (Some(fv), Some(serde_json::Value::Array(arr))) =
250                (field_val, rule.value.as_ref())
251            {
252                arr.contains(fv)
253            } else {
254                false
255            }
256        }
257        "contains" => {
258            if let (Some(serde_json::Value::String(fv)), Some(serde_json::Value::String(rv))) =
259                (field_val, rule.value.as_ref())
260            {
261                fv.contains(rv.as_str())
262            } else {
263                false
264            }
265        }
266        _ => true, // unknown operator, pass through
267    }
268}
269
270fn compare_values(
271    field: Option<&serde_json::Value>,
272    rule_val: Option<&serde_json::Value>,
273    cmp: impl Fn(f64, f64) -> bool,
274) -> bool {
275    match (field, rule_val) {
276        (Some(serde_json::Value::Number(a)), Some(serde_json::Value::Number(b))) => {
277            if let (Some(fa), Some(fb)) = (a.as_f64(), b.as_f64()) {
278                cmp(fa, fb)
279            } else {
280                false
281            }
282        }
283        _ => false,
284    }
285}
286
287fn apply_sort(data: &mut [Row], sorts: &[SortSpec]) {
288    data.sort_by(|a, b| {
289        for sort in sorts {
290            let a_val = get_f64(a, &sort.field);
291            let b_val = get_f64(b, &sort.field);
292            let ord = match (a_val, b_val) {
293                (Some(av), Some(bv)) => av
294                    .partial_cmp(&bv)
295                    .unwrap_or(std::cmp::Ordering::Equal),
296                _ => std::cmp::Ordering::Equal,
297            };
298            let ord = if sort.direction.as_deref() == Some("desc") {
299                ord.reverse()
300            } else {
301                ord
302            };
303            if ord != std::cmp::Ordering::Equal {
304                return ord;
305            }
306        }
307        std::cmp::Ordering::Equal
308    });
309}
310
311#[cfg(test)]
312mod tests {
313    #![allow(clippy::unwrap_used)]
314    use super::*;
315    use crate::spec::{AggregateSpec, Dimension, Measure, SortSpec, TransformSpec};
316    use serde_json::json;
317
318    fn make_row(pairs: Vec<(&str, serde_json::Value)>) -> Row {
319        pairs
320            .into_iter()
321            .map(|(k, v)| (k.to_string(), v))
322            .collect()
323    }
324
325    fn sales_data() -> Vec<Row> {
326        vec![
327            make_row(vec![
328                ("region", json!("North")),
329                ("product", json!("Widget")),
330                ("revenue", json!(100.0)),
331                ("units", json!(10)),
332            ]),
333            make_row(vec![
334                ("region", json!("North")),
335                ("product", json!("Gadget")),
336                ("revenue", json!(200.0)),
337                ("units", json!(15)),
338            ]),
339            make_row(vec![
340                ("region", json!("South")),
341                ("product", json!("Widget")),
342                ("revenue", json!(150.0)),
343                ("units", json!(12)),
344            ]),
345            make_row(vec![
346                ("region", json!("South")),
347                ("product", json!("Widget")),
348                ("revenue", json!(50.0)),
349                ("units", json!(5)),
350            ]),
351            make_row(vec![
352                ("region", json!("East")),
353                ("product", json!("Gadget")),
354                ("revenue", json!(300.0)),
355                ("units", json!(20)),
356            ]),
357        ]
358    }
359
360    #[test]
361    fn aggregate_sum() {
362        let data = sales_data();
363        let spec = AggregateSpec {
364            dimensions: vec![Dimension::Simple("region".to_string())],
365            measures: vec![Measure {
366                column: Some("revenue".to_string()),
367                aggregation: Some("sum".to_string()),
368                name: "total_revenue".to_string(),
369                expression: None,
370            }],
371            filters: None,
372            sort: None,
373            limit: None,
374        };
375
376        let result = aggregate(&data, &spec).unwrap();
377        assert_eq!(result.len(), 3);
378
379        // Find each region and check totals
380        let north = result
381            .iter()
382            .find(|r| get_string(r, "region") == Some("North".to_string()))
383            .unwrap();
384        assert_eq!(get_f64(north, "total_revenue"), Some(300.0));
385
386        let south = result
387            .iter()
388            .find(|r| get_string(r, "region") == Some("South".to_string()))
389            .unwrap();
390        assert_eq!(get_f64(south, "total_revenue"), Some(200.0));
391
392        let east = result
393            .iter()
394            .find(|r| get_string(r, "region") == Some("East".to_string()))
395            .unwrap();
396        assert_eq!(get_f64(east, "total_revenue"), Some(300.0));
397    }
398
399    #[test]
400    fn aggregate_avg() {
401        let data = sales_data();
402        let spec = AggregateSpec {
403            dimensions: vec![Dimension::Simple("region".to_string())],
404            measures: vec![Measure {
405                column: Some("revenue".to_string()),
406                aggregation: Some("avg".to_string()),
407                name: "avg_revenue".to_string(),
408                expression: None,
409            }],
410            filters: None,
411            sort: None,
412            limit: None,
413        };
414
415        let result = aggregate(&data, &spec).unwrap();
416        let north = result
417            .iter()
418            .find(|r| get_string(r, "region") == Some("North".to_string()))
419            .unwrap();
420        assert_eq!(get_f64(north, "avg_revenue"), Some(150.0)); // (100+200)/2
421
422        let south = result
423            .iter()
424            .find(|r| get_string(r, "region") == Some("South".to_string()))
425            .unwrap();
426        assert_eq!(get_f64(south, "avg_revenue"), Some(100.0)); // (150+50)/2
427    }
428
429    #[test]
430    fn aggregate_count() {
431        let data = sales_data();
432        let spec = AggregateSpec {
433            dimensions: vec![Dimension::Simple("region".to_string())],
434            measures: vec![Measure {
435                column: Some("revenue".to_string()),
436                aggregation: Some("count".to_string()),
437                name: "count".to_string(),
438                expression: None,
439            }],
440            filters: None,
441            sort: None,
442            limit: None,
443        };
444
445        let result = aggregate(&data, &spec).unwrap();
446        let north = result
447            .iter()
448            .find(|r| get_string(r, "region") == Some("North".to_string()))
449            .unwrap();
450        assert_eq!(get_f64(north, "count"), Some(2.0));
451
452        let south = result
453            .iter()
454            .find(|r| get_string(r, "region") == Some("South".to_string()))
455            .unwrap();
456        assert_eq!(get_f64(south, "count"), Some(2.0));
457    }
458
459    #[test]
460    fn aggregate_min_max() {
461        let data = sales_data();
462        let spec = AggregateSpec {
463            dimensions: vec![Dimension::Simple("region".to_string())],
464            measures: vec![
465                Measure {
466                    column: Some("revenue".to_string()),
467                    aggregation: Some("min".to_string()),
468                    name: "min_rev".to_string(),
469                    expression: None,
470                },
471                Measure {
472                    column: Some("revenue".to_string()),
473                    aggregation: Some("max".to_string()),
474                    name: "max_rev".to_string(),
475                    expression: None,
476                },
477            ],
478            filters: None,
479            sort: None,
480            limit: None,
481        };
482
483        let result = aggregate(&data, &spec).unwrap();
484        let south = result
485            .iter()
486            .find(|r| get_string(r, "region") == Some("South".to_string()))
487            .unwrap();
488        assert_eq!(get_f64(south, "min_rev"), Some(50.0));
489        assert_eq!(get_f64(south, "max_rev"), Some(150.0));
490    }
491
492    #[test]
493    fn aggregate_count_distinct() {
494        let data = sales_data();
495        let spec = AggregateSpec {
496            dimensions: vec![Dimension::Simple("region".to_string())],
497            measures: vec![Measure {
498                column: Some("revenue".to_string()),
499                aggregation: Some("countDistinct".to_string()),
500                name: "distinct_rev".to_string(),
501                expression: None,
502            }],
503            filters: None,
504            sort: None,
505            limit: None,
506        };
507
508        let result = aggregate(&data, &spec).unwrap();
509        let north = result
510            .iter()
511            .find(|r| get_string(r, "region") == Some("North".to_string()))
512            .unwrap();
513        assert_eq!(get_f64(north, "distinct_rev"), Some(2.0)); // 100 and 200
514    }
515
516    #[test]
517    fn aggregate_median() {
518        let data = vec![
519            make_row(vec![("g", json!("A")), ("v", json!(1.0))]),
520            make_row(vec![("g", json!("A")), ("v", json!(3.0))]),
521            make_row(vec![("g", json!("A")), ("v", json!(5.0))]),
522            make_row(vec![("g", json!("B")), ("v", json!(2.0))]),
523            make_row(vec![("g", json!("B")), ("v", json!(4.0))]),
524        ];
525        let spec = AggregateSpec {
526            dimensions: vec![Dimension::Simple("g".to_string())],
527            measures: vec![Measure {
528                column: Some("v".to_string()),
529                aggregation: Some("median".to_string()),
530                name: "med".to_string(),
531                expression: None,
532            }],
533            filters: None,
534            sort: None,
535            limit: None,
536        };
537
538        let result = aggregate(&data, &spec).unwrap();
539        let a = result
540            .iter()
541            .find(|r| get_string(r, "g") == Some("A".to_string()))
542            .unwrap();
543        assert_eq!(get_f64(a, "med"), Some(3.0)); // odd count: middle
544
545        let b = result
546            .iter()
547            .find(|r| get_string(r, "g") == Some("B".to_string()))
548            .unwrap();
549        assert_eq!(get_f64(b, "med"), Some(3.0)); // even count: (2+4)/2
550    }
551
552    #[test]
553    fn aggregate_with_sort() {
554        let data = sales_data();
555        let spec = AggregateSpec {
556            dimensions: vec![Dimension::Simple("region".to_string())],
557            measures: vec![Measure {
558                column: Some("revenue".to_string()),
559                aggregation: Some("sum".to_string()),
560                name: "total_revenue".to_string(),
561                expression: None,
562            }],
563            filters: None,
564            sort: Some(vec![SortSpec {
565                field: "total_revenue".to_string(),
566                direction: Some("desc".to_string()),
567            }]),
568            limit: None,
569        };
570
571        let result = aggregate(&data, &spec).unwrap();
572        assert_eq!(result.len(), 3);
573        // North=300, East=300, South=200 — both 300s first, then 200
574        let first_val = get_f64(&result[0], "total_revenue").unwrap();
575        let second_val = get_f64(&result[1], "total_revenue").unwrap();
576        let third_val = get_f64(&result[2], "total_revenue").unwrap();
577        assert!(first_val >= second_val);
578        assert!(second_val >= third_val);
579        assert_eq!(third_val, 200.0);
580    }
581
582    #[test]
583    fn aggregate_with_limit() {
584        let data = sales_data();
585        let spec = AggregateSpec {
586            dimensions: vec![Dimension::Simple("region".to_string())],
587            measures: vec![Measure {
588                column: Some("revenue".to_string()),
589                aggregation: Some("sum".to_string()),
590                name: "total_revenue".to_string(),
591                expression: None,
592            }],
593            filters: None,
594            sort: Some(vec![SortSpec {
595                field: "total_revenue".to_string(),
596                direction: Some("desc".to_string()),
597            }]),
598            limit: Some(2),
599        };
600
601        let result = aggregate(&data, &spec).unwrap();
602        assert_eq!(result.len(), 2);
603    }
604
605    #[test]
606    fn aggregate_with_detailed_dimension() {
607        let data = sales_data();
608        let spec = AggregateSpec {
609            dimensions: vec![Dimension::Detailed(crate::spec::DimensionSpec {
610                column: "region".to_string(),
611                name: Some("Region".to_string()),
612                dim_type: None,
613            })],
614            measures: vec![Measure {
615                column: Some("revenue".to_string()),
616                aggregation: Some("sum".to_string()),
617                name: "total".to_string(),
618                expression: None,
619            }],
620            filters: None,
621            sort: None,
622            limit: None,
623        };
624
625        let result = aggregate(&data, &spec).unwrap();
626        // Output field name should be "Region" not "region"
627        let north = result
628            .iter()
629            .find(|r| get_string(r, "Region") == Some("North".to_string()))
630            .unwrap();
631        assert_eq!(get_f64(north, "total"), Some(300.0));
632    }
633
634    #[test]
635    fn aggregate_with_filter_gt() {
636        let data = sales_data();
637        let spec = AggregateSpec {
638            dimensions: vec![Dimension::Simple("region".to_string())],
639            measures: vec![Measure {
640                column: Some("revenue".to_string()),
641                aggregation: Some("sum".to_string()),
642                name: "total_revenue".to_string(),
643                expression: None,
644            }],
645            filters: Some(FilterGroup {
646                combinator: None, // defaults to "and"
647                rules: vec![FilterRule {
648                    field: "total_revenue".to_string(),
649                    operator: ">".to_string(),
650                    value: Some(json!(250)),
651                }],
652            }),
653            sort: None,
654            limit: None,
655        };
656
657        let result = aggregate(&data, &spec).unwrap();
658        // North=300, East=300 pass; South=200 filtered out
659        assert_eq!(result.len(), 2);
660        for row in &result {
661            assert!(get_f64(row, "total_revenue").unwrap() > 250.0);
662        }
663    }
664
665    #[test]
666    fn aggregate_with_filter_eq() {
667        let data = sales_data();
668        let spec = AggregateSpec {
669            dimensions: vec![Dimension::Simple("region".to_string())],
670            measures: vec![Measure {
671                column: Some("revenue".to_string()),
672                aggregation: Some("sum".to_string()),
673                name: "total_revenue".to_string(),
674                expression: None,
675            }],
676            filters: Some(FilterGroup {
677                combinator: None,
678                rules: vec![FilterRule {
679                    field: "region".to_string(),
680                    operator: "=".to_string(),
681                    value: Some(json!("North")),
682                }],
683            }),
684            sort: None,
685            limit: None,
686        };
687
688        let result = aggregate(&data, &spec).unwrap();
689        assert_eq!(result.len(), 1);
690        assert_eq!(
691            get_string(&result[0], "region"),
692            Some("North".to_string())
693        );
694    }
695
696    #[test]
697    fn aggregate_with_filter_in() {
698        let data = sales_data();
699        let spec = AggregateSpec {
700            dimensions: vec![Dimension::Simple("region".to_string())],
701            measures: vec![Measure {
702                column: Some("revenue".to_string()),
703                aggregation: Some("sum".to_string()),
704                name: "total_revenue".to_string(),
705                expression: None,
706            }],
707            filters: Some(FilterGroup {
708                combinator: None,
709                rules: vec![FilterRule {
710                    field: "region".to_string(),
711                    operator: "in".to_string(),
712                    value: Some(json!(["North", "East"])),
713                }],
714            }),
715            sort: None,
716            limit: None,
717        };
718
719        let result = aggregate(&data, &spec).unwrap();
720        assert_eq!(result.len(), 2);
721    }
722
723    #[test]
724    fn aggregate_with_filter_or_combinator() {
725        let data = sales_data();
726        let spec = AggregateSpec {
727            dimensions: vec![Dimension::Simple("region".to_string())],
728            measures: vec![Measure {
729                column: Some("revenue".to_string()),
730                aggregation: Some("sum".to_string()),
731                name: "total_revenue".to_string(),
732                expression: None,
733            }],
734            filters: Some(FilterGroup {
735                combinator: Some("or".to_string()),
736                rules: vec![
737                    FilterRule {
738                        field: "region".to_string(),
739                        operator: "=".to_string(),
740                        value: Some(json!("North")),
741                    },
742                    FilterRule {
743                        field: "region".to_string(),
744                        operator: "=".to_string(),
745                        value: Some(json!("East")),
746                    },
747                ],
748            }),
749            sort: None,
750            limit: None,
751        };
752
753        let result = aggregate(&data, &spec).unwrap();
754        assert_eq!(result.len(), 2);
755    }
756
757    #[test]
758    fn aggregate_with_expression_measure() {
759        let data = sales_data();
760        let spec = AggregateSpec {
761            dimensions: vec![Dimension::Simple("region".to_string())],
762            measures: vec![
763                Measure {
764                    column: Some("revenue".to_string()),
765                    aggregation: Some("sum".to_string()),
766                    name: "total_revenue".to_string(),
767                    expression: None,
768                },
769                Measure {
770                    column: Some("units".to_string()),
771                    aggregation: Some("sum".to_string()),
772                    name: "total_units".to_string(),
773                    expression: None,
774                },
775                Measure {
776                    column: None,
777                    aggregation: None,
778                    name: "avg_price".to_string(),
779                    expression: Some("total_revenue / total_units".to_string()),
780                },
781            ],
782            filters: None,
783            sort: None,
784            limit: None,
785        };
786
787        let result = aggregate(&data, &spec).unwrap();
788        let north = result
789            .iter()
790            .find(|r| get_string(r, "region") == Some("North".to_string()))
791            .unwrap();
792        // North: revenue=300, units=25, avg_price=12
793        assert_eq!(get_f64(north, "total_revenue"), Some(300.0));
794        assert_eq!(get_f64(north, "total_units"), Some(25.0));
795        assert_eq!(get_f64(north, "avg_price"), Some(12.0));
796    }
797
798    #[test]
799    fn apply_transforms_full_pipeline() {
800        let data = sales_data();
801        let spec = TransformSpec {
802            sql: None,
803            forecast: None,
804            aggregate: Some(AggregateSpec {
805                dimensions: vec![Dimension::Simple("region".to_string())],
806                measures: vec![Measure {
807                    column: Some("revenue".to_string()),
808                    aggregation: Some("sum".to_string()),
809                    name: "total_revenue".to_string(),
810                    expression: None,
811                }],
812                filters: None,
813                sort: Some(vec![SortSpec {
814                    field: "total_revenue".to_string(),
815                    direction: Some("desc".to_string()),
816                }]),
817                limit: Some(2),
818            }),
819        };
820
821        let result = apply_transforms(data, &spec).unwrap();
822        assert_eq!(result.len(), 2);
823        let first_val = get_f64(&result[0], "total_revenue").unwrap();
824        let second_val = get_f64(&result[1], "total_revenue").unwrap();
825        assert!(first_val >= second_val);
826    }
827
828    #[test]
829    fn unknown_aggregation_returns_error() {
830        let data = sales_data();
831        let spec = AggregateSpec {
832            dimensions: vec![Dimension::Simple("region".to_string())],
833            measures: vec![Measure {
834                column: Some("revenue".to_string()),
835                aggregation: Some("bogus".to_string()),
836                name: "x".to_string(),
837                expression: None,
838            }],
839            filters: None,
840            sort: None,
841            limit: None,
842        };
843
844        let result = aggregate(&data, &spec);
845        assert!(result.is_err());
846        assert!(result
847            .unwrap_err()
848            .to_string()
849            .contains("Unknown aggregation: bogus"));
850    }
851
852    #[test]
853    fn filter_contains() {
854        let row = make_row(vec![("name", json!("Hello World"))]);
855        let rule = FilterRule {
856            field: "name".to_string(),
857            operator: "contains".to_string(),
858            value: Some(json!("World")),
859        };
860        assert!(eval_filter_rule(&row, &rule));
861
862        let rule_miss = FilterRule {
863            field: "name".to_string(),
864            operator: "contains".to_string(),
865            value: Some(json!("xyz")),
866        };
867        assert!(!eval_filter_rule(&row, &rule_miss));
868    }
869
870    #[test]
871    fn filter_is_null_and_is_not_null() {
872        let row_with = make_row(vec![("a", json!(42))]);
873        let row_null = make_row(vec![("a", serde_json::Value::Null)]);
874        let row_missing: Row = HashMap::new();
875
876        let rule_null = FilterRule {
877            field: "a".to_string(),
878            operator: "isNull".to_string(),
879            value: None,
880        };
881        let rule_not_null = FilterRule {
882            field: "a".to_string(),
883            operator: "isNotNull".to_string(),
884            value: None,
885        };
886
887        assert!(!eval_filter_rule(&row_with, &rule_null));
888        assert!(eval_filter_rule(&row_with, &rule_not_null));
889
890        assert!(eval_filter_rule(&row_null, &rule_null));
891        assert!(!eval_filter_rule(&row_null, &rule_not_null));
892
893        assert!(eval_filter_rule(&row_missing, &rule_null));
894        assert!(!eval_filter_rule(&row_missing, &rule_not_null));
895    }
896
897    #[test]
898    fn filter_ne() {
899        let row = make_row(vec![("x", json!("A"))]);
900        let rule = FilterRule {
901            field: "x".to_string(),
902            operator: "!=".to_string(),
903            value: Some(json!("B")),
904        };
905        assert!(eval_filter_rule(&row, &rule));
906
907        let rule_same = FilterRule {
908            field: "x".to_string(),
909            operator: "!=".to_string(),
910            value: Some(json!("A")),
911        };
912        assert!(!eval_filter_rule(&row, &rule_same));
913    }
914
915    #[test]
916    fn filter_lte_gte() {
917        let row = make_row(vec![("v", json!(10))]);
918
919        assert!(eval_filter_rule(
920            &row,
921            &FilterRule {
922                field: "v".to_string(),
923                operator: "<=".to_string(),
924                value: Some(json!(10)),
925            }
926        ));
927        assert!(eval_filter_rule(
928            &row,
929            &FilterRule {
930                field: "v".to_string(),
931                operator: ">=".to_string(),
932                value: Some(json!(10)),
933            }
934        ));
935        assert!(!eval_filter_rule(
936            &row,
937            &FilterRule {
938                field: "v".to_string(),
939                operator: "<".to_string(),
940                value: Some(json!(10)),
941            }
942        ));
943    }
944
945    #[test]
946    fn sort_asc_default() {
947        let mut data = vec![
948            make_row(vec![("v", json!(30))]),
949            make_row(vec![("v", json!(10))]),
950            make_row(vec![("v", json!(20))]),
951        ];
952        apply_sort(
953            &mut data,
954            &[SortSpec {
955                field: "v".to_string(),
956                direction: None, // defaults to asc
957            }],
958        );
959        assert_eq!(get_f64(&data[0], "v"), Some(10.0));
960        assert_eq!(get_f64(&data[1], "v"), Some(20.0));
961        assert_eq!(get_f64(&data[2], "v"), Some(30.0));
962    }
963
964    #[test]
965    fn empty_data_aggregation() {
966        let data: Vec<Row> = vec![];
967        let spec = AggregateSpec {
968            dimensions: vec![Dimension::Simple("region".to_string())],
969            measures: vec![Measure {
970                column: Some("revenue".to_string()),
971                aggregation: Some("sum".to_string()),
972                name: "total".to_string(),
973                expression: None,
974            }],
975            filters: None,
976            sort: None,
977            limit: None,
978        };
979
980        let result = aggregate(&data, &spec).unwrap();
981        assert!(result.is_empty());
982    }
983}