1pub mod conversion;
7pub mod sql_builder;
8pub mod stages;
9
10use async_trait::async_trait;
11use chartml_core::data::DataTable;
12use chartml_core::error::ChartError;
13use chartml_core::plugin::transform::{TransformContext, TransformMiddleware, TransformResult};
14use chartml_core::spec::TransformSpec;
15use datafusion::prelude::*;
16use arrow::array::RecordBatch;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20pub struct DataFusionTransform;
27
28#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
29#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
30impl TransformMiddleware for DataFusionTransform {
31 async fn transform(
32 &self,
33 data: DataTable,
34 spec: &TransformSpec,
35 _context: &TransformContext,
36 ) -> Result<TransformResult, ChartError> {
37 let ctx = SessionContext::new();
38
39 let batch = data.record_batch().clone();
42 let schema = batch.schema();
43 let mem_table =
44 datafusion::datasource::MemTable::try_new(schema, vec![vec![batch]]).map_err(|e| {
45 ChartError::DataError(format!("Failed to create source MemTable: {}", e))
46 })?;
47 ctx.register_table("source", std::sync::Arc::new(mem_table))
48 .map_err(|e| {
49 ChartError::DataError(format!("Failed to register source table: {}", e))
50 })?;
51
52 let mut current_table = "source".to_string();
53
54 if let Some(ref sql_spec) = spec.sql {
56 current_table =
57 stages::sql_stage::execute(&ctx, ¤t_table, sql_spec).await?;
58 }
59
60 if let Some(ref agg_spec) = spec.aggregate {
62 current_table =
63 stages::aggregate_stage::execute(&ctx, ¤t_table, agg_spec).await?;
64 }
65
66 if let Some(ref forecast_spec) = spec.forecast {
68 current_table =
69 stages::forecast_stage::execute(&ctx, ¤t_table, forecast_spec).await?;
70 }
71
72 let df = ctx
74 .table(¤t_table)
75 .await
76 .map_err(|e| ChartError::DataError(format!("Failed to read result table: {}", e)))?;
77 let output_schema = Arc::new(df.schema().as_arrow().clone());
78 let batches = df
79 .collect()
80 .await
81 .map_err(|e| ChartError::DataError(format!("Failed to collect results: {}", e)))?;
82
83 if batches.is_empty() {
84 return Ok(TransformResult {
87 data: DataTable::from_record_batch(RecordBatch::new_empty(output_schema)),
88 metadata: HashMap::new(),
89 });
90 }
91
92 let result_batch = arrow::compute::concat_batches(
94 &output_schema,
95 &batches,
96 )
97 .map_err(|e| ChartError::DataError(format!("Failed to concat result batches: {}", e)))?;
98
99 Ok(TransformResult {
100 data: DataTable::from_record_batch(result_batch),
101 metadata: HashMap::new(),
102 })
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use chartml_core::data::Row;
110 use chartml_core::spec::*;
111 use serde_json::json;
112
113 fn make_row(pairs: Vec<(&str, serde_json::Value)>) -> Row {
114 pairs
115 .into_iter()
116 .map(|(k, v)| (k.to_string(), v))
117 .collect()
118 }
119
120 fn sales_rows() -> Vec<Row> {
121 vec![
122 make_row(vec![
123 ("region", json!("North")),
124 ("product", json!("Widget")),
125 ("revenue", json!(100.0)),
126 ("units", json!(10.0)),
127 ]),
128 make_row(vec![
129 ("region", json!("North")),
130 ("product", json!("Gadget")),
131 ("revenue", json!(200.0)),
132 ("units", json!(15.0)),
133 ]),
134 make_row(vec![
135 ("region", json!("South")),
136 ("product", json!("Widget")),
137 ("revenue", json!(150.0)),
138 ("units", json!(12.0)),
139 ]),
140 make_row(vec![
141 ("region", json!("South")),
142 ("product", json!("Widget")),
143 ("revenue", json!(50.0)),
144 ("units", json!(5.0)),
145 ]),
146 make_row(vec![
147 ("region", json!("East")),
148 ("product", json!("Gadget")),
149 ("revenue", json!(300.0)),
150 ("units", json!(20.0)),
151 ]),
152 ]
153 }
154
155 fn sales_data() -> DataTable {
156 DataTable::from_rows(&sales_rows()).unwrap()
157 }
158
159 #[tokio::test]
160 async fn test_full_pipeline_aggregate() {
161 let data = sales_data();
162 let spec = TransformSpec {
163 sql: None,
164 forecast: None,
165 aggregate: Some(AggregateSpec {
166 dimensions: vec![Dimension::Simple("region".to_string())],
167 measures: vec![Measure {
168 column: Some("revenue".to_string()),
169 aggregation: Some("sum".to_string()),
170 name: "total_revenue".to_string(),
171 expression: None,
172 }],
173 filters: None,
174 sort: Some(vec![SortSpec {
175 field: "total_revenue".to_string(),
176 direction: Some("desc".to_string()),
177 }]),
178 limit: None,
179 }),
180 };
181
182 let transform = DataFusionTransform;
183 let context = TransformContext::default();
184 let result = transform.transform(data, &spec, &context).await.unwrap();
185
186 assert_eq!(result.data.num_rows(), 3, "Should have 3 regions");
187
188 let rows = result.data.to_rows();
190 let revenues: Vec<f64> = rows
191 .iter()
192 .map(|r| r.get("total_revenue").unwrap().as_f64().unwrap())
193 .collect();
194
195 assert!(
197 revenues[0] >= revenues[1],
198 "First should be >= second: {:?}",
199 revenues
200 );
201 assert!(
202 revenues[1] >= revenues[2],
203 "Second should be >= third: {:?}",
204 revenues
205 );
206 assert_eq!(revenues[2], 200.0, "South total should be 200");
207 }
208
209 #[tokio::test]
210 async fn test_full_pipeline_forecast() {
211 let rows: Vec<Row> = (0..20)
213 .map(|i| {
214 make_row(vec![
215 ("timestamp", json!(1000 + i)),
216 ("value", json!(10.0 + 2.0 * i as f64)),
217 ])
218 })
219 .collect();
220 let data = DataTable::from_rows(&rows).unwrap();
221
222 let spec = TransformSpec {
223 sql: None,
224 aggregate: None,
225 forecast: Some(ForecastSpec {
226 timestamp: "timestamp".to_string(),
227 value: "value".to_string(),
228 horizon: Some(5),
229 confidence_level: Some(0.95),
230 model: Some("linear".to_string()),
231 group_by: None,
232 }),
233 };
234
235 let transform = DataFusionTransform;
236 let context = TransformContext::default();
237 let result = transform.transform(data, &spec, &context).await.unwrap();
238
239 assert_eq!(
241 result.data.num_rows(),
242 25,
243 "Should have 25 rows (20 historical + 5 forecast)"
244 );
245
246 let result_rows = result.data.to_rows();
248
249 let forecast_rows: Vec<&Row> = result_rows
251 .iter()
252 .filter(|r| r.get("is_forecast").and_then(|v| v.as_bool()) == Some(true))
253 .collect();
254 assert_eq!(forecast_rows.len(), 5, "Should have 5 forecast rows");
255
256 for row in &forecast_rows {
258 assert!(
259 row.get("forecast").is_some(),
260 "Forecast row should have 'forecast' field"
261 );
262 assert!(
263 row.get("lower_bound").is_some(),
264 "Forecast row should have 'lower_bound' field"
265 );
266 assert!(
267 row.get("upper_bound").is_some(),
268 "Forecast row should have 'upper_bound' field"
269 );
270 }
271
272 let historical_rows: Vec<&Row> = result_rows
274 .iter()
275 .filter(|r| r.get("is_forecast").and_then(|v| v.as_bool()) == Some(false))
276 .collect();
277 assert_eq!(historical_rows.len(), 20, "Should have 20 historical rows");
278 }
279
280 #[tokio::test]
281 async fn test_full_pipeline_sql() {
282 let data = sales_data();
283 let spec = TransformSpec {
284 sql: Some(SqlSpec::Single(
285 "SELECT * FROM \"source\" WHERE \"revenue\" > 100".to_string(),
286 )),
287 aggregate: None,
288 forecast: None,
289 };
290
291 let transform = DataFusionTransform;
292 let context = TransformContext::default();
293 let result = transform.transform(data, &spec, &context).await.unwrap();
294
295 let result_rows = result.data.to_rows();
297 assert!(
298 result_rows.len() < 5,
299 "Should filter out some rows, got {}",
300 result_rows.len()
301 );
302 for row in &result_rows {
303 let rev = row.get("revenue").unwrap().as_f64().unwrap();
304 assert!(rev > 100.0, "Revenue should be > 100, got {}", rev);
305 }
306 }
307
308 #[tokio::test]
311 async fn test_integration_chartml_async_aggregate() {
312 use chartml_core::element::{ChartElement, ViewBox};
313 use chartml_core::plugin::{ChartConfig, ChartRenderer};
314 use chartml_core::ChartML;
315
316 struct MockBarRenderer;
317 impl ChartRenderer for MockBarRenderer {
318 fn render(
319 &self,
320 data: &DataTable,
321 _config: &ChartConfig,
322 ) -> Result<ChartElement, chartml_core::error::ChartError> {
323 assert_eq!(data.num_rows(), 3, "Expected 3 aggregated groups, got {}", data.num_rows());
326
327 let revenues: Vec<f64> = (0..data.num_rows())
329 .filter_map(|i| data.get_f64(i, "total_revenue"))
330 .collect();
331 assert!(revenues[0] >= revenues[1], "Should be sorted desc");
332 assert!(revenues[1] >= revenues[2], "Should be sorted desc");
333
334 Ok(ChartElement::Svg {
335 viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
336 width: Some(800.0),
337 height: Some(400.0),
338 class: "mock-bar".to_string(),
339 children: vec![],
340 })
341 }
342 }
343
344 let mut chartml = ChartML::new();
345 chartml.register_renderer("bar", MockBarRenderer);
346 chartml.register_transform(DataFusionTransform);
347
348 let yaml = r#"
349type: chart
350version: 1
351title: Revenue by Region
352data:
353 provider: inline
354 rows: []
355transform:
356 aggregate:
357 dimensions:
358 - region
359 measures:
360 - column: revenue
361 aggregation: sum
362 name: total_revenue
363 sort:
364 - field: total_revenue
365 direction: desc
366visualize:
367 type: bar
368 columns: region
369 rows: total_revenue
370"#;
371
372 let data = sales_data();
373 let result = chartml.render_from_yaml_with_data_async(yaml, data).await;
374 assert!(result.is_ok(), "Async render failed: {:?}", result.err());
375
376 match result.unwrap() {
377 ChartElement::Svg { class, .. } => {
378 assert_eq!(class, "mock-bar");
379 }
380 other => panic!("Expected Svg element, got {:?}", other),
381 }
382 }
383
384 #[tokio::test]
385 async fn test_integration_chartml_async_sql() {
386 use chartml_core::element::{ChartElement, ViewBox};
387 use chartml_core::plugin::{ChartConfig, ChartRenderer};
388 use chartml_core::ChartML;
389
390 struct MockBarRenderer;
391 impl ChartRenderer for MockBarRenderer {
392 fn render(
393 &self,
394 data: &DataTable,
395 _config: &ChartConfig,
396 ) -> Result<ChartElement, chartml_core::error::ChartError> {
397 assert_eq!(data.num_rows(), 3, "Expected 3 rows after SQL filter, got {}", data.num_rows());
399 for i in 0..data.num_rows() {
400 let rev = data.get_f64(i, "revenue").unwrap();
401 assert!(rev > 100.0, "Revenue should be > 100, got {}", rev);
402 }
403
404 Ok(ChartElement::Svg {
405 viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
406 width: Some(800.0),
407 height: Some(400.0),
408 class: "mock-bar".to_string(),
409 children: vec![],
410 })
411 }
412 }
413
414 let mut chartml = ChartML::new();
415 chartml.register_renderer("bar", MockBarRenderer);
416 chartml.register_transform(DataFusionTransform);
417
418 let yaml = r#"
419type: chart
420version: 1
421title: High Revenue Items
422data:
423 provider: inline
424 rows: []
425transform:
426 sql: "SELECT * FROM \"source\" WHERE \"revenue\" > 100"
427visualize:
428 type: bar
429 columns: region
430 rows: revenue
431"#;
432
433 let data = sales_data();
434 let result = chartml.render_from_yaml_with_data_async(yaml, data).await;
435 assert!(result.is_ok(), "Async SQL render failed: {:?}", result.err());
436 }
437
438 #[tokio::test]
439 async fn test_integration_chartml_async_no_middleware_error() {
440 use chartml_core::element::{ChartElement, ViewBox};
441 use chartml_core::plugin::{ChartConfig, ChartRenderer};
442 use chartml_core::ChartML;
443
444 struct MockBarRenderer;
445 impl ChartRenderer for MockBarRenderer {
446 fn render(
447 &self,
448 _data: &DataTable,
449 _config: &ChartConfig,
450 ) -> Result<ChartElement, chartml_core::error::ChartError> {
451 Ok(ChartElement::Svg {
452 viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
453 width: Some(800.0),
454 height: Some(400.0),
455 class: "mock".to_string(),
456 children: vec![],
457 })
458 }
459 }
460
461 let mut chartml = ChartML::new();
462 chartml.register_renderer("bar", MockBarRenderer);
463 let yaml = r#"
466type: chart
467version: 1
468title: Test
469data:
470 provider: inline
471 rows: []
472transform:
473 sql: "SELECT * FROM source"
474visualize:
475 type: bar
476 columns: x
477 rows: y
478"#;
479
480 let empty = DataTable::from_rows(&[]).unwrap();
481 let result = chartml.render_from_yaml_with_data_async(yaml, empty).await;
482 assert!(result.is_err(), "Should fail when sql transform used without middleware");
483 let err = result.unwrap_err().to_string();
484 assert!(err.contains("no TransformMiddleware is registered"), "Error should mention missing middleware, got: {}", err);
485 }
486
487 #[tokio::test]
488 async fn test_integration_chartml_async_aggregate_only_uses_builtin() {
489 use chartml_core::element::{ChartElement, ViewBox};
492 use chartml_core::plugin::{ChartConfig, ChartRenderer};
493 use chartml_core::ChartML;
494
495 struct MockBarRenderer;
496 impl ChartRenderer for MockBarRenderer {
497 fn render(
498 &self,
499 data: &DataTable,
500 _config: &ChartConfig,
501 ) -> Result<ChartElement, chartml_core::error::ChartError> {
502 assert_eq!(data.num_rows(), 3, "Expected 3 aggregated groups, got {}", data.num_rows());
504 Ok(ChartElement::Svg {
505 viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
506 width: Some(800.0),
507 height: Some(400.0),
508 class: "mock-bar".to_string(),
509 children: vec![],
510 })
511 }
512 }
513
514 let mut chartml = ChartML::new();
515 chartml.register_renderer("bar", MockBarRenderer);
516 let yaml = r#"
519type: chart
520version: 1
521title: Revenue by Region
522data:
523 provider: inline
524 rows: []
525transform:
526 aggregate:
527 dimensions:
528 - region
529 measures:
530 - column: revenue
531 aggregation: sum
532 name: total_revenue
533visualize:
534 type: bar
535 columns: region
536 rows: total_revenue
537"#;
538
539 let data = sales_data();
540 let result = chartml.render_from_yaml_with_data_async(yaml, data).await;
541 assert!(result.is_ok(), "Aggregate-only async render should work without middleware: {:?}", result.err());
542 }
543}