Skip to main content

chartml_datafusion/
lib.rs

1//! DataFusion-backed transform middleware for ChartML.
2//!
3//! Implements the 3-stage pipeline: SQL → Aggregate → Forecast.
4//! Compatible with both native (server) and WASM (browser) targets.
5
6pub mod conversion;
7pub mod sql_builder;
8pub mod stages;
9
10use async_trait::async_trait;
11use chartml_core::data::DataTable;
12use chartml_core::error::ChartError;
13use chartml_core::plugin::transform::{TransformContext, TransformMiddleware, TransformResult};
14use chartml_core::spec::TransformSpec;
15use datafusion::prelude::*;
16use arrow::array::RecordBatch;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20/// DataFusion-backed transform middleware.
21///
22/// Processes data through a 3-stage pipeline:
23/// 1. **SQL stage** — execute raw SQL with placeholder replacement
24/// 2. **Aggregate stage** — declarative GROUP BY / measures / filters
25/// 3. **Forecast stage** — time series forecasting via chartml-forecast
26pub struct DataFusionTransform;
27
28#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
29#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
30impl TransformMiddleware for DataFusionTransform {
31    async fn transform(
32        &self,
33        data: DataTable,
34        spec: &TransformSpec,
35        _context: &TransformContext,
36    ) -> Result<TransformResult, ChartError> {
37        let ctx = SessionContext::new();
38
39        // Register input data as "source" table — no conversion needed,
40        // DataTable already holds an Arrow RecordBatch.
41        let batch = data.record_batch().clone();
42        let schema = batch.schema();
43        let mem_table =
44            datafusion::datasource::MemTable::try_new(schema, vec![vec![batch]]).map_err(|e| {
45                ChartError::DataError(format!("Failed to create source MemTable: {}", e))
46            })?;
47        ctx.register_table("source", std::sync::Arc::new(mem_table))
48            .map_err(|e| {
49                ChartError::DataError(format!("Failed to register source table: {}", e))
50            })?;
51
52        let mut current_table = "source".to_string();
53
54        // Stage 1: SQL
55        if let Some(ref sql_spec) = spec.sql {
56            current_table =
57                stages::sql_stage::execute(&ctx, &current_table, sql_spec).await?;
58        }
59
60        // Stage 2: Aggregate
61        if let Some(ref agg_spec) = spec.aggregate {
62            current_table =
63                stages::aggregate_stage::execute(&ctx, &current_table, agg_spec).await?;
64        }
65
66        // Stage 3: Forecast
67        if let Some(ref forecast_spec) = spec.forecast {
68            current_table =
69                stages::forecast_stage::execute(&ctx, &current_table, forecast_spec).await?;
70        }
71
72        // Collect final result
73        let df = ctx
74            .table(&current_table)
75            .await
76            .map_err(|e| ChartError::DataError(format!("Failed to read result table: {}", e)))?;
77        let output_schema = Arc::new(df.schema().as_arrow().clone());
78        let batches = df
79            .collect()
80            .await
81            .map_err(|e| ChartError::DataError(format!("Failed to collect results: {}", e)))?;
82
83        if batches.is_empty() {
84            // DataFusion returned no batches (e.g., WHERE filtered all rows).
85            // Return an empty DataTable preserving the output schema.
86            return Ok(TransformResult {
87                data: DataTable::from_record_batch(RecordBatch::new_empty(output_schema)),
88                metadata: HashMap::new(),
89            });
90        }
91
92        // Concatenate all output batches into a single RecordBatch and wrap in DataTable.
93        let result_batch = arrow::compute::concat_batches(
94            &output_schema,
95            &batches,
96        )
97        .map_err(|e| ChartError::DataError(format!("Failed to concat result batches: {}", e)))?;
98
99        Ok(TransformResult {
100            data: DataTable::from_record_batch(result_batch),
101            metadata: HashMap::new(),
102        })
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use chartml_core::data::Row;
110    use chartml_core::spec::*;
111    use serde_json::json;
112
113    fn make_row(pairs: Vec<(&str, serde_json::Value)>) -> Row {
114        pairs
115            .into_iter()
116            .map(|(k, v)| (k.to_string(), v))
117            .collect()
118    }
119
120    fn sales_rows() -> Vec<Row> {
121        vec![
122            make_row(vec![
123                ("region", json!("North")),
124                ("product", json!("Widget")),
125                ("revenue", json!(100.0)),
126                ("units", json!(10.0)),
127            ]),
128            make_row(vec![
129                ("region", json!("North")),
130                ("product", json!("Gadget")),
131                ("revenue", json!(200.0)),
132                ("units", json!(15.0)),
133            ]),
134            make_row(vec![
135                ("region", json!("South")),
136                ("product", json!("Widget")),
137                ("revenue", json!(150.0)),
138                ("units", json!(12.0)),
139            ]),
140            make_row(vec![
141                ("region", json!("South")),
142                ("product", json!("Widget")),
143                ("revenue", json!(50.0)),
144                ("units", json!(5.0)),
145            ]),
146            make_row(vec![
147                ("region", json!("East")),
148                ("product", json!("Gadget")),
149                ("revenue", json!(300.0)),
150                ("units", json!(20.0)),
151            ]),
152        ]
153    }
154
155    fn sales_data() -> DataTable {
156        DataTable::from_rows(&sales_rows()).unwrap()
157    }
158
159    #[tokio::test]
160    async fn test_full_pipeline_aggregate() {
161        let data = sales_data();
162        let spec = TransformSpec {
163            sql: None,
164            forecast: None,
165            aggregate: Some(AggregateSpec {
166                dimensions: vec![Dimension::Simple("region".to_string())],
167                measures: vec![Measure {
168                    column: Some("revenue".to_string()),
169                    aggregation: Some("sum".to_string()),
170                    name: "total_revenue".to_string(),
171                    expression: None,
172                }],
173                filters: None,
174                sort: Some(vec![SortSpec {
175                    field: "total_revenue".to_string(),
176                    direction: Some("desc".to_string()),
177                }]),
178                limit: None,
179            }),
180        };
181
182        let transform = DataFusionTransform;
183        let context = TransformContext::default();
184        let result = transform.transform(data, &spec, &context).await.unwrap();
185
186        assert_eq!(result.data.num_rows(), 3, "Should have 3 regions");
187
188        // Results should be sorted descending by total_revenue
189        let rows = result.data.to_rows();
190        let revenues: Vec<f64> = rows
191            .iter()
192            .map(|r| r.get("total_revenue").unwrap().as_f64().unwrap())
193            .collect();
194
195        // North=300, East=300, South=200
196        assert!(
197            revenues[0] >= revenues[1],
198            "First should be >= second: {:?}",
199            revenues
200        );
201        assert!(
202            revenues[1] >= revenues[2],
203            "Second should be >= third: {:?}",
204            revenues
205        );
206        assert_eq!(revenues[2], 200.0, "South total should be 200");
207    }
208
209    #[tokio::test]
210    async fn test_full_pipeline_forecast() {
211        // Create time series data (linear: y = 10 + 2x)
212        let rows: Vec<Row> = (0..20)
213            .map(|i| {
214                make_row(vec![
215                    ("timestamp", json!(1000 + i)),
216                    ("value", json!(10.0 + 2.0 * i as f64)),
217                ])
218            })
219            .collect();
220        let data = DataTable::from_rows(&rows).unwrap();
221
222        let spec = TransformSpec {
223            sql: None,
224            aggregate: None,
225            forecast: Some(ForecastSpec {
226                timestamp: "timestamp".to_string(),
227                value: "value".to_string(),
228                horizon: Some(5),
229                confidence_level: Some(0.95),
230                model: Some("linear".to_string()),
231                group_by: None,
232            }),
233        };
234
235        let transform = DataFusionTransform;
236        let context = TransformContext::default();
237        let result = transform.transform(data, &spec, &context).await.unwrap();
238
239        // Should have 20 historical + 5 forecast rows
240        assert_eq!(
241            result.data.num_rows(),
242            25,
243            "Should have 25 rows (20 historical + 5 forecast)"
244        );
245
246        // Convert to rows for detailed field assertions
247        let result_rows = result.data.to_rows();
248
249        // Check that forecast rows have is_forecast=true
250        let forecast_rows: Vec<&Row> = result_rows
251            .iter()
252            .filter(|r| r.get("is_forecast").and_then(|v| v.as_bool()) == Some(true))
253            .collect();
254        assert_eq!(forecast_rows.len(), 5, "Should have 5 forecast rows");
255
256        // Forecast values should have forecast, lower_bound, upper_bound
257        for row in &forecast_rows {
258            assert!(
259                row.get("forecast").is_some(),
260                "Forecast row should have 'forecast' field"
261            );
262            assert!(
263                row.get("lower_bound").is_some(),
264                "Forecast row should have 'lower_bound' field"
265            );
266            assert!(
267                row.get("upper_bound").is_some(),
268                "Forecast row should have 'upper_bound' field"
269            );
270        }
271
272        // Historical rows should have is_forecast=false
273        let historical_rows: Vec<&Row> = result_rows
274            .iter()
275            .filter(|r| r.get("is_forecast").and_then(|v| v.as_bool()) == Some(false))
276            .collect();
277        assert_eq!(historical_rows.len(), 20, "Should have 20 historical rows");
278    }
279
280    #[tokio::test]
281    async fn test_full_pipeline_sql() {
282        let data = sales_data();
283        let spec = TransformSpec {
284            sql: Some(SqlSpec::Single(
285                "SELECT * FROM \"source\" WHERE \"revenue\" > 100".to_string(),
286            )),
287            aggregate: None,
288            forecast: None,
289        };
290
291        let transform = DataFusionTransform;
292        let context = TransformContext::default();
293        let result = transform.transform(data, &spec, &context).await.unwrap();
294
295        // Only rows with revenue > 100 should remain
296        let result_rows = result.data.to_rows();
297        assert!(
298            result_rows.len() < 5,
299            "Should filter out some rows, got {}",
300            result_rows.len()
301        );
302        for row in &result_rows {
303            let rev = row.get("revenue").unwrap().as_f64().unwrap();
304            assert!(rev > 100.0, "Revenue should be > 100, got {}", rev);
305        }
306    }
307
308    // --- Integration tests: ChartML + DataFusionTransform wired together ---
309
310    #[tokio::test]
311    async fn test_integration_chartml_async_aggregate() {
312        use chartml_core::element::{ChartElement, ViewBox};
313        use chartml_core::plugin::{ChartConfig, ChartRenderer};
314        use chartml_core::ChartML;
315
316        struct MockBarRenderer;
317        impl ChartRenderer for MockBarRenderer {
318            fn render(
319                &self,
320                data: &DataTable,
321                _config: &ChartConfig,
322            ) -> Result<ChartElement, chartml_core::error::ChartError> {
323                // Verify the data was actually transformed (aggregated)
324                // We expect 3 region groups from the sales_data
325                assert_eq!(data.num_rows(), 3, "Expected 3 aggregated groups, got {}", data.num_rows());
326
327                // Verify sort order (desc by total_revenue)
328                let revenues: Vec<f64> = (0..data.num_rows())
329                    .filter_map(|i| data.get_f64(i, "total_revenue"))
330                    .collect();
331                assert!(revenues[0] >= revenues[1], "Should be sorted desc");
332                assert!(revenues[1] >= revenues[2], "Should be sorted desc");
333
334                Ok(ChartElement::Svg {
335                    viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
336                    width: Some(800.0),
337                    height: Some(400.0),
338                    class: "mock-bar".to_string(),
339                    children: vec![],
340                })
341            }
342        }
343
344        let mut chartml = ChartML::new();
345        chartml.register_renderer("bar", MockBarRenderer);
346        chartml.register_transform(DataFusionTransform);
347
348        let yaml = r#"
349type: chart
350version: 1
351title: Revenue by Region
352data:
353  provider: inline
354  rows: []
355transform:
356  aggregate:
357    dimensions:
358      - region
359    measures:
360      - column: revenue
361        aggregation: sum
362        name: total_revenue
363    sort:
364      - field: total_revenue
365        direction: desc
366visualize:
367  type: bar
368  columns: region
369  rows: total_revenue
370"#;
371
372        let data = sales_data();
373        let result = chartml.render_from_yaml_with_data_async(yaml, data).await;
374        assert!(result.is_ok(), "Async render failed: {:?}", result.err());
375
376        match result.unwrap() {
377            ChartElement::Svg { class, .. } => {
378                assert_eq!(class, "mock-bar");
379            }
380            other => panic!("Expected Svg element, got {:?}", other),
381        }
382    }
383
384    #[tokio::test]
385    async fn test_integration_chartml_async_sql() {
386        use chartml_core::element::{ChartElement, ViewBox};
387        use chartml_core::plugin::{ChartConfig, ChartRenderer};
388        use chartml_core::ChartML;
389
390        struct MockBarRenderer;
391        impl ChartRenderer for MockBarRenderer {
392            fn render(
393                &self,
394                data: &DataTable,
395                _config: &ChartConfig,
396            ) -> Result<ChartElement, chartml_core::error::ChartError> {
397                // SQL filter: revenue > 100 should leave 3 rows (200, 150, 300)
398                assert_eq!(data.num_rows(), 3, "Expected 3 rows after SQL filter, got {}", data.num_rows());
399                for i in 0..data.num_rows() {
400                    let rev = data.get_f64(i, "revenue").unwrap();
401                    assert!(rev > 100.0, "Revenue should be > 100, got {}", rev);
402                }
403
404                Ok(ChartElement::Svg {
405                    viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
406                    width: Some(800.0),
407                    height: Some(400.0),
408                    class: "mock-bar".to_string(),
409                    children: vec![],
410                })
411            }
412        }
413
414        let mut chartml = ChartML::new();
415        chartml.register_renderer("bar", MockBarRenderer);
416        chartml.register_transform(DataFusionTransform);
417
418        let yaml = r#"
419type: chart
420version: 1
421title: High Revenue Items
422data:
423  provider: inline
424  rows: []
425transform:
426  sql: "SELECT * FROM \"source\" WHERE \"revenue\" > 100"
427visualize:
428  type: bar
429  columns: region
430  rows: revenue
431"#;
432
433        let data = sales_data();
434        let result = chartml.render_from_yaml_with_data_async(yaml, data).await;
435        assert!(result.is_ok(), "Async SQL render failed: {:?}", result.err());
436    }
437
438    #[tokio::test]
439    async fn test_integration_chartml_async_no_middleware_error() {
440        use chartml_core::element::{ChartElement, ViewBox};
441        use chartml_core::plugin::{ChartConfig, ChartRenderer};
442        use chartml_core::ChartML;
443
444        struct MockBarRenderer;
445        impl ChartRenderer for MockBarRenderer {
446            fn render(
447                &self,
448                _data: &DataTable,
449                _config: &ChartConfig,
450            ) -> Result<ChartElement, chartml_core::error::ChartError> {
451                Ok(ChartElement::Svg {
452                    viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
453                    width: Some(800.0),
454                    height: Some(400.0),
455                    class: "mock".to_string(),
456                    children: vec![],
457                })
458            }
459        }
460
461        let mut chartml = ChartML::new();
462        chartml.register_renderer("bar", MockBarRenderer);
463        // Note: NOT registering a TransformMiddleware
464
465        let yaml = r#"
466type: chart
467version: 1
468title: Test
469data:
470  provider: inline
471  rows: []
472transform:
473  sql: "SELECT * FROM source"
474visualize:
475  type: bar
476  columns: x
477  rows: y
478"#;
479
480        let empty = DataTable::from_rows(&[]).unwrap();
481        let result = chartml.render_from_yaml_with_data_async(yaml, empty).await;
482        assert!(result.is_err(), "Should fail when sql transform used without middleware");
483        let err = result.unwrap_err().to_string();
484        assert!(err.contains("no TransformMiddleware is registered"), "Error should mention missing middleware, got: {}", err);
485    }
486
487    #[tokio::test]
488    async fn test_integration_chartml_async_aggregate_only_uses_builtin() {
489        // When only aggregate is specified (no sql/forecast), the async method
490        // should fall back to the built-in sync apply_transforms — no middleware needed.
491        use chartml_core::element::{ChartElement, ViewBox};
492        use chartml_core::plugin::{ChartConfig, ChartRenderer};
493        use chartml_core::ChartML;
494
495        struct MockBarRenderer;
496        impl ChartRenderer for MockBarRenderer {
497            fn render(
498                &self,
499                data: &DataTable,
500                _config: &ChartConfig,
501            ) -> Result<ChartElement, chartml_core::error::ChartError> {
502                // Should be aggregated by the built-in transform
503                assert_eq!(data.num_rows(), 3, "Expected 3 aggregated groups, got {}", data.num_rows());
504                Ok(ChartElement::Svg {
505                    viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
506                    width: Some(800.0),
507                    height: Some(400.0),
508                    class: "mock-bar".to_string(),
509                    children: vec![],
510                })
511            }
512        }
513
514        let mut chartml = ChartML::new();
515        chartml.register_renderer("bar", MockBarRenderer);
516        // Note: NOT registering a TransformMiddleware — aggregate-only should still work
517
518        let yaml = r#"
519type: chart
520version: 1
521title: Revenue by Region
522data:
523  provider: inline
524  rows: []
525transform:
526  aggregate:
527    dimensions:
528      - region
529    measures:
530      - column: revenue
531        aggregation: sum
532        name: total_revenue
533visualize:
534  type: bar
535  columns: region
536  rows: total_revenue
537"#;
538
539        let data = sales_data();
540        let result = chartml.render_from_yaml_with_data_async(yaml, data).await;
541        assert!(result.is_ok(), "Aggregate-only async render should work without middleware: {:?}", result.err());
542    }
543}