1use crate::error::{GraphError, Result};
11use crate::query::{normalize_record_batch, normalize_schema};
12use arrow_array::RecordBatch;
13use datafusion::datasource::MemTable;
14use datafusion::execution::context::SessionContext;
15use std::collections::HashMap;
16use std::sync::Arc;
17
18#[derive(Debug, Clone)]
40pub struct SqlQuery {
41 sql: String,
42}
43
44impl SqlQuery {
45 pub fn new(sql: &str) -> Self {
49 Self {
50 sql: sql.to_string(),
51 }
52 }
53
54 pub fn sql(&self) -> &str {
56 &self.sql
57 }
58
59 pub async fn execute(&self, datasets: HashMap<String, RecordBatch>) -> Result<RecordBatch> {
70 let ctx = self.build_context(datasets)?;
71 self.execute_with_context(ctx).await
72 }
73
74 pub async fn execute_with_context(&self, ctx: SessionContext) -> Result<RecordBatch> {
79 let df = ctx
80 .sql(&self.sql)
81 .await
82 .map_err(|e| GraphError::PlanError {
83 message: format!("SQL execution error: {}", e),
84 location: snafu::Location::new(file!(), line!(), column!()),
85 })?;
86
87 let batches = df.collect().await.map_err(|e| GraphError::PlanError {
88 message: format!("Failed to collect SQL results: {}", e),
89 location: snafu::Location::new(file!(), line!(), column!()),
90 })?;
91
92 if batches.is_empty() {
93 let schema = df_schema_from_ctx(&ctx, &self.sql).await?;
95 return Ok(RecordBatch::new_empty(schema));
96 }
97
98 let schema = batches[0].schema();
99 arrow::compute::concat_batches(&schema, &batches).map_err(|e| GraphError::PlanError {
100 message: format!("Failed to concatenate result batches: {}", e),
101 location: snafu::Location::new(file!(), line!(), column!()),
102 })
103 }
104
105 pub async fn explain(&self, datasets: HashMap<String, RecordBatch>) -> Result<String> {
109 let ctx = self.build_context(datasets)?;
110
111 let df = ctx
112 .sql(&self.sql)
113 .await
114 .map_err(|e| GraphError::PlanError {
115 message: format!("SQL explain error: {}", e),
116 location: snafu::Location::new(file!(), line!(), column!()),
117 })?;
118
119 let logical_plan = df.logical_plan();
120
121 let physical_plan = ctx
122 .state()
123 .create_physical_plan(logical_plan)
124 .await
125 .map_err(|e| GraphError::PlanError {
126 message: format!("Failed to create physical plan: {}", e),
127 location: snafu::Location::new(file!(), line!(), column!()),
128 })?;
129
130 let physical_plan_str = datafusion::physical_plan::displayable(physical_plan.as_ref())
131 .indent(true)
132 .to_string();
133
134 Ok(format!(
135 "== Logical Plan ==\n{}\n\n== Physical Plan ==\n{}",
136 logical_plan.display_indent(),
137 physical_plan_str,
138 ))
139 }
140
141 fn build_context(&self, datasets: HashMap<String, RecordBatch>) -> Result<SessionContext> {
143 let ctx = SessionContext::new();
144
145 for (name, batch) in datasets {
146 let normalized_batch = normalize_record_batch(&batch)?;
147 let schema = normalized_batch.schema();
148 let mem_table = Arc::new(
149 MemTable::try_new(schema, vec![vec![normalized_batch]]).map_err(|e| {
150 GraphError::PlanError {
151 message: format!("Failed to create MemTable for {}: {}", name, e),
152 location: snafu::Location::new(file!(), line!(), column!()),
153 }
154 })?,
155 );
156
157 let normalized_name = name.to_lowercase();
158 ctx.register_table(&normalized_name, mem_table)
159 .map_err(|e| GraphError::PlanError {
160 message: format!("Failed to register table {}: {}", name, e),
161 location: snafu::Location::new(file!(), line!(), column!()),
162 })?;
163 }
164
165 Ok(ctx)
166 }
167}
168
169async fn df_schema_from_ctx(ctx: &SessionContext, sql: &str) -> Result<Arc<arrow_schema::Schema>> {
171 let df = ctx.sql(sql).await.map_err(|e| GraphError::PlanError {
172 message: format!("Failed to plan SQL for schema: {}", e),
173 location: snafu::Location::new(file!(), line!(), column!()),
174 })?;
175 let arrow_schema = Arc::new(arrow_schema::Schema::from(df.schema()));
176 normalize_schema(arrow_schema)
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use arrow_array::{Float64Array, Int64Array, StringArray};
183 use arrow_schema::{DataType, Field, Schema};
184
185 fn person_batch() -> RecordBatch {
186 let schema = Arc::new(Schema::new(vec![
187 Field::new("id", DataType::Int64, false),
188 Field::new("name", DataType::Utf8, false),
189 Field::new("age", DataType::Int64, false),
190 Field::new("city", DataType::Utf8, false),
191 ]));
192 RecordBatch::try_new(
193 schema,
194 vec![
195 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
196 Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
197 Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
198 Arc::new(StringArray::from(vec![
199 "New York",
200 "San Francisco",
201 "New York",
202 "Chicago",
203 ])),
204 ],
205 )
206 .unwrap()
207 }
208
209 fn datasets_with(name: &str, batch: RecordBatch) -> HashMap<String, RecordBatch> {
210 let mut datasets = HashMap::new();
211 datasets.insert(name.to_string(), batch);
212 datasets
213 }
214
215 #[tokio::test]
216 async fn test_basic_select() {
217 let query = SqlQuery::new("SELECT name, age FROM person WHERE age > 30 ORDER BY age");
218 let result = query
219 .execute(datasets_with("person", person_batch()))
220 .await
221 .unwrap();
222
223 let names: Vec<&str> = result
224 .column_by_name("name")
225 .unwrap()
226 .as_any()
227 .downcast_ref::<StringArray>()
228 .unwrap()
229 .iter()
230 .map(|v| v.unwrap())
231 .collect();
232 assert_eq!(names, vec!["Bob", "David"]);
233 }
234
235 #[tokio::test]
236 async fn test_select_star() {
237 let query = SqlQuery::new("SELECT * FROM person");
238 let result = query
239 .execute(datasets_with("person", person_batch()))
240 .await
241 .unwrap();
242 assert_eq!(result.num_rows(), 4);
243 assert_eq!(result.num_columns(), 4);
244 }
245
246 #[tokio::test]
247 async fn test_limit() {
248 let query = SqlQuery::new("SELECT name FROM person ORDER BY name LIMIT 2");
249 let result = query
250 .execute(datasets_with("person", person_batch()))
251 .await
252 .unwrap();
253 assert_eq!(result.num_rows(), 2);
254 }
255
256 #[tokio::test]
257 async fn test_aggregation() {
258 let query = SqlQuery::new(
259 "SELECT COUNT(*) as cnt, AVG(age) as avg_age, SUM(age) as total_age FROM person",
260 );
261 let result = query
262 .execute(datasets_with("person", person_batch()))
263 .await
264 .unwrap();
265 assert_eq!(result.num_rows(), 1);
266
267 let cnt = result
268 .column_by_name("cnt")
269 .unwrap()
270 .as_any()
271 .downcast_ref::<Int64Array>()
272 .unwrap()
273 .value(0);
274 assert_eq!(cnt, 4);
275
276 let avg_age = result
277 .column_by_name("avg_age")
278 .unwrap()
279 .as_any()
280 .downcast_ref::<Float64Array>()
281 .unwrap()
282 .value(0);
283 assert!((avg_age - 33.25).abs() < 0.01);
284 }
285
286 #[tokio::test]
287 async fn test_group_by() {
288 let query = SqlQuery::new(
289 "SELECT city, COUNT(*) as cnt FROM person GROUP BY city ORDER BY cnt DESC",
290 );
291 let result = query
292 .execute(datasets_with("person", person_batch()))
293 .await
294 .unwrap();
295
296 let cities: Vec<&str> = result
297 .column_by_name("city")
298 .unwrap()
299 .as_any()
300 .downcast_ref::<StringArray>()
301 .unwrap()
302 .iter()
303 .map(|v| v.unwrap())
304 .collect();
305 assert_eq!(cities[0], "New York");
307 }
308
309 #[tokio::test]
310 async fn test_invalid_sql() {
311 let query = SqlQuery::new("INVALID SQL STATEMENT");
312 let result = query.execute(datasets_with("person", person_batch())).await;
313 assert!(result.is_err());
314 }
315
316 #[tokio::test]
317 async fn test_explain() {
318 let query = SqlQuery::new("SELECT name FROM person WHERE age > 30");
319 let plan = query
320 .explain(datasets_with("person", person_batch()))
321 .await
322 .unwrap();
323 assert!(plan.contains("Logical Plan"));
324 assert!(plan.contains("Physical Plan"));
325 }
326
327 #[tokio::test]
328 async fn test_execute_with_context() {
329 let ctx = SessionContext::new();
331 let batch = person_batch();
332 let schema = batch.schema();
333 let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![batch]]).unwrap());
334 ctx.register_table("people", mem_table).unwrap();
335
336 let query = SqlQuery::new("SELECT name FROM people ORDER BY name LIMIT 1");
337 let result = query.execute_with_context(ctx).await.unwrap();
338
339 let names: Vec<&str> = result
340 .column_by_name("name")
341 .unwrap()
342 .as_any()
343 .downcast_ref::<StringArray>()
344 .unwrap()
345 .iter()
346 .map(|v| v.unwrap())
347 .collect();
348 assert_eq!(names, vec!["Alice"]);
349 }
350
351 #[tokio::test]
352 async fn test_sql_text_accessor() {
353 let query = SqlQuery::new("SELECT 1");
354 assert_eq!(query.sql(), "SELECT 1");
355 }
356}