Skip to main content

lance_graph/
sql_query.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Direct SQL query interface for Lance datasets
5//!
6//! This module provides a way to execute standard SQL queries directly against
7//! in-memory datasets (as RecordBatches) or a pre-configured DataFusion SessionContext,
8//! without requiring a GraphConfig or Cypher parsing.
9
10use crate::error::{GraphError, Result};
11use crate::query::{normalize_record_batch, normalize_schema};
12use arrow_array::RecordBatch;
13use datafusion::datasource::MemTable;
14use datafusion::execution::context::SessionContext;
15use std::collections::HashMap;
16use std::sync::Arc;
17
18/// A SQL query that can be executed against in-memory datasets or a DataFusion SessionContext.
19///
20/// Unlike `CypherQuery`, this does not require a `GraphConfig` — users write standard SQL
21/// with explicit JOINs against their node/relationship tables.
22///
23/// # Example
24///
25/// ```no_run
26/// use lance_graph::SqlQuery;
27/// use arrow_array::RecordBatch;
28/// use std::collections::HashMap;
29///
30/// # async fn example() -> lance_graph::Result<()> {
31/// let mut datasets: HashMap<String, RecordBatch> = HashMap::new();
32/// // datasets.insert("person".to_string(), person_batch);
33///
34/// let query = SqlQuery::new("SELECT name, age FROM person WHERE age > 30");
35/// // let result = query.execute(datasets).await?;
36/// # Ok(())
37/// # }
38/// ```
39#[derive(Debug, Clone)]
40pub struct SqlQuery {
41    sql: String,
42}
43
44impl SqlQuery {
45    /// Create a new SQL query from a SQL string.
46    ///
47    /// No parsing is done at construction time — the SQL is validated when executed.
48    pub fn new(sql: &str) -> Self {
49        Self {
50            sql: sql.to_string(),
51        }
52    }
53
54    /// Get the SQL query text.
55    pub fn sql(&self) -> &str {
56        &self.sql
57    }
58
59    /// Execute the SQL query against in-memory datasets.
60    ///
61    /// Each entry in `datasets` is registered as a table in a fresh DataFusion
62    /// SessionContext. Table names are lowercased for consistency.
63    ///
64    /// # Arguments
65    /// * `datasets` - HashMap of table name to RecordBatch
66    ///
67    /// # Returns
68    /// A single `RecordBatch` containing all result rows.
69    pub async fn execute(&self, datasets: HashMap<String, RecordBatch>) -> Result<RecordBatch> {
70        let ctx = self.build_context(datasets)?;
71        self.execute_with_context(ctx).await
72    }
73
74    /// Execute the SQL query against a pre-configured DataFusion SessionContext.
75    ///
76    /// Use this when tables are already registered (e.g., CSV/Parquet files,
77    /// external data sources, or a context shared across queries).
78    pub async fn execute_with_context(&self, ctx: SessionContext) -> Result<RecordBatch> {
79        let df = ctx
80            .sql(&self.sql)
81            .await
82            .map_err(|e| GraphError::PlanError {
83                message: format!("SQL execution error: {}", e),
84                location: snafu::Location::new(file!(), line!(), column!()),
85            })?;
86
87        let batches = df.collect().await.map_err(|e| GraphError::PlanError {
88            message: format!("Failed to collect SQL results: {}", e),
89            location: snafu::Location::new(file!(), line!(), column!()),
90        })?;
91
92        if batches.is_empty() {
93            // Return an empty batch with the schema from the logical plan
94            let schema = df_schema_from_ctx(&ctx, &self.sql).await?;
95            return Ok(RecordBatch::new_empty(schema));
96        }
97
98        let schema = batches[0].schema();
99        arrow::compute::concat_batches(&schema, &batches).map_err(|e| GraphError::PlanError {
100            message: format!("Failed to concatenate result batches: {}", e),
101            location: snafu::Location::new(file!(), line!(), column!()),
102        })
103    }
104
105    /// Return the DataFusion execution plan as a formatted string.
106    ///
107    /// Useful for debugging and understanding how the query will be executed.
108    pub async fn explain(&self, datasets: HashMap<String, RecordBatch>) -> Result<String> {
109        let ctx = self.build_context(datasets)?;
110
111        let df = ctx
112            .sql(&self.sql)
113            .await
114            .map_err(|e| GraphError::PlanError {
115                message: format!("SQL explain error: {}", e),
116                location: snafu::Location::new(file!(), line!(), column!()),
117            })?;
118
119        let logical_plan = df.logical_plan();
120
121        let physical_plan = ctx
122            .state()
123            .create_physical_plan(logical_plan)
124            .await
125            .map_err(|e| GraphError::PlanError {
126                message: format!("Failed to create physical plan: {}", e),
127                location: snafu::Location::new(file!(), line!(), column!()),
128            })?;
129
130        let physical_plan_str = datafusion::physical_plan::displayable(physical_plan.as_ref())
131            .indent(true)
132            .to_string();
133
134        Ok(format!(
135            "== Logical Plan ==\n{}\n\n== Physical Plan ==\n{}",
136            logical_plan.display_indent(),
137            physical_plan_str,
138        ))
139    }
140
141    /// Build a DataFusion SessionContext from in-memory datasets.
142    fn build_context(&self, datasets: HashMap<String, RecordBatch>) -> Result<SessionContext> {
143        let ctx = SessionContext::new();
144
145        for (name, batch) in datasets {
146            let normalized_batch = normalize_record_batch(&batch)?;
147            let schema = normalized_batch.schema();
148            let mem_table = Arc::new(
149                MemTable::try_new(schema, vec![vec![normalized_batch]]).map_err(|e| {
150                    GraphError::PlanError {
151                        message: format!("Failed to create MemTable for {}: {}", name, e),
152                        location: snafu::Location::new(file!(), line!(), column!()),
153                    }
154                })?,
155            );
156
157            let normalized_name = name.to_lowercase();
158            ctx.register_table(&normalized_name, mem_table)
159                .map_err(|e| GraphError::PlanError {
160                    message: format!("Failed to register table {}: {}", name, e),
161                    location: snafu::Location::new(file!(), line!(), column!()),
162                })?;
163        }
164
165        Ok(ctx)
166    }
167}
168
169/// Helper to get the output schema from a SQL query without executing it.
170async fn df_schema_from_ctx(ctx: &SessionContext, sql: &str) -> Result<Arc<arrow_schema::Schema>> {
171    let df = ctx.sql(sql).await.map_err(|e| GraphError::PlanError {
172        message: format!("Failed to plan SQL for schema: {}", e),
173        location: snafu::Location::new(file!(), line!(), column!()),
174    })?;
175    let arrow_schema = Arc::new(arrow_schema::Schema::from(df.schema()));
176    normalize_schema(arrow_schema)
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use arrow_array::{Float64Array, Int64Array, StringArray};
183    use arrow_schema::{DataType, Field, Schema};
184
185    fn person_batch() -> RecordBatch {
186        let schema = Arc::new(Schema::new(vec![
187            Field::new("id", DataType::Int64, false),
188            Field::new("name", DataType::Utf8, false),
189            Field::new("age", DataType::Int64, false),
190            Field::new("city", DataType::Utf8, false),
191        ]));
192        RecordBatch::try_new(
193            schema,
194            vec![
195                Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
196                Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])),
197                Arc::new(Int64Array::from(vec![28, 34, 29, 42])),
198                Arc::new(StringArray::from(vec![
199                    "New York",
200                    "San Francisco",
201                    "New York",
202                    "Chicago",
203                ])),
204            ],
205        )
206        .unwrap()
207    }
208
209    fn datasets_with(name: &str, batch: RecordBatch) -> HashMap<String, RecordBatch> {
210        let mut datasets = HashMap::new();
211        datasets.insert(name.to_string(), batch);
212        datasets
213    }
214
215    #[tokio::test]
216    async fn test_basic_select() {
217        let query = SqlQuery::new("SELECT name, age FROM person WHERE age > 30 ORDER BY age");
218        let result = query
219            .execute(datasets_with("person", person_batch()))
220            .await
221            .unwrap();
222
223        let names: Vec<&str> = result
224            .column_by_name("name")
225            .unwrap()
226            .as_any()
227            .downcast_ref::<StringArray>()
228            .unwrap()
229            .iter()
230            .map(|v| v.unwrap())
231            .collect();
232        assert_eq!(names, vec!["Bob", "David"]);
233    }
234
235    #[tokio::test]
236    async fn test_select_star() {
237        let query = SqlQuery::new("SELECT * FROM person");
238        let result = query
239            .execute(datasets_with("person", person_batch()))
240            .await
241            .unwrap();
242        assert_eq!(result.num_rows(), 4);
243        assert_eq!(result.num_columns(), 4);
244    }
245
246    #[tokio::test]
247    async fn test_limit() {
248        let query = SqlQuery::new("SELECT name FROM person ORDER BY name LIMIT 2");
249        let result = query
250            .execute(datasets_with("person", person_batch()))
251            .await
252            .unwrap();
253        assert_eq!(result.num_rows(), 2);
254    }
255
256    #[tokio::test]
257    async fn test_aggregation() {
258        let query = SqlQuery::new(
259            "SELECT COUNT(*) as cnt, AVG(age) as avg_age, SUM(age) as total_age FROM person",
260        );
261        let result = query
262            .execute(datasets_with("person", person_batch()))
263            .await
264            .unwrap();
265        assert_eq!(result.num_rows(), 1);
266
267        let cnt = result
268            .column_by_name("cnt")
269            .unwrap()
270            .as_any()
271            .downcast_ref::<Int64Array>()
272            .unwrap()
273            .value(0);
274        assert_eq!(cnt, 4);
275
276        let avg_age = result
277            .column_by_name("avg_age")
278            .unwrap()
279            .as_any()
280            .downcast_ref::<Float64Array>()
281            .unwrap()
282            .value(0);
283        assert!((avg_age - 33.25).abs() < 0.01);
284    }
285
286    #[tokio::test]
287    async fn test_group_by() {
288        let query = SqlQuery::new(
289            "SELECT city, COUNT(*) as cnt FROM person GROUP BY city ORDER BY cnt DESC",
290        );
291        let result = query
292            .execute(datasets_with("person", person_batch()))
293            .await
294            .unwrap();
295
296        let cities: Vec<&str> = result
297            .column_by_name("city")
298            .unwrap()
299            .as_any()
300            .downcast_ref::<StringArray>()
301            .unwrap()
302            .iter()
303            .map(|v| v.unwrap())
304            .collect();
305        // New York has 2, others have 1
306        assert_eq!(cities[0], "New York");
307    }
308
309    #[tokio::test]
310    async fn test_invalid_sql() {
311        let query = SqlQuery::new("INVALID SQL STATEMENT");
312        let result = query.execute(datasets_with("person", person_batch())).await;
313        assert!(result.is_err());
314    }
315
316    #[tokio::test]
317    async fn test_explain() {
318        let query = SqlQuery::new("SELECT name FROM person WHERE age > 30");
319        let plan = query
320            .explain(datasets_with("person", person_batch()))
321            .await
322            .unwrap();
323        assert!(plan.contains("Logical Plan"));
324        assert!(plan.contains("Physical Plan"));
325    }
326
327    #[tokio::test]
328    async fn test_execute_with_context() {
329        // Build context manually and execute against it
330        let ctx = SessionContext::new();
331        let batch = person_batch();
332        let schema = batch.schema();
333        let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![batch]]).unwrap());
334        ctx.register_table("people", mem_table).unwrap();
335
336        let query = SqlQuery::new("SELECT name FROM people ORDER BY name LIMIT 1");
337        let result = query.execute_with_context(ctx).await.unwrap();
338
339        let names: Vec<&str> = result
340            .column_by_name("name")
341            .unwrap()
342            .as_any()
343            .downcast_ref::<StringArray>()
344            .unwrap()
345            .iter()
346            .map(|v| v.unwrap())
347            .collect();
348        assert_eq!(names, vec!["Alice"]);
349    }
350
351    #[tokio::test]
352    async fn test_sql_text_accessor() {
353        let query = SqlQuery::new("SELECT 1");
354        assert_eq!(query.sql(), "SELECT 1");
355    }
356}