1mod translator;
6
7use crate::dataframe::DataFrame;
8use crate::session::SparkSession;
9use polars::prelude::PolarsError;
10use sqlparser::ast::Statement;
11
12fn parse_sql_to_statement(query: &str) -> Result<Statement, PolarsError> {
14 spark_sql_parser::parse_sql(query)
15 .map_err(|e| PolarsError::InvalidOperation(e.to_string().into()))
16}
17
18pub fn parse_sql(query: &str) -> Result<Statement, PolarsError> {
21 parse_sql_to_statement(query)
22}
23
24pub fn execute_sql(session: &SparkSession, query: &str) -> Result<DataFrame, PolarsError> {
28 let stmt = parse_sql_to_statement(query)?;
29 translator::translate(session, &stmt)
30}
31
32pub use translator::{expr_string_to_polars, translate};
33
34#[cfg(test)]
35mod tests {
36 use crate::SparkSession;
37
38 #[test]
39 fn test_sql_select_from_temp_view() {
40 let spark = SparkSession::builder().app_name("test").get_or_create();
41 let df = spark
42 .create_dataframe(
43 vec![
44 (1, 25, "Alice".to_string()),
45 (2, 30, "Bob".to_string()),
46 (3, 35, "Carol".to_string()),
47 ],
48 vec!["id", "age", "name"],
49 )
50 .unwrap();
51 spark.create_or_replace_temp_view("t", df);
52 let result = spark.sql("SELECT id, name FROM t WHERE age > 26").unwrap();
53 let cols = result.columns().unwrap();
54 assert_eq!(cols, vec!["id", "name"]);
55 assert_eq!(result.count().unwrap(), 2);
56 }
57
58 #[test]
59 fn test_sql_select_star() {
60 let spark = SparkSession::builder().app_name("test").get_or_create();
61 let df = spark
62 .create_dataframe(
63 vec![(1, 10, "a".to_string()), (2, 20, "b".to_string())],
64 vec!["id", "age", "name"],
65 )
66 .unwrap();
67 spark.create_or_replace_temp_view("v", df);
68 let result = spark.sql("SELECT * FROM v").unwrap();
69 assert_eq!(result.columns().unwrap(), vec!["id", "age", "name"]);
70 assert_eq!(result.count().unwrap(), 2);
71 }
72
73 #[test]
74 fn test_sql_group_by_count() {
75 let spark = SparkSession::builder().app_name("test").get_or_create();
76 let df = spark
77 .create_dataframe(
78 vec![
79 (1, 1, "a".to_string()),
80 (2, 1, "b".to_string()),
81 (3, 2, "c".to_string()),
82 ],
83 vec!["id", "grp", "name"],
84 )
85 .unwrap();
86 spark.create_or_replace_temp_view("t", df);
87 let result = spark
88 .sql("SELECT grp, COUNT(id) FROM t GROUP BY grp ORDER BY grp")
89 .unwrap();
90 assert_eq!(result.count().unwrap(), 2);
91 }
92
93 #[test]
94 fn test_sql_group_by_expression() {
95 let spark = SparkSession::builder().app_name("test").get_or_create();
97 let df = spark
98 .create_dataframe(
99 vec![
100 (1, 25, "a".to_string()),
101 (2, 35, "b".to_string()),
102 (3, 28, "c".to_string()),
103 ],
104 vec!["id", "age", "name"],
105 )
106 .unwrap();
107 spark.create_or_replace_temp_view("t", df);
108 let result = spark
109 .sql("SELECT COUNT(*) as count FROM t GROUP BY (age > 30)")
110 .unwrap();
111 assert_eq!(result.count().unwrap(), 2);
112 }
113
114 #[test]
115 fn test_sql_scalar_aggregate() {
116 let spark = SparkSession::builder().app_name("test").get_or_create();
119 let df = spark
120 .create_dataframe(
121 vec![(1, 100, "Alice".to_string()), (2, 200, "Bob".to_string())],
122 vec!["id", "salary", "name"],
123 )
124 .unwrap();
125 spark.create_or_replace_temp_view("test_table", df);
126 let result = spark
127 .sql("SELECT AVG(salary) as avg_salary FROM test_table")
128 .unwrap();
129 assert_eq!(result.count().unwrap(), 1);
130 let rows = result.collect_as_json_rows().unwrap();
131 let avg_val = rows[0].get("avg_salary").and_then(|v| v.as_f64()).unwrap();
132 assert!((avg_val - 150.0).abs() < 1e-9);
133 }
134
135 #[test]
136 fn test_sql_having() {
137 let spark = SparkSession::builder().app_name("test").get_or_create();
138 let df = spark
139 .create_dataframe(
140 vec![
141 (1, 25, "a".to_string()),
142 (2, 25, "b".to_string()),
143 (3, 30, "c".to_string()),
144 (4, 35, "d".to_string()),
145 ],
146 vec!["id", "age", "name"],
147 )
148 .unwrap();
149 spark.create_or_replace_temp_view("t", df);
150 let result = spark
151 .sql("SELECT age, COUNT(id) FROM t GROUP BY age HAVING age > 26")
152 .unwrap();
153 assert_eq!(result.count().unwrap(), 2);
154 let rows = result.collect_as_json_rows().unwrap();
155 let ages: Vec<i64> = rows
156 .iter()
157 .map(|r| r.get("age").and_then(|v| v.as_i64()).unwrap())
158 .collect();
159 assert!(ages.contains(&30));
160 assert!(ages.contains(&35));
161 assert!(!ages.contains(&25));
162 }
163
164 #[test]
165 fn test_sql_having_agg() {
166 let spark = SparkSession::builder().app_name("test").get_or_create();
169 let df = spark
170 .create_dataframe(
171 vec![
172 (0, 50000, "A".to_string()),
173 (0, 60000, "A".to_string()),
174 (0, 40000, "B".to_string()),
175 ],
176 vec!["dummy", "salary", "dept"],
177 )
178 .unwrap();
179 spark.create_or_replace_temp_view("t", df);
180 let result = spark
181 .sql("SELECT dept, AVG(salary) as avg_sal FROM t GROUP BY dept HAVING AVG(salary) >= 55000")
182 .unwrap();
183 assert_eq!(result.count().unwrap(), 1);
184 let rows = result.collect_as_json_rows().unwrap();
185 assert_eq!(rows[0].get("dept").and_then(|v| v.as_str()).unwrap(), "A");
186 }
187
188 #[test]
189 fn test_sql_where_like_and_in() {
190 let spark = SparkSession::builder().app_name("test").get_or_create();
193 let df = spark
194 .create_dataframe(
195 vec![
196 (1, 0, "Alice".to_string()),
197 (2, 0, "Bob".to_string()),
198 (3, 0, "Carol".to_string()),
199 ],
200 vec!["id", "dummy", "name"],
201 )
202 .unwrap();
203 spark.create_or_replace_temp_view("t", df);
204 let like_result = spark.sql("SELECT * FROM t WHERE name LIKE 'A%'").unwrap();
205 assert_eq!(like_result.count().unwrap(), 1);
206 let rows = like_result.collect_as_json_rows().unwrap();
207 assert_eq!(
208 rows[0].get("name").and_then(|v| v.as_str()).unwrap(),
209 "Alice"
210 );
211 let in_result = spark.sql("SELECT * FROM t WHERE id IN (1, 2)").unwrap();
212 assert_eq!(in_result.count().unwrap(), 2);
213 }
214
215 #[test]
216 fn test_sql_table_not_found() {
217 let spark = SparkSession::builder().app_name("test").get_or_create();
218 let result = spark.sql("SELECT 1 FROM nonexistent");
219 assert!(result.is_err());
220 }
221
222 #[test]
223 fn test_sql_udf_select() {
224 use polars::prelude::DataType;
225
226 let spark = SparkSession::builder().app_name("test").get_or_create();
227 spark
228 .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
229 .unwrap();
230 let df = spark
231 .create_dataframe(
232 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
233 vec!["id", "age", "name"],
234 )
235 .unwrap();
236 spark.create_or_replace_temp_view("t", df);
237 let result = spark
238 .sql("SELECT id, to_str(id) AS id_str, name FROM t")
239 .unwrap();
240 let cols = result.columns().unwrap();
241 assert!(cols.contains(&"id_str".to_string()));
242 let rows = result.collect_as_json_rows().unwrap();
243 assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
244 }
245
246 #[test]
247 fn test_sql_builtin_upper() {
248 let spark = SparkSession::builder().app_name("test").get_or_create();
249 let df = spark
250 .create_dataframe(
251 vec![(1, 25, "alice".to_string()), (2, 30, "bob".to_string())],
252 vec!["id", "age", "name"],
253 )
254 .unwrap();
255 spark.create_or_replace_temp_view("t", df);
256 let result = spark
257 .sql("SELECT id, UPPER(name) AS upper_name FROM t ORDER BY id")
258 .unwrap();
259 let rows = result.collect_as_json_rows().unwrap();
260 assert_eq!(
261 rows[0].get("upper_name").and_then(|v| v.as_str()),
262 Some("ALICE")
263 );
264 }
265
266 #[test]
267 fn test_sql_from_global_temp_view() {
268 let spark = SparkSession::builder().app_name("test").get_or_create();
269 let df = spark
270 .create_dataframe(
271 vec![(1, 10, "a".to_string()), (2, 20, "b".to_string())],
272 vec!["id", "v", "name"],
273 )
274 .unwrap();
275 spark.create_or_replace_global_temp_view("gv", df);
276 let result = spark
277 .sql("SELECT * FROM global_temp.gv ORDER BY id")
278 .unwrap();
279 assert_eq!(result.count().unwrap(), 2);
280 let rows = result.collect_as_json_rows().unwrap();
281 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("a"));
282 assert_eq!(rows[1].get("name").and_then(|v| v.as_str()), Some("b"));
283 }
284
285 #[test]
287 fn test_sql_create_schema_ddl() {
288 let spark = SparkSession::builder().app_name("test").get_or_create();
289 let out = spark.sql("CREATE SCHEMA my_schema").unwrap();
291 assert_eq!(out.count().unwrap(), 0);
292 assert!(out.columns().unwrap().is_empty());
293 assert!(spark.database_exists("my_schema"));
294 assert!(
295 spark
296 .list_database_names()
297 .contains(&"my_schema".to_string())
298 );
299 }
300
301 #[test]
302 fn test_sql_create_database_ddl() {
303 let spark = SparkSession::builder().app_name("test").get_or_create();
304 let out = spark.sql("CREATE DATABASE my_db").unwrap();
305 assert_eq!(out.count().unwrap(), 0);
306 assert!(out.columns().unwrap().is_empty());
307 assert!(spark.database_exists("my_db"));
308 assert!(spark.list_database_names().contains(&"my_db".to_string()));
309 }
310
311 #[test]
312 fn test_sql_drop_table_ddl() {
313 let spark = SparkSession::builder().app_name("test").get_or_create();
314 let out = spark
316 .sql("DROP TABLE IF EXISTS my_schema.my_table")
317 .unwrap();
318 assert_eq!(out.count().unwrap(), 0);
319 let df = spark
321 .create_dataframe(vec![(1i64, 10i64, "a".to_string())], vec!["id", "v", "x"])
322 .unwrap();
323 spark.create_or_replace_temp_view("t_drop_me", df.clone());
324 assert!(spark.table("t_drop_me").is_ok());
325 let _ = spark.sql("DROP TABLE t_drop_me").unwrap();
326 assert!(spark.table("t_drop_me").is_err());
327 }
328
329 #[test]
330 fn test_sql_drop_schema() {
331 let spark = SparkSession::builder().app_name("test").get_or_create();
332 spark
334 .sql("CREATE SCHEMA IF NOT EXISTS test_schema_to_drop")
335 .unwrap();
336 assert!(spark.database_exists("test_schema_to_drop"));
337 spark
338 .sql("DROP SCHEMA IF EXISTS test_schema_to_drop CASCADE")
339 .unwrap();
340 assert!(!spark.database_exists("test_schema_to_drop"));
341 }
342
343 #[test]
344 fn test_sql_case_insensitive_columns() {
345 let spark = SparkSession::builder().app_name("test").get_or_create();
346 let df = spark
347 .create_dataframe(
348 vec![
349 (1, 25, "Alice".to_string()),
350 (2, 30, "Bob".to_string()),
351 (3, 35, "Charlie".to_string()),
352 ],
353 vec!["Id", "Age", "Name"],
354 )
355 .unwrap();
356 spark.create_or_replace_temp_view("t", df);
357 let result = spark
359 .sql("SELECT name, age FROM t WHERE age > 26 ORDER BY age")
360 .unwrap();
361 assert_eq!(result.count().unwrap(), 2);
362 let cols = result.columns().unwrap();
363 assert_eq!(cols, vec!["name", "age"]);
364 let rows = result.collect_as_json_rows().unwrap();
365 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Bob"));
366 assert_eq!(rows[0].get("age").and_then(|v| v.as_i64()), Some(30));
367 }
368}