chartml_datafusion/stages/
sql_stage.rs1use chartml_core::error::ChartError;
4use chartml_core::spec::SqlSpec;
5use datafusion::prelude::*;
6
7fn replace_placeholders(sql: &str, source_name: &str, table_name: &str) -> String {
9 sql.replace(
11 &format!("{{{}}}", source_name),
12 &format!("\"{}\"", table_name),
13 )
14}
15
16pub async fn execute(
23 ctx: &SessionContext,
24 current_table: &str,
25 spec: &SqlSpec,
26) -> Result<String, ChartError> {
27 let statements: Vec<String> = match spec {
28 SqlSpec::Single(s) => vec![s.clone()],
29 SqlSpec::Multiple(v) => v.clone(),
30 };
31
32 if statements.is_empty() {
33 return Err(ChartError::DataError(
34 "SQL stage: must contain at least one SQL statement".to_string(),
35 ));
36 }
37
38 let resolved: Vec<String> = statements
42 .iter()
43 .map(|stmt| {
44 let mut s = replace_placeholders(stmt, "source", current_table);
45 s = replace_placeholders(&s, "sourceName", current_table);
46 s
47 })
48 .collect();
49
50 for sql in &resolved[..resolved.len() - 1] {
52 ctx.sql(sql)
53 .await
54 .map_err(|e| ChartError::DataError(format!("SQL stage setup error: {}", e)))?
55 .collect()
56 .await
57 .map_err(|e| ChartError::DataError(format!("SQL stage setup collect error: {}", e)))?;
58 }
59
60 let final_sql = &resolved[resolved.len() - 1];
62 let output_table = format!("__stage_sql_{}", current_table);
63
64 let df = ctx
65 .sql(final_sql)
66 .await
67 .map_err(|e| ChartError::DataError(format!("SQL stage error: {}", e)))?;
68
69 let batches = df
70 .collect()
71 .await
72 .map_err(|e| ChartError::DataError(format!("SQL stage collect error: {}", e)))?;
73
74 let schema = if let Some(first) = batches.first() {
75 first.schema()
76 } else {
77 return Err(ChartError::DataError(
78 "SQL stage returned no results".to_string(),
79 ));
80 };
81
82 let mem_table = datafusion::datasource::MemTable::try_new(schema, vec![batches])
83 .map_err(|e| ChartError::DataError(format!("SQL stage MemTable error: {}", e)))?;
84
85 ctx.register_table(&output_table, std::sync::Arc::new(mem_table))
86 .map_err(|e| ChartError::DataError(format!("SQL stage register error: {}", e)))?;
87
88 Ok(output_table)
89}