chartml_datafusion/stages/
aggregate_stage.rs1use chartml_core::error::ChartError;
4use chartml_core::spec::AggregateSpec;
5use datafusion::prelude::*;
6
7use crate::sql_builder;
8
9pub async fn execute(
16 ctx: &SessionContext,
17 current_table: &str,
18 spec: &AggregateSpec,
19) -> Result<String, ChartError> {
20 let sql = sql_builder::build_aggregate_sql(current_table, spec);
21 let output_table = format!("__stage_agg_{}", current_table);
22
23 let df = ctx
24 .sql(&sql)
25 .await
26 .map_err(|e| ChartError::DataError(format!("Aggregate stage SQL error: {}", e)))?;
27
28 let batches = df
29 .collect()
30 .await
31 .map_err(|e| ChartError::DataError(format!("Aggregate stage collect error: {}", e)))?;
32
33 let schema = if let Some(first) = batches.first() {
34 first.schema()
35 } else {
36 let empty_schema = arrow::datatypes::Schema::empty();
38 let mem_table = datafusion::datasource::MemTable::try_new(
39 std::sync::Arc::new(empty_schema),
40 vec![vec![]],
41 )
42 .map_err(|e| ChartError::DataError(format!("Aggregate stage MemTable error: {}", e)))?;
43 ctx.register_table(&output_table, std::sync::Arc::new(mem_table))
44 .map_err(|e| {
45 ChartError::DataError(format!("Aggregate stage register error: {}", e))
46 })?;
47 return Ok(output_table);
48 };
49
50 let mem_table = datafusion::datasource::MemTable::try_new(schema, vec![batches])
51 .map_err(|e| ChartError::DataError(format!("Aggregate stage MemTable error: {}", e)))?;
52
53 ctx.register_table(&output_table, std::sync::Arc::new(mem_table))
54 .map_err(|e| ChartError::DataError(format!("Aggregate stage register error: {}", e)))?;
55
56 Ok(output_table)
57}