Skip to main content

robin_sparkless/sql/
mod.rs

1//! SQL parsing and translation to DataFrame operations.
2//! Compiled only when the `sql` feature is enabled.
3
4mod parser;
5mod translator;
6
7use crate::dataframe::DataFrame;
8use crate::session::SparkSession;
9use polars::prelude::PolarsError;
10
11/// Parse a SQL string and execute it using the session's catalog.
12/// Supports: SELECT (columns or *), FROM single table or two-table JOIN,
13/// WHERE (basic predicates), GROUP BY + aggregates, ORDER BY, LIMIT.
14pub fn execute_sql(session: &SparkSession, query: &str) -> Result<DataFrame, PolarsError> {
15    let stmt = parser::parse_sql(query)?;
16    translator::translate(session, &stmt)
17}
18
19pub use parser::parse_sql;
20pub use translator::{expr_string_to_polars, translate};
21
22#[cfg(test)]
23mod tests {
24    use crate::SparkSession;
25
26    #[test]
27    fn test_sql_select_from_temp_view() {
28        let spark = SparkSession::builder().app_name("test").get_or_create();
29        let df = spark
30            .create_dataframe(
31                vec![
32                    (1, 25, "Alice".to_string()),
33                    (2, 30, "Bob".to_string()),
34                    (3, 35, "Carol".to_string()),
35                ],
36                vec!["id", "age", "name"],
37            )
38            .unwrap();
39        spark.create_or_replace_temp_view("t", df);
40        let result = spark.sql("SELECT id, name FROM t WHERE age > 26").unwrap();
41        let cols = result.columns().unwrap();
42        assert_eq!(cols, vec!["id", "name"]);
43        assert_eq!(result.count().unwrap(), 2);
44    }
45
46    #[test]
47    fn test_sql_select_star() {
48        let spark = SparkSession::builder().app_name("test").get_or_create();
49        let df = spark
50            .create_dataframe(
51                vec![(1, 10, "a".to_string()), (2, 20, "b".to_string())],
52                vec!["id", "age", "name"],
53            )
54            .unwrap();
55        spark.create_or_replace_temp_view("v", df);
56        let result = spark.sql("SELECT * FROM v").unwrap();
57        assert_eq!(result.columns().unwrap(), vec!["id", "age", "name"]);
58        assert_eq!(result.count().unwrap(), 2);
59    }
60
61    #[test]
62    fn test_sql_group_by_count() {
63        let spark = SparkSession::builder().app_name("test").get_or_create();
64        let df = spark
65            .create_dataframe(
66                vec![
67                    (1, 1, "a".to_string()),
68                    (2, 1, "b".to_string()),
69                    (3, 2, "c".to_string()),
70                ],
71                vec!["id", "grp", "name"],
72            )
73            .unwrap();
74        spark.create_or_replace_temp_view("t", df);
75        let result = spark
76            .sql("SELECT grp, COUNT(id) FROM t GROUP BY grp ORDER BY grp")
77            .unwrap();
78        assert_eq!(result.count().unwrap(), 2);
79    }
80
81    #[test]
82    fn test_sql_group_by_expression() {
83        // Issue #588: GROUP BY (age > 30) — expression instead of column name.
84        let spark = SparkSession::builder().app_name("test").get_or_create();
85        let df = spark
86            .create_dataframe(
87                vec![
88                    (1, 25, "a".to_string()),
89                    (2, 35, "b".to_string()),
90                    (3, 28, "c".to_string()),
91                ],
92                vec!["id", "age", "name"],
93            )
94            .unwrap();
95        spark.create_or_replace_temp_view("t", df);
96        let result = spark
97            .sql("SELECT COUNT(*) as count FROM t GROUP BY (age > 30)")
98            .unwrap();
99        assert_eq!(result.count().unwrap(), 2);
100    }
101
102    #[test]
103    fn test_sql_scalar_aggregate() {
104        // Issue #587: SELECT AVG(salary) FROM t (no GROUP BY) — scalar aggregation.
105        // create_dataframe takes (i64, i64, String) -> columns ["id", "salary", "name"]
106        let spark = SparkSession::builder().app_name("test").get_or_create();
107        let df = spark
108            .create_dataframe(
109                vec![(1, 100, "Alice".to_string()), (2, 200, "Bob".to_string())],
110                vec!["id", "salary", "name"],
111            )
112            .unwrap();
113        spark.create_or_replace_temp_view("test_table", df);
114        let result = spark
115            .sql("SELECT AVG(salary) as avg_salary FROM test_table")
116            .unwrap();
117        assert_eq!(result.count().unwrap(), 1);
118        let rows = result.collect_as_json_rows().unwrap();
119        let avg_val = rows[0].get("avg_salary").and_then(|v| v.as_f64()).unwrap();
120        assert!((avg_val - 150.0).abs() < 1e-9);
121    }
122
123    #[test]
124    fn test_sql_having() {
125        let spark = SparkSession::builder().app_name("test").get_or_create();
126        let df = spark
127            .create_dataframe(
128                vec![
129                    (1, 25, "a".to_string()),
130                    (2, 25, "b".to_string()),
131                    (3, 30, "c".to_string()),
132                    (4, 35, "d".to_string()),
133                ],
134                vec!["id", "age", "name"],
135            )
136            .unwrap();
137        spark.create_or_replace_temp_view("t", df);
138        let result = spark
139            .sql("SELECT age, COUNT(id) FROM t GROUP BY age HAVING age > 26")
140            .unwrap();
141        assert_eq!(result.count().unwrap(), 2);
142        let rows = result.collect_as_json_rows().unwrap();
143        let ages: Vec<i64> = rows
144            .iter()
145            .map(|r| r.get("age").and_then(|v| v.as_i64()).unwrap())
146            .collect();
147        assert!(ages.contains(&30));
148        assert!(ages.contains(&35));
149        assert!(!ages.contains(&25));
150    }
151
152    #[test]
153    fn test_sql_having_agg() {
154        // Issue #589: HAVING with aggregate expression (e.g. HAVING AVG(salary) > 55000).
155        // create_dataframe takes (i64, i64, String) -> columns ["dummy", "salary", "dept"]
156        let spark = SparkSession::builder().app_name("test").get_or_create();
157        let df = spark
158            .create_dataframe(
159                vec![
160                    (0, 50000, "A".to_string()),
161                    (0, 60000, "A".to_string()),
162                    (0, 40000, "B".to_string()),
163                ],
164                vec!["dummy", "salary", "dept"],
165            )
166            .unwrap();
167        spark.create_or_replace_temp_view("t", df);
168        let result = spark
169            .sql("SELECT dept, AVG(salary) as avg_sal FROM t GROUP BY dept HAVING AVG(salary) >= 55000")
170            .unwrap();
171        assert_eq!(result.count().unwrap(), 1);
172        let rows = result.collect_as_json_rows().unwrap();
173        assert_eq!(rows[0].get("dept").and_then(|v| v.as_str()).unwrap(), "A");
174    }
175
176    #[test]
177    fn test_sql_where_like_and_in() {
178        // Issue #590: WHERE with LIKE and IN.
179        // create_dataframe takes (i64, i64, String) -> columns ["id", "dummy", "name"]
180        let spark = SparkSession::builder().app_name("test").get_or_create();
181        let df = spark
182            .create_dataframe(
183                vec![
184                    (1, 0, "Alice".to_string()),
185                    (2, 0, "Bob".to_string()),
186                    (3, 0, "Carol".to_string()),
187                ],
188                vec!["id", "dummy", "name"],
189            )
190            .unwrap();
191        spark.create_or_replace_temp_view("t", df);
192        let like_result = spark.sql("SELECT * FROM t WHERE name LIKE 'A%'").unwrap();
193        assert_eq!(like_result.count().unwrap(), 1);
194        let rows = like_result.collect_as_json_rows().unwrap();
195        assert_eq!(
196            rows[0].get("name").and_then(|v| v.as_str()).unwrap(),
197            "Alice"
198        );
199        let in_result = spark.sql("SELECT * FROM t WHERE id IN (1, 2)").unwrap();
200        assert_eq!(in_result.count().unwrap(), 2);
201    }
202
203    #[test]
204    fn test_sql_table_not_found() {
205        let spark = SparkSession::builder().app_name("test").get_or_create();
206        let result = spark.sql("SELECT 1 FROM nonexistent");
207        assert!(result.is_err());
208    }
209
210    #[test]
211    fn test_sql_udf_select() {
212        use polars::prelude::DataType;
213
214        let spark = SparkSession::builder().app_name("test").get_or_create();
215        spark
216            .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
217            .unwrap();
218        let df = spark
219            .create_dataframe(
220                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
221                vec!["id", "age", "name"],
222            )
223            .unwrap();
224        spark.create_or_replace_temp_view("t", df);
225        let result = spark
226            .sql("SELECT id, to_str(id) AS id_str, name FROM t")
227            .unwrap();
228        let cols = result.columns().unwrap();
229        assert!(cols.contains(&"id_str".to_string()));
230        let rows = result.collect_as_json_rows().unwrap();
231        assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
232    }
233
234    #[test]
235    fn test_sql_builtin_upper() {
236        let spark = SparkSession::builder().app_name("test").get_or_create();
237        let df = spark
238            .create_dataframe(
239                vec![(1, 25, "alice".to_string()), (2, 30, "bob".to_string())],
240                vec!["id", "age", "name"],
241            )
242            .unwrap();
243        spark.create_or_replace_temp_view("t", df);
244        let result = spark
245            .sql("SELECT id, UPPER(name) AS upper_name FROM t ORDER BY id")
246            .unwrap();
247        let rows = result.collect_as_json_rows().unwrap();
248        assert_eq!(
249            rows[0].get("upper_name").and_then(|v| v.as_str()),
250            Some("ALICE")
251        );
252    }
253
254    #[test]
255    fn test_sql_from_global_temp_view() {
256        let spark = SparkSession::builder().app_name("test").get_or_create();
257        let df = spark
258            .create_dataframe(
259                vec![(1, 10, "a".to_string()), (2, 20, "b".to_string())],
260                vec!["id", "v", "name"],
261            )
262            .unwrap();
263        spark.create_or_replace_global_temp_view("gv", df);
264        let result = spark
265            .sql("SELECT * FROM global_temp.gv ORDER BY id")
266            .unwrap();
267        assert_eq!(result.count().unwrap(), 2);
268        let rows = result.collect_as_json_rows().unwrap();
269        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("a"));
270        assert_eq!(rows[1].get("name").and_then(|v| v.as_str()), Some("b"));
271    }
272
273    /// Case-insensitive column resolution (PySpark default; issue #194).
274    #[test]
275    fn test_sql_create_schema_ddl() {
276        let spark = SparkSession::builder().app_name("test").get_or_create();
277        // CREATE SCHEMA persists name; returns empty DataFrame (issue #347).
278        let out = spark.sql("CREATE SCHEMA my_schema").unwrap();
279        assert_eq!(out.count().unwrap(), 0);
280        assert!(out.columns().unwrap().is_empty());
281        assert!(spark.database_exists("my_schema"));
282        assert!(
283            spark
284                .list_database_names()
285                .contains(&"my_schema".to_string())
286        );
287    }
288
289    #[test]
290    fn test_sql_create_database_ddl() {
291        let spark = SparkSession::builder().app_name("test").get_or_create();
292        let out = spark.sql("CREATE DATABASE my_db").unwrap();
293        assert_eq!(out.count().unwrap(), 0);
294        assert!(out.columns().unwrap().is_empty());
295        assert!(spark.database_exists("my_db"));
296        assert!(spark.list_database_names().contains(&"my_db".to_string()));
297    }
298
299    #[test]
300    fn test_sql_drop_table_ddl() {
301        let spark = SparkSession::builder().app_name("test").get_or_create();
302        // DROP TABLE IF EXISTS (no error when table does not exist)
303        let out = spark
304            .sql("DROP TABLE IF EXISTS my_schema.my_table")
305            .unwrap();
306        assert_eq!(out.count().unwrap(), 0);
307        // Create a temp view then DROP TABLE
308        let df = spark
309            .create_dataframe(vec![(1i64, 10i64, "a".to_string())], vec!["id", "v", "x"])
310            .unwrap();
311        spark.create_or_replace_temp_view("t_drop_me", df.clone());
312        assert!(spark.table("t_drop_me").is_ok());
313        let _ = spark.sql("DROP TABLE t_drop_me").unwrap();
314        assert!(spark.table("t_drop_me").is_err());
315    }
316
317    #[test]
318    fn test_sql_drop_schema() {
319        let spark = SparkSession::builder().app_name("test").get_or_create();
320        // CREATE then DROP SCHEMA (issue #526; sqlparser 0.45 has no DROP DATABASE token)
321        spark
322            .sql("CREATE SCHEMA IF NOT EXISTS test_schema_to_drop")
323            .unwrap();
324        assert!(spark.database_exists("test_schema_to_drop"));
325        spark
326            .sql("DROP SCHEMA IF EXISTS test_schema_to_drop CASCADE")
327            .unwrap();
328        assert!(!spark.database_exists("test_schema_to_drop"));
329    }
330
331    #[test]
332    fn test_sql_case_insensitive_columns() {
333        let spark = SparkSession::builder().app_name("test").get_or_create();
334        let df = spark
335            .create_dataframe(
336                vec![
337                    (1, 25, "Alice".to_string()),
338                    (2, 30, "Bob".to_string()),
339                    (3, 35, "Charlie".to_string()),
340                ],
341                vec!["Id", "Age", "Name"],
342            )
343            .unwrap();
344        spark.create_or_replace_temp_view("t", df);
345        // SQL with lowercase column names resolves to Id, Age, Name
346        let result = spark
347            .sql("SELECT name, age FROM t WHERE age > 26 ORDER BY age")
348            .unwrap();
349        assert_eq!(result.count().unwrap(), 2);
350        let cols = result.columns().unwrap();
351        assert_eq!(cols, vec!["name", "age"]);
352        let rows = result.collect_as_json_rows().unwrap();
353        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Bob"));
354        assert_eq!(rows[0].get("age").and_then(|v| v.as_i64()), Some(30));
355    }
356}