Skip to main content

chartml_datafusion/stages/
aggregate_stage.rs

1//! Aggregate stage: compile AggregateSpec → SQL → DataFusion execution.
2
3use chartml_core::error::ChartError;
4use chartml_core::spec::AggregateSpec;
5use datafusion::prelude::*;
6
7use crate::sql_builder;
8
9/// Execute the aggregate stage.
10///
11/// Generates SQL from the aggregate spec, executes via DataFusion,
12/// and registers the result as a new table.
13///
14/// Returns the name of the output table.
15pub async fn execute(
16    ctx: &SessionContext,
17    current_table: &str,
18    spec: &AggregateSpec,
19) -> Result<String, ChartError> {
20    let sql = sql_builder::build_aggregate_sql(current_table, spec);
21    let output_table = format!("__stage_agg_{}", current_table);
22
23    let df = ctx
24        .sql(&sql)
25        .await
26        .map_err(|e| ChartError::DataError(format!("Aggregate stage SQL error: {}", e)))?;
27
28    let batches = df
29        .collect()
30        .await
31        .map_err(|e| ChartError::DataError(format!("Aggregate stage collect error: {}", e)))?;
32
33    let schema = if let Some(first) = batches.first() {
34        first.schema()
35    } else {
36        // Empty result — create an empty table with no schema
37        let empty_schema = arrow::datatypes::Schema::empty();
38        let mem_table = datafusion::datasource::MemTable::try_new(
39            std::sync::Arc::new(empty_schema),
40            vec![vec![]],
41        )
42        .map_err(|e| ChartError::DataError(format!("Aggregate stage MemTable error: {}", e)))?;
43        ctx.register_table(&output_table, std::sync::Arc::new(mem_table))
44            .map_err(|e| {
45                ChartError::DataError(format!("Aggregate stage register error: {}", e))
46            })?;
47        return Ok(output_table);
48    };
49
50    let mem_table = datafusion::datasource::MemTable::try_new(schema, vec![batches])
51        .map_err(|e| ChartError::DataError(format!("Aggregate stage MemTable error: {}", e)))?;
52
53    ctx.register_table(&output_table, std::sync::Arc::new(mem_table))
54        .map_err(|e| ChartError::DataError(format!("Aggregate stage register error: {}", e)))?;
55
56    Ok(output_table)
57}