Skip to main content

chartml_datafusion/stages/
sql_stage.rs

1//! SQL stage: placeholder replacement + execution via DataFusion.
2
3use chartml_core::error::ChartError;
4use chartml_core::spec::SqlSpec;
5use datafusion::prelude::*;
6
7/// Replace `{sourceName}` placeholders in a SQL string with the actual table name.
8fn replace_placeholders(sql: &str, source_name: &str, table_name: &str) -> String {
9    // Replace {sourceName} with the quoted table name
10    sql.replace(
11        &format!("{{{}}}", source_name),
12        &format!("\"{}\"", table_name),
13    )
14}
15
16/// Execute the SQL stage.
17///
18/// For `SqlSpec::Single`: execute the single statement and register result.
19/// For `SqlSpec::Multiple`: execute all but last as setup, last becomes result.
20///
21/// Returns the name of the output table registered in the session context.
22pub async fn execute(
23    ctx: &SessionContext,
24    current_table: &str,
25    spec: &SqlSpec,
26) -> Result<String, ChartError> {
27    let statements: Vec<String> = match spec {
28        SqlSpec::Single(s) => vec![s.clone()],
29        SqlSpec::Multiple(v) => v.clone(),
30    };
31
32    if statements.is_empty() {
33        return Err(ChartError::DataError(
34            "SQL stage: must contain at least one SQL statement".to_string(),
35        ));
36    }
37
38    // Replace placeholders in all statements
39    // We use "source" as the default placeholder name and also replace
40    // the current_table name directly
41    let resolved: Vec<String> = statements
42        .iter()
43        .map(|stmt| {
44            let mut s = replace_placeholders(stmt, "source", current_table);
45            s = replace_placeholders(&s, "sourceName", current_table);
46            s
47        })
48        .collect();
49
50    // Execute setup statements (all but the last)
51    for sql in &resolved[..resolved.len() - 1] {
52        ctx.sql(sql)
53            .await
54            .map_err(|e| ChartError::DataError(format!("SQL stage setup error: {}", e)))?
55            .collect()
56            .await
57            .map_err(|e| ChartError::DataError(format!("SQL stage setup collect error: {}", e)))?;
58    }
59
60    // Execute the final statement and register as a new table
61    let final_sql = &resolved[resolved.len() - 1];
62    let output_table = format!("__stage_sql_{}", current_table);
63
64    let df = ctx
65        .sql(final_sql)
66        .await
67        .map_err(|e| ChartError::DataError(format!("SQL stage error: {}", e)))?;
68
69    let batches = df
70        .collect()
71        .await
72        .map_err(|e| ChartError::DataError(format!("SQL stage collect error: {}", e)))?;
73
74    let schema = if let Some(first) = batches.first() {
75        first.schema()
76    } else {
77        return Err(ChartError::DataError(
78            "SQL stage returned no results".to_string(),
79        ));
80    };
81
82    let mem_table = datafusion::datasource::MemTable::try_new(schema, vec![batches])
83        .map_err(|e| ChartError::DataError(format!("SQL stage MemTable error: {}", e)))?;
84
85    ctx.register_table(&output_table, std::sync::Arc::new(mem_table))
86        .map_err(|e| ChartError::DataError(format!("SQL stage register error: {}", e)))?;
87
88    Ok(output_table)
89}