Skip to main content

robin_sparkless/sql/
mod.rs

1//! SQL parsing and translation to DataFrame operations.
2//! Parsing is provided by the `spark-sql-parser` crate; this module translates AST to DataFrame ops.
3//! Compiled only when the `sql` feature is enabled.
4
5mod translator;
6
7use crate::dataframe::DataFrame;
8use crate::session::SparkSession;
9use polars::prelude::PolarsError;
10use sqlparser::ast::Statement;
11
12/// Parse a single SQL statement using [spark_sql_parser]. Returns PolarsError for compatibility with session/translator.
13fn parse_sql_to_statement(query: &str) -> Result<Statement, PolarsError> {
14    spark_sql_parser::parse_sql(query)
15        .map_err(|e| PolarsError::InvalidOperation(e.to_string().into()))
16}
17
18/// Parse a single SQL statement (SELECT or DDL: CREATE SCHEMA / CREATE DATABASE / DROP TABLE).
19/// Delegates to the [spark-sql-parser](https://crates.io/crates/spark-sql-parser) crate.
20pub fn parse_sql(query: &str) -> Result<Statement, PolarsError> {
21    parse_sql_to_statement(query)
22}
23
24/// Parse a SQL string and execute it using the session's catalog.
25/// Supports: SELECT (columns or *), FROM single table or two-table JOIN,
26/// WHERE (basic predicates), GROUP BY + aggregates, ORDER BY, LIMIT.
27pub fn execute_sql(session: &SparkSession, query: &str) -> Result<DataFrame, PolarsError> {
28    let stmt = parse_sql_to_statement(query)?;
29    translator::translate(session, &stmt)
30}
31
32pub use translator::{expr_string_to_polars, translate};
33
34#[cfg(test)]
35mod tests {
36    use crate::SparkSession;
37
38    #[test]
39    fn test_sql_select_from_temp_view() {
40        let spark = SparkSession::builder().app_name("test").get_or_create();
41        let df = spark
42            .create_dataframe(
43                vec![
44                    (1, 25, "Alice".to_string()),
45                    (2, 30, "Bob".to_string()),
46                    (3, 35, "Carol".to_string()),
47                ],
48                vec!["id", "age", "name"],
49            )
50            .unwrap();
51        spark.create_or_replace_temp_view("t", df);
52        let result = spark.sql("SELECT id, name FROM t WHERE age > 26").unwrap();
53        let cols = result.columns().unwrap();
54        assert_eq!(cols, vec!["id", "name"]);
55        assert_eq!(result.count().unwrap(), 2);
56    }
57
58    #[test]
59    fn test_sql_select_star() {
60        let spark = SparkSession::builder().app_name("test").get_or_create();
61        let df = spark
62            .create_dataframe(
63                vec![(1, 10, "a".to_string()), (2, 20, "b".to_string())],
64                vec!["id", "age", "name"],
65            )
66            .unwrap();
67        spark.create_or_replace_temp_view("v", df);
68        let result = spark.sql("SELECT * FROM v").unwrap();
69        assert_eq!(result.columns().unwrap(), vec!["id", "age", "name"]);
70        assert_eq!(result.count().unwrap(), 2);
71    }
72
73    #[test]
74    fn test_sql_group_by_count() {
75        let spark = SparkSession::builder().app_name("test").get_or_create();
76        let df = spark
77            .create_dataframe(
78                vec![
79                    (1, 1, "a".to_string()),
80                    (2, 1, "b".to_string()),
81                    (3, 2, "c".to_string()),
82                ],
83                vec!["id", "grp", "name"],
84            )
85            .unwrap();
86        spark.create_or_replace_temp_view("t", df);
87        let result = spark
88            .sql("SELECT grp, COUNT(id) FROM t GROUP BY grp ORDER BY grp")
89            .unwrap();
90        assert_eq!(result.count().unwrap(), 2);
91    }
92
93    #[test]
94    fn test_sql_group_by_expression() {
95        // Issue #588: GROUP BY (age > 30) — expression instead of column name.
96        let spark = SparkSession::builder().app_name("test").get_or_create();
97        let df = spark
98            .create_dataframe(
99                vec![
100                    (1, 25, "a".to_string()),
101                    (2, 35, "b".to_string()),
102                    (3, 28, "c".to_string()),
103                ],
104                vec!["id", "age", "name"],
105            )
106            .unwrap();
107        spark.create_or_replace_temp_view("t", df);
108        let result = spark
109            .sql("SELECT COUNT(*) as count FROM t GROUP BY (age > 30)")
110            .unwrap();
111        assert_eq!(result.count().unwrap(), 2);
112    }
113
114    #[test]
115    fn test_sql_scalar_aggregate() {
116        // Issue #587: SELECT AVG(salary) FROM t (no GROUP BY) — scalar aggregation.
117        // create_dataframe takes (i64, i64, String) -> columns ["id", "salary", "name"]
118        let spark = SparkSession::builder().app_name("test").get_or_create();
119        let df = spark
120            .create_dataframe(
121                vec![(1, 100, "Alice".to_string()), (2, 200, "Bob".to_string())],
122                vec!["id", "salary", "name"],
123            )
124            .unwrap();
125        spark.create_or_replace_temp_view("test_table", df);
126        let result = spark
127            .sql("SELECT AVG(salary) as avg_salary FROM test_table")
128            .unwrap();
129        assert_eq!(result.count().unwrap(), 1);
130        let rows = result.collect_as_json_rows().unwrap();
131        let avg_val = rows[0].get("avg_salary").and_then(|v| v.as_f64()).unwrap();
132        assert!((avg_val - 150.0).abs() < 1e-9);
133    }
134
135    #[test]
136    fn test_sql_having() {
137        let spark = SparkSession::builder().app_name("test").get_or_create();
138        let df = spark
139            .create_dataframe(
140                vec![
141                    (1, 25, "a".to_string()),
142                    (2, 25, "b".to_string()),
143                    (3, 30, "c".to_string()),
144                    (4, 35, "d".to_string()),
145                ],
146                vec!["id", "age", "name"],
147            )
148            .unwrap();
149        spark.create_or_replace_temp_view("t", df);
150        let result = spark
151            .sql("SELECT age, COUNT(id) FROM t GROUP BY age HAVING age > 26")
152            .unwrap();
153        assert_eq!(result.count().unwrap(), 2);
154        let rows = result.collect_as_json_rows().unwrap();
155        let ages: Vec<i64> = rows
156            .iter()
157            .map(|r| r.get("age").and_then(|v| v.as_i64()).unwrap())
158            .collect();
159        assert!(ages.contains(&30));
160        assert!(ages.contains(&35));
161        assert!(!ages.contains(&25));
162    }
163
164    #[test]
165    fn test_sql_having_agg() {
166        // Issue #589: HAVING with aggregate expression (e.g. HAVING AVG(salary) > 55000).
167        // create_dataframe takes (i64, i64, String) -> columns ["dummy", "salary", "dept"]
168        let spark = SparkSession::builder().app_name("test").get_or_create();
169        let df = spark
170            .create_dataframe(
171                vec![
172                    (0, 50000, "A".to_string()),
173                    (0, 60000, "A".to_string()),
174                    (0, 40000, "B".to_string()),
175                ],
176                vec!["dummy", "salary", "dept"],
177            )
178            .unwrap();
179        spark.create_or_replace_temp_view("t", df);
180        let result = spark
181            .sql("SELECT dept, AVG(salary) as avg_sal FROM t GROUP BY dept HAVING AVG(salary) >= 55000")
182            .unwrap();
183        assert_eq!(result.count().unwrap(), 1);
184        let rows = result.collect_as_json_rows().unwrap();
185        assert_eq!(rows[0].get("dept").and_then(|v| v.as_str()).unwrap(), "A");
186    }
187
188    #[test]
189    fn test_sql_where_like_and_in() {
190        // Issue #590: WHERE with LIKE and IN.
191        // create_dataframe takes (i64, i64, String) -> columns ["id", "dummy", "name"]
192        let spark = SparkSession::builder().app_name("test").get_or_create();
193        let df = spark
194            .create_dataframe(
195                vec![
196                    (1, 0, "Alice".to_string()),
197                    (2, 0, "Bob".to_string()),
198                    (3, 0, "Carol".to_string()),
199                ],
200                vec!["id", "dummy", "name"],
201            )
202            .unwrap();
203        spark.create_or_replace_temp_view("t", df);
204        let like_result = spark.sql("SELECT * FROM t WHERE name LIKE 'A%'").unwrap();
205        assert_eq!(like_result.count().unwrap(), 1);
206        let rows = like_result.collect_as_json_rows().unwrap();
207        assert_eq!(
208            rows[0].get("name").and_then(|v| v.as_str()).unwrap(),
209            "Alice"
210        );
211        let in_result = spark.sql("SELECT * FROM t WHERE id IN (1, 2)").unwrap();
212        assert_eq!(in_result.count().unwrap(), 2);
213    }
214
215    #[test]
216    fn test_sql_table_not_found() {
217        let spark = SparkSession::builder().app_name("test").get_or_create();
218        let result = spark.sql("SELECT 1 FROM nonexistent");
219        assert!(result.is_err());
220    }
221
222    #[test]
223    fn test_sql_udf_select() {
224        use polars::prelude::DataType;
225
226        let spark = SparkSession::builder().app_name("test").get_or_create();
227        spark
228            .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
229            .unwrap();
230        let df = spark
231            .create_dataframe(
232                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
233                vec!["id", "age", "name"],
234            )
235            .unwrap();
236        spark.create_or_replace_temp_view("t", df);
237        let result = spark
238            .sql("SELECT id, to_str(id) AS id_str, name FROM t")
239            .unwrap();
240        let cols = result.columns().unwrap();
241        assert!(cols.contains(&"id_str".to_string()));
242        let rows = result.collect_as_json_rows().unwrap();
243        assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
244    }
245
246    #[test]
247    fn test_sql_builtin_upper() {
248        let spark = SparkSession::builder().app_name("test").get_or_create();
249        let df = spark
250            .create_dataframe(
251                vec![(1, 25, "alice".to_string()), (2, 30, "bob".to_string())],
252                vec!["id", "age", "name"],
253            )
254            .unwrap();
255        spark.create_or_replace_temp_view("t", df);
256        let result = spark
257            .sql("SELECT id, UPPER(name) AS upper_name FROM t ORDER BY id")
258            .unwrap();
259        let rows = result.collect_as_json_rows().unwrap();
260        assert_eq!(
261            rows[0].get("upper_name").and_then(|v| v.as_str()),
262            Some("ALICE")
263        );
264    }
265
266    #[test]
267    fn test_sql_from_global_temp_view() {
268        let spark = SparkSession::builder().app_name("test").get_or_create();
269        let df = spark
270            .create_dataframe(
271                vec![(1, 10, "a".to_string()), (2, 20, "b".to_string())],
272                vec!["id", "v", "name"],
273            )
274            .unwrap();
275        spark.create_or_replace_global_temp_view("gv", df);
276        let result = spark
277            .sql("SELECT * FROM global_temp.gv ORDER BY id")
278            .unwrap();
279        assert_eq!(result.count().unwrap(), 2);
280        let rows = result.collect_as_json_rows().unwrap();
281        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("a"));
282        assert_eq!(rows[1].get("name").and_then(|v| v.as_str()), Some("b"));
283    }
284
285    /// Case-insensitive column resolution (PySpark default; issue #194).
286    #[test]
287    fn test_sql_create_schema_ddl() {
288        let spark = SparkSession::builder().app_name("test").get_or_create();
289        // CREATE SCHEMA persists name; returns empty DataFrame (issue #347).
290        let out = spark.sql("CREATE SCHEMA my_schema").unwrap();
291        assert_eq!(out.count().unwrap(), 0);
292        assert!(out.columns().unwrap().is_empty());
293        assert!(spark.database_exists("my_schema"));
294        assert!(
295            spark
296                .list_database_names()
297                .contains(&"my_schema".to_string())
298        );
299    }
300
301    #[test]
302    fn test_sql_create_database_ddl() {
303        let spark = SparkSession::builder().app_name("test").get_or_create();
304        let out = spark.sql("CREATE DATABASE my_db").unwrap();
305        assert_eq!(out.count().unwrap(), 0);
306        assert!(out.columns().unwrap().is_empty());
307        assert!(spark.database_exists("my_db"));
308        assert!(spark.list_database_names().contains(&"my_db".to_string()));
309    }
310
311    #[test]
312    fn test_sql_drop_table_ddl() {
313        let spark = SparkSession::builder().app_name("test").get_or_create();
314        // DROP TABLE IF EXISTS (no error when table does not exist)
315        let out = spark
316            .sql("DROP TABLE IF EXISTS my_schema.my_table")
317            .unwrap();
318        assert_eq!(out.count().unwrap(), 0);
319        // Create a temp view then DROP TABLE
320        let df = spark
321            .create_dataframe(vec![(1i64, 10i64, "a".to_string())], vec!["id", "v", "x"])
322            .unwrap();
323        spark.create_or_replace_temp_view("t_drop_me", df.clone());
324        assert!(spark.table("t_drop_me").is_ok());
325        let _ = spark.sql("DROP TABLE t_drop_me").unwrap();
326        assert!(spark.table("t_drop_me").is_err());
327    }
328
329    #[test]
330    fn test_sql_drop_schema() {
331        let spark = SparkSession::builder().app_name("test").get_or_create();
332        // CREATE then DROP SCHEMA (issue #526; sqlparser 0.45 has no DROP DATABASE token)
333        spark
334            .sql("CREATE SCHEMA IF NOT EXISTS test_schema_to_drop")
335            .unwrap();
336        assert!(spark.database_exists("test_schema_to_drop"));
337        spark
338            .sql("DROP SCHEMA IF EXISTS test_schema_to_drop CASCADE")
339            .unwrap();
340        assert!(!spark.database_exists("test_schema_to_drop"));
341    }
342
343    #[test]
344    fn test_sql_case_insensitive_columns() {
345        let spark = SparkSession::builder().app_name("test").get_or_create();
346        let df = spark
347            .create_dataframe(
348                vec![
349                    (1, 25, "Alice".to_string()),
350                    (2, 30, "Bob".to_string()),
351                    (3, 35, "Charlie".to_string()),
352                ],
353                vec!["Id", "Age", "Name"],
354            )
355            .unwrap();
356        spark.create_or_replace_temp_view("t", df);
357        // SQL with lowercase column names resolves to Id, Age, Name
358        let result = spark
359            .sql("SELECT name, age FROM t WHERE age > 26 ORDER BY age")
360            .unwrap();
361        assert_eq!(result.count().unwrap(), 2);
362        let cols = result.columns().unwrap();
363        assert_eq!(cols, vec!["name", "age"]);
364        let rows = result.collect_as_json_rows().unwrap();
365        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Bob"));
366        assert_eq!(rows[0].get("age").and_then(|v| v.as_i64()), Some(30));
367    }
368}