1mod parser;
5mod translator;
6
7use crate::dataframe::DataFrame;
8use crate::session::SparkSession;
9use polars::prelude::PolarsError;
10
11pub fn execute_sql(session: &SparkSession, query: &str) -> Result<DataFrame, PolarsError> {
15 let stmt = parser::parse_sql(query)?;
16 translator::translate(session, &stmt)
17}
18
19pub use parser::parse_sql;
20pub use translator::{expr_string_to_polars, translate};
21
22#[cfg(test)]
23mod tests {
24 use crate::SparkSession;
25
26 #[test]
27 fn test_sql_select_from_temp_view() {
28 let spark = SparkSession::builder().app_name("test").get_or_create();
29 let df = spark
30 .create_dataframe(
31 vec![
32 (1, 25, "Alice".to_string()),
33 (2, 30, "Bob".to_string()),
34 (3, 35, "Carol".to_string()),
35 ],
36 vec!["id", "age", "name"],
37 )
38 .unwrap();
39 spark.create_or_replace_temp_view("t", df);
40 let result = spark.sql("SELECT id, name FROM t WHERE age > 26").unwrap();
41 let cols = result.columns().unwrap();
42 assert_eq!(cols, vec!["id", "name"]);
43 assert_eq!(result.count().unwrap(), 2);
44 }
45
46 #[test]
47 fn test_sql_select_star() {
48 let spark = SparkSession::builder().app_name("test").get_or_create();
49 let df = spark
50 .create_dataframe(
51 vec![(1, 10, "a".to_string()), (2, 20, "b".to_string())],
52 vec!["id", "age", "name"],
53 )
54 .unwrap();
55 spark.create_or_replace_temp_view("v", df);
56 let result = spark.sql("SELECT * FROM v").unwrap();
57 assert_eq!(result.columns().unwrap(), vec!["id", "age", "name"]);
58 assert_eq!(result.count().unwrap(), 2);
59 }
60
61 #[test]
62 fn test_sql_group_by_count() {
63 let spark = SparkSession::builder().app_name("test").get_or_create();
64 let df = spark
65 .create_dataframe(
66 vec![
67 (1, 1, "a".to_string()),
68 (2, 1, "b".to_string()),
69 (3, 2, "c".to_string()),
70 ],
71 vec!["id", "grp", "name"],
72 )
73 .unwrap();
74 spark.create_or_replace_temp_view("t", df);
75 let result = spark
76 .sql("SELECT grp, COUNT(id) FROM t GROUP BY grp ORDER BY grp")
77 .unwrap();
78 assert_eq!(result.count().unwrap(), 2);
79 }
80
81 #[test]
82 fn test_sql_group_by_expression() {
83 let spark = SparkSession::builder().app_name("test").get_or_create();
85 let df = spark
86 .create_dataframe(
87 vec![
88 (1, 25, "a".to_string()),
89 (2, 35, "b".to_string()),
90 (3, 28, "c".to_string()),
91 ],
92 vec!["id", "age", "name"],
93 )
94 .unwrap();
95 spark.create_or_replace_temp_view("t", df);
96 let result = spark
97 .sql("SELECT COUNT(*) as count FROM t GROUP BY (age > 30)")
98 .unwrap();
99 assert_eq!(result.count().unwrap(), 2);
100 }
101
102 #[test]
103 fn test_sql_scalar_aggregate() {
104 let spark = SparkSession::builder().app_name("test").get_or_create();
107 let df = spark
108 .create_dataframe(
109 vec![(1, 100, "Alice".to_string()), (2, 200, "Bob".to_string())],
110 vec!["id", "salary", "name"],
111 )
112 .unwrap();
113 spark.create_or_replace_temp_view("test_table", df);
114 let result = spark
115 .sql("SELECT AVG(salary) as avg_salary FROM test_table")
116 .unwrap();
117 assert_eq!(result.count().unwrap(), 1);
118 let rows = result.collect_as_json_rows().unwrap();
119 let avg_val = rows[0].get("avg_salary").and_then(|v| v.as_f64()).unwrap();
120 assert!((avg_val - 150.0).abs() < 1e-9);
121 }
122
123 #[test]
124 fn test_sql_having() {
125 let spark = SparkSession::builder().app_name("test").get_or_create();
126 let df = spark
127 .create_dataframe(
128 vec![
129 (1, 25, "a".to_string()),
130 (2, 25, "b".to_string()),
131 (3, 30, "c".to_string()),
132 (4, 35, "d".to_string()),
133 ],
134 vec!["id", "age", "name"],
135 )
136 .unwrap();
137 spark.create_or_replace_temp_view("t", df);
138 let result = spark
139 .sql("SELECT age, COUNT(id) FROM t GROUP BY age HAVING age > 26")
140 .unwrap();
141 assert_eq!(result.count().unwrap(), 2);
142 let rows = result.collect_as_json_rows().unwrap();
143 let ages: Vec<i64> = rows
144 .iter()
145 .map(|r| r.get("age").and_then(|v| v.as_i64()).unwrap())
146 .collect();
147 assert!(ages.contains(&30));
148 assert!(ages.contains(&35));
149 assert!(!ages.contains(&25));
150 }
151
152 #[test]
153 fn test_sql_having_agg() {
154 let spark = SparkSession::builder().app_name("test").get_or_create();
157 let df = spark
158 .create_dataframe(
159 vec![
160 (0, 50000, "A".to_string()),
161 (0, 60000, "A".to_string()),
162 (0, 40000, "B".to_string()),
163 ],
164 vec!["dummy", "salary", "dept"],
165 )
166 .unwrap();
167 spark.create_or_replace_temp_view("t", df);
168 let result = spark
169 .sql("SELECT dept, AVG(salary) as avg_sal FROM t GROUP BY dept HAVING AVG(salary) >= 55000")
170 .unwrap();
171 assert_eq!(result.count().unwrap(), 1);
172 let rows = result.collect_as_json_rows().unwrap();
173 assert_eq!(rows[0].get("dept").and_then(|v| v.as_str()).unwrap(), "A");
174 }
175
176 #[test]
177 fn test_sql_where_like_and_in() {
178 let spark = SparkSession::builder().app_name("test").get_or_create();
181 let df = spark
182 .create_dataframe(
183 vec![
184 (1, 0, "Alice".to_string()),
185 (2, 0, "Bob".to_string()),
186 (3, 0, "Carol".to_string()),
187 ],
188 vec!["id", "dummy", "name"],
189 )
190 .unwrap();
191 spark.create_or_replace_temp_view("t", df);
192 let like_result = spark.sql("SELECT * FROM t WHERE name LIKE 'A%'").unwrap();
193 assert_eq!(like_result.count().unwrap(), 1);
194 let rows = like_result.collect_as_json_rows().unwrap();
195 assert_eq!(
196 rows[0].get("name").and_then(|v| v.as_str()).unwrap(),
197 "Alice"
198 );
199 let in_result = spark.sql("SELECT * FROM t WHERE id IN (1, 2)").unwrap();
200 assert_eq!(in_result.count().unwrap(), 2);
201 }
202
203 #[test]
204 fn test_sql_table_not_found() {
205 let spark = SparkSession::builder().app_name("test").get_or_create();
206 let result = spark.sql("SELECT 1 FROM nonexistent");
207 assert!(result.is_err());
208 }
209
210 #[test]
211 fn test_sql_udf_select() {
212 use polars::prelude::DataType;
213
214 let spark = SparkSession::builder().app_name("test").get_or_create();
215 spark
216 .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
217 .unwrap();
218 let df = spark
219 .create_dataframe(
220 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
221 vec!["id", "age", "name"],
222 )
223 .unwrap();
224 spark.create_or_replace_temp_view("t", df);
225 let result = spark
226 .sql("SELECT id, to_str(id) AS id_str, name FROM t")
227 .unwrap();
228 let cols = result.columns().unwrap();
229 assert!(cols.contains(&"id_str".to_string()));
230 let rows = result.collect_as_json_rows().unwrap();
231 assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
232 }
233
234 #[test]
235 fn test_sql_builtin_upper() {
236 let spark = SparkSession::builder().app_name("test").get_or_create();
237 let df = spark
238 .create_dataframe(
239 vec![(1, 25, "alice".to_string()), (2, 30, "bob".to_string())],
240 vec!["id", "age", "name"],
241 )
242 .unwrap();
243 spark.create_or_replace_temp_view("t", df);
244 let result = spark
245 .sql("SELECT id, UPPER(name) AS upper_name FROM t ORDER BY id")
246 .unwrap();
247 let rows = result.collect_as_json_rows().unwrap();
248 assert_eq!(
249 rows[0].get("upper_name").and_then(|v| v.as_str()),
250 Some("ALICE")
251 );
252 }
253
254 #[test]
255 fn test_sql_from_global_temp_view() {
256 let spark = SparkSession::builder().app_name("test").get_or_create();
257 let df = spark
258 .create_dataframe(
259 vec![(1, 10, "a".to_string()), (2, 20, "b".to_string())],
260 vec!["id", "v", "name"],
261 )
262 .unwrap();
263 spark.create_or_replace_global_temp_view("gv", df);
264 let result = spark
265 .sql("SELECT * FROM global_temp.gv ORDER BY id")
266 .unwrap();
267 assert_eq!(result.count().unwrap(), 2);
268 let rows = result.collect_as_json_rows().unwrap();
269 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("a"));
270 assert_eq!(rows[1].get("name").and_then(|v| v.as_str()), Some("b"));
271 }
272
273 #[test]
275 fn test_sql_create_schema_ddl() {
276 let spark = SparkSession::builder().app_name("test").get_or_create();
277 let out = spark.sql("CREATE SCHEMA my_schema").unwrap();
279 assert_eq!(out.count().unwrap(), 0);
280 assert!(out.columns().unwrap().is_empty());
281 assert!(spark.database_exists("my_schema"));
282 assert!(
283 spark
284 .list_database_names()
285 .contains(&"my_schema".to_string())
286 );
287 }
288
289 #[test]
290 fn test_sql_create_database_ddl() {
291 let spark = SparkSession::builder().app_name("test").get_or_create();
292 let out = spark.sql("CREATE DATABASE my_db").unwrap();
293 assert_eq!(out.count().unwrap(), 0);
294 assert!(out.columns().unwrap().is_empty());
295 assert!(spark.database_exists("my_db"));
296 assert!(spark.list_database_names().contains(&"my_db".to_string()));
297 }
298
299 #[test]
300 fn test_sql_drop_table_ddl() {
301 let spark = SparkSession::builder().app_name("test").get_or_create();
302 let out = spark
304 .sql("DROP TABLE IF EXISTS my_schema.my_table")
305 .unwrap();
306 assert_eq!(out.count().unwrap(), 0);
307 let df = spark
309 .create_dataframe(vec![(1i64, 10i64, "a".to_string())], vec!["id", "v", "x"])
310 .unwrap();
311 spark.create_or_replace_temp_view("t_drop_me", df.clone());
312 assert!(spark.table("t_drop_me").is_ok());
313 let _ = spark.sql("DROP TABLE t_drop_me").unwrap();
314 assert!(spark.table("t_drop_me").is_err());
315 }
316
317 #[test]
318 fn test_sql_drop_schema() {
319 let spark = SparkSession::builder().app_name("test").get_or_create();
320 spark
322 .sql("CREATE SCHEMA IF NOT EXISTS test_schema_to_drop")
323 .unwrap();
324 assert!(spark.database_exists("test_schema_to_drop"));
325 spark
326 .sql("DROP SCHEMA IF EXISTS test_schema_to_drop CASCADE")
327 .unwrap();
328 assert!(!spark.database_exists("test_schema_to_drop"));
329 }
330
331 #[test]
332 fn test_sql_case_insensitive_columns() {
333 let spark = SparkSession::builder().app_name("test").get_or_create();
334 let df = spark
335 .create_dataframe(
336 vec![
337 (1, 25, "Alice".to_string()),
338 (2, 30, "Bob".to_string()),
339 (3, 35, "Charlie".to_string()),
340 ],
341 vec!["Id", "Age", "Name"],
342 )
343 .unwrap();
344 spark.create_or_replace_temp_view("t", df);
345 let result = spark
347 .sql("SELECT name, age FROM t WHERE age > 26 ORDER BY age")
348 .unwrap();
349 assert_eq!(result.count().unwrap(), 2);
350 let cols = result.columns().unwrap();
351 assert_eq!(cols, vec!["name", "age"]);
352 let rows = result.collect_as_json_rows().unwrap();
353 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Bob"));
354 assert_eq!(rows[0].get("age").and_then(|v| v.as_i64()), Some(30));
355 }
356}