1use std::collections::HashMap;
5
6use crate::column::Column;
7use crate::dataframe::{DataFrame, JoinType, join};
8use crate::functions;
9use crate::session::{SparkSession, set_thread_udf_session};
10use polars::prelude::{DataFrame as PlDataFrame, Expr, PolarsError, col, lit};
11use sqlparser::ast::{
12 BinaryOperator, Expr as SqlExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments,
13 GroupByExpr, JoinConstraint, JoinOperator, ObjectType, OrderByKind, Query, Select, SelectItem,
14 SetExpr, Statement, TableFactor, Value, ValueWithSpan,
15};
16
17use super::parser;
18
19fn function_args_slice(args: &FunctionArguments) -> &[FunctionArg] {
21 match args {
22 FunctionArguments::List(list) => &list.args,
23 _ => &[],
24 }
25}
26
27pub fn expr_string_to_polars(
30 expr_str: &str,
31 session: &SparkSession,
32 df: &DataFrame,
33) -> Result<Expr, PolarsError> {
34 let query = format!("SELECT {} FROM __selectexpr_t", expr_str);
35 let stmt = parser::parse_sql(&query)?;
36 let query_ast = match &stmt {
37 Statement::Query(q) => q.as_ref(),
38 _ => {
39 return Err(PolarsError::InvalidOperation(
40 "expr_string_to_polars: expected SELECT statement".into(),
41 ));
42 }
43 };
44 let body = match query_ast.body.as_ref() {
45 SetExpr::Select(s) => s.as_ref(),
46 _ => {
47 return Err(PolarsError::InvalidOperation(
48 "expr_string_to_polars: expected SELECT".into(),
49 ));
50 }
51 };
52 let first = body.projection.first().ok_or_else(|| {
53 PolarsError::InvalidOperation("expr_string_to_polars: empty SELECT list".into())
54 })?;
55 set_thread_udf_session(session.clone());
56 let (sql_expr, alias) = match first {
57 SelectItem::UnnamedExpr(e) => ((*e).clone(), None),
58 SelectItem::ExprWithAlias { expr, alias: a } => ((*expr).clone(), Some(a.value.as_str())),
59 _ => {
60 return Err(PolarsError::InvalidOperation(
61 format!("expr_string_to_polars: unsupported select item {:?}", first).into(),
62 ));
63 }
64 };
65 let expr = sql_expr_to_polars(&sql_expr, session, Some(df), None)?;
66 Ok(match alias {
67 Some(a) => expr.alias(a),
68 None => expr,
69 })
70}
71
72pub fn translate(
75 session: &SparkSession,
76 stmt: &Statement,
77) -> Result<crate::dataframe::DataFrame, PolarsError> {
78 set_thread_udf_session(session.clone());
79 match stmt {
80 Statement::Query(q) => translate_query(session, q.as_ref()),
81 Statement::CreateSchema { schema_name, .. } => {
82 let name = schema_name.to_string();
83 session.register_database(&name);
84 Ok(DataFrame::from_polars_with_options(
85 PlDataFrame::empty(),
86 session.is_case_sensitive(),
87 ))
88 }
89 Statement::CreateDatabase { db_name, .. } => {
90 let name = db_name.to_string();
91 session.register_database(&name);
92 Ok(DataFrame::from_polars_with_options(
93 PlDataFrame::empty(),
94 session.is_case_sensitive(),
95 ))
96 }
97 Statement::Drop {
98 object_type: ObjectType::Table | ObjectType::View,
99 names,
100 ..
101 } => {
102 for obj_name in names {
103 let name = obj_name.to_string();
104 if name.starts_with("global_temp.") {
105 if let Some(suffix) = name.strip_prefix("global_temp.") {
106 session.drop_global_temp_view(suffix);
107 }
108 }
109 session.drop_temp_view(&name);
110 session.drop_table(&name);
111 }
112 Ok(DataFrame::from_polars_with_options(
113 PlDataFrame::empty(),
114 session.is_case_sensitive(),
115 ))
116 }
117 Statement::Drop {
118 object_type: ObjectType::Schema,
119 names,
120 ..
121 } => {
122 for obj_name in names {
123 session.drop_database(&obj_name.to_string());
124 }
125 Ok(DataFrame::from_polars_with_options(
126 PlDataFrame::empty(),
127 session.is_case_sensitive(),
128 ))
129 }
130 _ => Err(PolarsError::InvalidOperation(
131 "SQL: only SELECT, CREATE SCHEMA/DATABASE, and DROP TABLE/VIEW/SCHEMA are supported."
132 .into(),
133 )),
134 }
135}
136
137fn translate_query(
138 session: &SparkSession,
139 query: &Query,
140) -> Result<crate::dataframe::DataFrame, PolarsError> {
141 let body = match query.body.as_ref() {
142 SetExpr::Select(select) => select.as_ref(),
143 _ => {
144 return Err(PolarsError::InvalidOperation(
145 "SQL: only SELECT (no UNION/EXCEPT/INTERSECT) is supported.".into(),
146 ));
147 }
148 };
149 let mut df = translate_select_from(session, body)?;
150 if let Some(selection) = &body.selection {
151 let expr = sql_expr_to_polars(selection, session, Some(&df), None)?;
152 df = df.filter(expr)?;
153 }
154 let group_exprs: &[SqlExpr] = match &body.group_by {
155 GroupByExpr::Expressions(exprs, _) => exprs.as_slice(),
156 GroupByExpr::All(_) => {
157 return Err(PolarsError::InvalidOperation(
158 "SQL: GROUP BY ALL is not supported. Use explicit GROUP BY columns.".into(),
159 ));
160 }
161 };
162 let has_group_by = !group_exprs.is_empty();
163 let mut having_agg_map: HashMap<(String, String), String> = HashMap::new();
164 if has_group_by {
165 let pairs: Vec<(Expr, String)> = group_exprs
167 .iter()
168 .enumerate()
169 .map(|(i, e)| {
170 Ok(match e {
171 SqlExpr::Identifier(ident) => {
172 let name = ident.value.as_str();
173 let resolved = df.resolve_column_name(name)?;
174 (col(resolved.as_str()), resolved)
175 }
176 SqlExpr::CompoundIdentifier(parts) => {
177 let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
178 let resolved = df.resolve_column_name(name)?;
179 (col(resolved.as_str()), resolved)
180 }
181 _ => {
182 let expr = sql_expr_to_polars(e, session, Some(&df), None)?;
183 let name = format!("group_{}", i);
184 (expr.alias(&name), name)
185 }
186 })
187 })
188 .collect::<Result<Vec<_>, PolarsError>>()?;
189 let (group_exprs_polars, group_cols): (Vec<Expr>, Vec<String>) = pairs.into_iter().unzip();
190 let grouped = df.group_by_exprs(group_exprs_polars, group_cols.clone())?;
191 let mut agg_exprs = projection_to_agg_exprs(&body.projection, &group_cols, &df)?;
192 if let Some(having_expr) = &body.having {
193 let having_list = extract_having_agg_calls(having_expr);
194 for (func, alias) in &having_list {
195 push_agg_function(
196 &func.name,
197 function_args_slice(&func.args),
198 &df,
199 Some(alias.as_str()),
200 &mut agg_exprs,
201 )?;
202 }
203 having_agg_map = having_list
204 .into_iter()
205 .filter_map(|(f, alias)| agg_function_key(&f).map(|k| (k, alias)))
206 .collect();
207 }
208 if agg_exprs.is_empty() {
209 df = grouped.count()?;
210 } else {
211 df = grouped.agg(agg_exprs)?;
212 }
213 } else if projection_is_scalar_aggregate(&body.projection) {
214 let agg_exprs = projection_to_agg_exprs(&body.projection, &[], &df)?;
216 let pl_df = df.lazy_frame().select(agg_exprs).collect()?;
217 df = DataFrame::from_polars_with_options(pl_df, df.case_sensitive);
218 } else {
219 df = apply_projection(&df, &body.projection, session)?;
220 }
221 if let Some(having_expr) = &body.having {
222 let having_polars = sql_expr_to_polars(
223 having_expr,
224 session,
225 Some(&df),
226 Some(&having_agg_map).filter(|m| !m.is_empty()),
227 )?;
228 df = df.filter(having_polars)?;
229 }
230 if let Some(order_by) = &query.order_by {
231 if let OrderByKind::Expressions(exprs) = &order_by.kind {
232 if !exprs.is_empty() {
233 let pairs: Vec<(String, bool)> = exprs
234 .iter()
235 .map(|o| {
236 let col_name = sql_expr_to_col_name(&o.expr)?;
237 let resolved = df.resolve_column_name(&col_name)?;
238 let ascending = o.options.asc.unwrap_or(true);
239 Ok((resolved, ascending))
240 })
241 .collect::<Result<Vec<_>, PolarsError>>()?;
242 let (cols, asc): (Vec<String>, Vec<bool>) = pairs.into_iter().unzip();
243 let col_refs: Vec<&str> = cols.iter().map(|s| s.as_str()).collect();
244 df = df.order_by(col_refs, asc)?;
245 }
246 }
247 }
248 let limit_expr = query.fetch.as_ref().and_then(|f| f.quantity.as_ref());
249 if let Some(limit_expr) = limit_expr {
250 let n = sql_limit_to_usize(limit_expr)?;
251 df = df.limit(n)?;
252 }
253 Ok(df)
254}
255
256fn translate_select_from(
257 session: &SparkSession,
258 select: &Select,
259) -> Result<crate::dataframe::DataFrame, PolarsError> {
260 if select.from.is_empty() {
261 return Err(PolarsError::InvalidOperation(
262 "SQL: FROM clause is required. Register a table with create_or_replace_temp_view."
263 .into(),
264 ));
265 }
266 let first_tj = &select.from[0];
267 let mut df = resolve_table_factor(session, &first_tj.relation)?;
268 for join_spec in &first_tj.joins {
269 let right_df = resolve_table_factor(session, &join_spec.relation)?;
270 let join_type = match &join_spec.join_operator {
271 JoinOperator::Inner(_) => JoinType::Inner,
272 JoinOperator::LeftOuter(_) => JoinType::Left,
273 JoinOperator::RightOuter(_) => JoinType::Right,
274 JoinOperator::FullOuter(_) => JoinType::Outer,
275 _ => {
276 return Err(PolarsError::InvalidOperation(
277 "SQL: only INNER, LEFT, RIGHT, FULL JOIN are supported.".into(),
278 ));
279 }
280 };
281 let on_cols = join_condition_to_on_columns(&join_spec.join_operator)?;
282 let on_refs: Vec<&str> = on_cols.iter().map(|s| s.as_str()).collect();
283 df = join(
284 &df,
285 &right_df,
286 on_refs,
287 join_type,
288 session.is_case_sensitive(),
289 )?;
290 }
291 Ok(df)
292}
293
294fn resolve_table_factor(
295 session: &SparkSession,
296 factor: &TableFactor,
297) -> Result<crate::dataframe::DataFrame, PolarsError> {
298 match factor {
299 TableFactor::Table { name, .. } => {
300 let table_name = if name.0.len() >= 2 {
302 let parts: Vec<String> = name
303 .0
304 .iter()
305 .filter_map(|p| p.as_ident().map(|i| i.value.clone()))
306 .collect();
307 parts.join(".")
308 } else {
309 name.0
310 .last()
311 .and_then(|p| p.as_ident())
312 .map(|i| i.value.clone())
313 .unwrap_or_default()
314 };
315 session.table(&table_name)
316 }
317 _ => Err(PolarsError::InvalidOperation(
318 "SQL: only plain table names are supported in FROM (no subqueries, derived tables). Register with create_or_replace_temp_view.".into(),
319 )),
320 }
321}
322
323fn join_condition_to_on_columns(join_op: &JoinOperator) -> Result<Vec<String>, PolarsError> {
324 let constraint = match join_op {
325 JoinOperator::Inner(c)
326 | JoinOperator::LeftOuter(c)
327 | JoinOperator::RightOuter(c)
328 | JoinOperator::FullOuter(c) => c,
329 _ => {
330 return Err(PolarsError::InvalidOperation(
331 "SQL: only INNER/LEFT/RIGHT/FULL JOIN with ON are supported.".into(),
332 ));
333 }
334 };
335 match constraint {
336 JoinConstraint::On(expr) => match expr {
337 SqlExpr::BinaryOp {
338 left,
339 op: BinaryOperator::Eq,
340 right,
341 } => {
342 let l = sql_expr_to_col_name(left.as_ref())?;
343 let r = sql_expr_to_col_name(right.as_ref())?;
344 if l != r {
345 return Err(PolarsError::InvalidOperation(
346 "SQL: JOIN ON must use same column name on both sides (e.g. a.id = b.id where both become 'id').".into(),
347 ));
348 }
349 Ok(vec![l])
350 }
351 _ => Err(PolarsError::InvalidOperation(
352 "SQL: JOIN ON must be a single equality (col = col).".into(),
353 )),
354 },
355 _ => Err(PolarsError::InvalidOperation(
356 "SQL: JOIN must use ON (equality); NATURAL/USING not supported.".into(),
357 )),
358 }
359}
360
361fn sql_expr_to_polars(
362 expr: &SqlExpr,
363 session: &SparkSession,
364 df: Option<&DataFrame>,
365 having_agg_map: Option<&HashMap<(String, String), String>>,
366) -> Result<Expr, PolarsError> {
367 match expr {
368 SqlExpr::Identifier(ident) => {
369 let name = ident.value.as_str();
370 let resolved = df
371 .map(|d| d.resolve_column_name(name))
372 .transpose()?
373 .unwrap_or_else(|| name.to_string());
374 Ok(col(resolved.as_str()))
375 }
376 SqlExpr::CompoundIdentifier(parts) => {
377 let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
378 let resolved = df
379 .map(|d| d.resolve_column_name(name))
380 .transpose()?
381 .unwrap_or_else(|| name.to_string());
382 Ok(col(resolved.as_str()))
383 }
384 SqlExpr::Value(ValueWithSpan { value: Value::Number(s, _), .. }) => {
385 if s.contains('.') {
386 let v: f64 = s.parse().map_err(|_| {
387 PolarsError::InvalidOperation(format!("SQL: invalid number literal '{}'", s).into())
388 })?;
389 Ok(lit(v))
390 } else {
391 let v: i64 = s.parse().map_err(|_| {
392 PolarsError::InvalidOperation(format!("SQL: invalid integer literal '{}'", s).into())
393 })?;
394 Ok(lit(v))
395 }
396 }
397 SqlExpr::Value(ValueWithSpan { value: Value::SingleQuotedString(s), .. }) => Ok(lit(s.as_str())),
398 SqlExpr::Value(ValueWithSpan { value: Value::Boolean(b), .. }) => Ok(lit(*b)),
399 SqlExpr::Value(ValueWithSpan { value: Value::Null, .. }) => Ok(lit(polars::prelude::NULL)),
400 SqlExpr::BinaryOp { left, op, right } => {
401 let l = sql_expr_to_polars(left, session, df, having_agg_map)?;
402 let r = sql_expr_to_polars(right, session, df, having_agg_map)?;
403 match op {
404 BinaryOperator::Eq => Ok(l.eq(r)),
405 BinaryOperator::NotEq => Ok(l.eq(r).not()),
406 BinaryOperator::Gt => Ok(l.gt(r)),
407 BinaryOperator::GtEq => Ok(l.gt_eq(r)),
408 BinaryOperator::Lt => Ok(l.lt(r)),
409 BinaryOperator::LtEq => Ok(l.lt_eq(r)),
410 BinaryOperator::And => Ok(l.and(r)),
411 BinaryOperator::Or => Ok(l.or(r)),
412 _ => Err(PolarsError::InvalidOperation(
413 format!("SQL: unsupported operator in WHERE: {:?}. Use =, <>, <, <=, >, >=, AND, OR.", op).into(),
414 )),
415 }
416 }
417 SqlExpr::Nested(inner) => sql_expr_to_polars(inner, session, df, having_agg_map),
418 SqlExpr::IsNull(expr) => Ok(sql_expr_to_polars(expr, session, df, having_agg_map)?.is_null()),
419 SqlExpr::IsNotNull(expr) => Ok(sql_expr_to_polars(expr, session, df, having_agg_map)?.is_not_null()),
420 SqlExpr::UnaryOp { op, expr } => {
421 let e = sql_expr_to_polars(expr, session, df, having_agg_map)?;
422 match op {
423 sqlparser::ast::UnaryOperator::Not => Ok(e.not()),
424 _ => Err(PolarsError::InvalidOperation(
425 format!("SQL: unsupported unary operator in WHERE: {:?}", op).into(),
426 )),
427 }
428 }
429 SqlExpr::Function(func) => {
430 if let Some(map) = having_agg_map {
431 if let Some(key) = agg_function_key(func) {
432 if let Some(col_name) = map.get(&key) {
433 return Ok(col(col_name.as_str()));
434 }
435 }
436 }
437 sql_function_to_expr(func, session, df)
438 }
439 SqlExpr::Like {
440 negated,
441 expr: left,
442 pattern,
443 escape_char,
444 any: _,
445 } => {
446 let col_expr = sql_expr_to_polars(left.as_ref(), session, df, having_agg_map)?;
447 let pattern_str = sql_expr_to_string_literal(pattern.as_ref())?;
448 let col_col = crate::column::Column::from_expr(col_expr, None);
449 let escape: Option<char> = escape_char.as_ref().and_then(|v| match v {
450 Value::SingleQuotedString(s) => s.chars().next(),
451 _ => None,
452 });
453 let like_expr = col_col.like(&pattern_str, escape).into_expr();
454 Ok(if *negated {
455 like_expr.not()
456 } else {
457 like_expr
458 })
459 }
460 SqlExpr::InList {
461 expr: left,
462 list,
463 negated,
464 } => {
465 let col_expr = sql_expr_to_polars(left.as_ref(), session, df, having_agg_map)?;
466 if list.is_empty() {
467 return Ok(lit(false));
468 }
469 let series = sql_in_list_to_series(list)?;
470 let in_expr = col_expr.is_in(lit(series), false);
471 Ok(if *negated {
472 in_expr.not()
473 } else {
474 in_expr
475 })
476 }
477 _ => Err(PolarsError::InvalidOperation(
478 format!("SQL: unsupported expression in WHERE: {:?}. Use column, literal, =, <, >, AND, OR, IS NULL, LIKE, IN.", expr).into(),
479 )),
480 }
481}
482
483fn sql_function_to_expr(
487 func: &Function,
488 session: &SparkSession,
489 df: Option<&DataFrame>,
490) -> Result<Expr, PolarsError> {
491 let func_name = func
492 .name
493 .0
494 .last()
495 .and_then(|p| p.as_ident())
496 .map(|i| i.value.as_str())
497 .unwrap_or("");
498 let args = sql_function_args_to_columns(func, session, df)?;
499
500 let case_sensitive = session.is_case_sensitive();
501
502 if let Some(col) = args.first() {
504 let builtin_expr = match func_name.to_uppercase().as_str() {
505 "UPPER" | "UCASE" if args.len() == 1 => Some(functions::upper(col).expr().clone()),
506 "LOWER" | "LCASE" if args.len() == 1 => Some(functions::lower(col).expr().clone()),
507 _ => None,
508 };
509 if let Some(e) = builtin_expr {
510 return Ok(e);
511 }
512 }
513
514 if session.udf_registry.has_udf(func_name, case_sensitive) {
516 let col = functions::call_udf(func_name, &args)?;
517 if col.udf_call.is_some() {
518 return Err(PolarsError::InvalidOperation(
519 "SQL: Python UDF in WHERE/HAVING not yet supported. Use in SELECT.".into(),
520 ));
521 }
522 return Ok(col.expr().clone());
523 }
524
525 Err(PolarsError::InvalidOperation(
526 format!("SQL: unknown function '{}'. Register with spark.udf.register() or use built-ins: UPPER, LOWER.", func_name).into(),
527 ))
528}
529
530fn sql_function_args_to_columns(
531 func: &Function,
532 session: &SparkSession,
533 df: Option<&DataFrame>,
534) -> Result<Vec<Column>, PolarsError> {
535 let mut cols = Vec::new();
536 for arg in function_args_slice(&func.args) {
537 if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg {
538 let e = sql_expr_to_polars(expr, session, df, None)?;
539 cols.push(Column::from_expr(e, None));
540 } else {
541 return Err(PolarsError::InvalidOperation(
542 "SQL: only positional function arguments supported.".into(),
543 ));
544 }
545 }
546 Ok(cols)
547}
548
549fn sql_expr_to_col_name(expr: &SqlExpr) -> Result<String, PolarsError> {
550 match expr {
551 SqlExpr::Identifier(ident) => Ok(ident.value.clone()),
552 SqlExpr::CompoundIdentifier(parts) => parts
553 .last()
554 .map(|i| i.value.clone())
555 .ok_or_else(|| PolarsError::InvalidOperation("SQL: empty compound identifier.".into())),
556 _ => Err(PolarsError::InvalidOperation(
557 format!("SQL: expected column name, got {:?}", expr).into(),
558 )),
559 }
560}
561
562fn sql_expr_to_string_literal(expr: &SqlExpr) -> Result<String, PolarsError> {
564 match expr {
565 SqlExpr::Value(ValueWithSpan {
566 value: Value::SingleQuotedString(s),
567 ..
568 }) => Ok(s.clone()),
569 _ => Err(PolarsError::InvalidOperation(
570 format!("SQL: LIKE pattern must be a string literal, got {:?}", expr).into(),
571 )),
572 }
573}
574
575fn sql_in_list_to_series(list: &[SqlExpr]) -> Result<polars::prelude::Series, PolarsError> {
577 use polars::prelude::Series;
578 let mut str_vals: Vec<String> = Vec::new();
579 let mut int_vals: Vec<i64> = Vec::new();
580 let mut float_vals: Vec<f64> = Vec::new();
581 let mut has_string = false;
582 let mut has_float = false;
583 for e in list {
584 match e {
585 SqlExpr::Value(ValueWithSpan {
586 value: Value::SingleQuotedString(s),
587 ..
588 }) => {
589 str_vals.push(s.clone());
590 has_string = true;
591 }
592 SqlExpr::Value(ValueWithSpan {
593 value: Value::Number(n, _),
594 ..
595 }) => {
596 str_vals.push(n.clone());
597 if n.contains('.') {
598 let v: f64 = n.parse().map_err(|_| {
599 PolarsError::InvalidOperation(
600 format!("SQL: invalid number in IN list '{}'", n).into(),
601 )
602 })?;
603 float_vals.push(v);
604 has_float = true;
605 } else {
606 let v: i64 = n.parse().map_err(|_| {
607 PolarsError::InvalidOperation(
608 format!("SQL: invalid integer in IN list '{}'", n).into(),
609 )
610 })?;
611 int_vals.push(v);
612 }
613 }
614 SqlExpr::Value(ValueWithSpan {
615 value: Value::Boolean(b),
616 ..
617 }) => {
618 str_vals.push(b.to_string());
619 has_string = true;
620 }
621 SqlExpr::Value(ValueWithSpan {
622 value: Value::Null, ..
623 }) => {}
624 _ => {
625 return Err(PolarsError::InvalidOperation(
626 format!("SQL: IN list supports only literals, got {:?}", e).into(),
627 ));
628 }
629 }
630 }
631 let series = if has_string {
632 Series::from_iter(str_vals.iter().map(|s| s.as_str()))
633 } else if !has_float && int_vals.len() == str_vals.len() {
634 Series::from_iter(int_vals)
635 } else if float_vals.len() == str_vals.len() {
636 Series::from_iter(float_vals)
637 } else {
638 Series::from_iter(str_vals.iter().map(|s| s.as_str()))
639 };
640 Ok(series)
641}
642
643enum ProjItem {
645 Expr(Expr, String),
646 PythonUdf(Column, String),
647}
648
649fn apply_projection(
650 df: &crate::dataframe::DataFrame,
651 projection: &[SelectItem],
652 session: &SparkSession,
653) -> Result<crate::dataframe::DataFrame, PolarsError> {
654 for item in projection {
656 if matches!(item, SelectItem::Wildcard(_)) {
657 let column_names = df.columns()?;
658 let all_col_names: Vec<&str> = column_names.iter().map(|s| s.as_str()).collect();
659 return df.select(all_col_names);
660 }
661 }
662
663 let mut items = Vec::new();
664 for item in projection {
665 let proj = match item {
666 SelectItem::UnnamedExpr(SqlExpr::Identifier(ident)) => {
667 let name = ident.value.as_str();
668 let resolved = df.resolve_column_name(name)?;
669 ProjItem::Expr(col(resolved.as_str()), name.to_string())
670 }
671 SelectItem::UnnamedExpr(SqlExpr::CompoundIdentifier(parts)) => {
672 let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
673 let resolved = df.resolve_column_name(name)?;
674 ProjItem::Expr(col(resolved.as_str()), name.to_string())
675 }
676 SelectItem::UnnamedExpr(SqlExpr::Function(func)) => {
677 projection_function_to_item(func, session, Some(df))?
678 }
679 SelectItem::ExprWithAlias { expr, alias } => {
680 let alias_str = alias.value.clone();
681 match expr {
682 SqlExpr::Identifier(ident) => {
683 let name = ident.value.as_str();
684 let resolved = df.resolve_column_name(name)?;
685 ProjItem::Expr(col(resolved.as_str()), alias_str)
686 }
687 SqlExpr::CompoundIdentifier(parts) => {
688 let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
689 let resolved = df.resolve_column_name(name)?;
690 ProjItem::Expr(col(resolved.as_str()), alias_str)
691 }
692 SqlExpr::Function(func) => {
693 let mut item = projection_function_to_item(func, session, Some(df))?;
694 item = match item {
696 ProjItem::Expr(e, _) => ProjItem::Expr(e, alias_str),
697 ProjItem::PythonUdf(c, _) => ProjItem::PythonUdf(c, alias_str),
698 };
699 item
700 }
701 _ => {
702 return Err(PolarsError::InvalidOperation(
703 format!("SQL: unsupported expression with alias: {:?}", expr).into(),
704 ));
705 }
706 }
707 }
708 _ => {
709 return Err(PolarsError::InvalidOperation(
710 format!(
711 "SQL: SELECT supports column names, *, and function calls. Got {:?}",
712 item
713 )
714 .into(),
715 ));
716 }
717 };
718 items.push(proj);
719 }
720
721 if items.is_empty() {
722 return Err(PolarsError::InvalidOperation(
723 "SQL: SELECT must list at least one column or *.".into(),
724 ));
725 }
726
727 let has_python_udf = items.iter().any(|i| matches!(i, ProjItem::PythonUdf(_, _)));
729
730 let mut df = df.clone();
731
732 if has_python_udf {
733 for item in &items {
735 if let ProjItem::PythonUdf(col, alias) = item {
736 df = df.with_column(alias, col)?;
737 }
738 }
739 let exprs: Vec<Expr> = items
740 .iter()
741 .map(|i| match i {
742 ProjItem::Expr(e, alias) => e.clone().alias(alias),
743 ProjItem::PythonUdf(_, alias) => col(alias.as_str()).alias(alias),
744 })
745 .collect();
746 df.select_exprs(exprs)
747 } else {
748 let exprs: Vec<Expr> = items
750 .iter()
751 .map(|i| match i {
752 ProjItem::Expr(e, alias) => e.clone().alias(alias),
753 ProjItem::PythonUdf(_, _) => unreachable!(),
754 })
755 .collect();
756 df.select_exprs(exprs)
757 }
758}
759
760fn sql_function_alias(func: &Function) -> String {
761 let func_name = func
762 .name
763 .0
764 .last()
765 .and_then(|p| p.as_ident())
766 .map(|i| i.value.as_str())
767 .unwrap_or("");
768 let arg_parts: Vec<String> = function_args_slice(&func.args)
769 .iter()
770 .filter_map(|a| {
771 if let FunctionArg::Unnamed(FunctionArgExpr::Expr(SqlExpr::Identifier(ident))) = a {
772 Some(ident.value.to_string())
773 } else if let FunctionArg::Unnamed(FunctionArgExpr::Expr(
774 SqlExpr::CompoundIdentifier(parts),
775 )) = a
776 {
777 parts.last().map(|i| i.value.to_string())
778 } else {
779 Some("_".to_string())
780 }
781 })
782 .collect();
783 if arg_parts.is_empty() {
784 format!("{}()", func_name)
785 } else {
786 format!("{}({})", func_name, arg_parts.join(", "))
787 }
788}
789
790fn projection_function_to_item(
791 func: &Function,
792 session: &SparkSession,
793 df: Option<&DataFrame>,
794) -> Result<ProjItem, PolarsError> {
795 let func_name = func
796 .name
797 .0
798 .last()
799 .and_then(|p| p.as_ident())
800 .map(|i| i.value.as_str())
801 .unwrap_or("");
802 let args = sql_function_args_to_columns(func, session, df)?;
803 let case_sensitive = session.is_case_sensitive();
804 let alias = sql_function_alias(func);
805
806 if let Some(col) = args.first() {
808 let builtin = match func_name.to_uppercase().as_str() {
809 "UPPER" | "UCASE" if args.len() == 1 => {
810 Some(functions::upper(col).expr().clone().alias(&alias))
811 }
812 "LOWER" | "LCASE" if args.len() == 1 => {
813 Some(functions::lower(col).expr().clone().alias(&alias))
814 }
815 _ => None,
816 };
817 if let Some(e) = builtin {
818 return Ok(ProjItem::Expr(e, alias));
819 }
820 }
821
822 if session.udf_registry.has_udf(func_name, case_sensitive) {
824 let col = functions::call_udf(func_name, &args)?;
825 if col.udf_call.is_some() {
826 return Ok(ProjItem::PythonUdf(col, alias));
827 }
828 return Ok(ProjItem::Expr(col.expr().clone().alias(&alias), alias));
829 }
830
831 Err(PolarsError::InvalidOperation(
832 format!(
833 "SQL: unknown function '{}'. Register with spark.udf.register() or use built-ins: UPPER, LOWER.",
834 func_name
835 )
836 .into(),
837 ))
838}
839
840fn push_agg_function(
842 name: &sqlparser::ast::ObjectName,
843 args: &[sqlparser::ast::FunctionArg],
844 df: &DataFrame,
845 alias_override: Option<&str>,
846 agg: &mut Vec<Expr>,
847) -> Result<(), PolarsError> {
848 use polars::prelude::len;
849
850 let func_name = name
851 .0
852 .last()
853 .and_then(|p| p.as_ident())
854 .map(|i| i.value.as_str())
855 .unwrap_or("");
856 let (expr, default_alias) = match func_name.to_uppercase().as_str() {
857 "COUNT" => {
858 let e = if args.is_empty() {
859 len()
860 } else if args.len() == 1 {
861 use sqlparser::ast::FunctionArgExpr;
862 match &args[0] {
863 sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => len(),
864 sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => {
865 let expr = match e {
866 SqlExpr::Nested(inner) => inner.as_ref(),
867 other => other,
868 };
869 match expr {
870 SqlExpr::Wildcard(_) => len(),
871 SqlExpr::Identifier(ident) => {
872 let resolved = df.resolve_column_name(ident.value.as_str())?;
873 col(resolved.as_str()).count()
874 }
875 _ => len(), }
877 }
878 _ => {
879 return Err(PolarsError::InvalidOperation(
880 "SQL: COUNT(*) or COUNT(column) only.".into(),
881 ));
882 }
883 }
884 } else {
885 return Err(PolarsError::InvalidOperation(
886 "SQL: COUNT takes at most one argument.".into(),
887 ));
888 };
889 (e, "count".to_string())
890 }
891 "SUM" => {
892 if let Some(sqlparser::ast::FunctionArg::Unnamed(
893 sqlparser::ast::FunctionArgExpr::Expr(SqlExpr::Identifier(ident)),
894 )) = args.first()
895 {
896 let resolved = df.resolve_column_name(ident.value.as_str())?;
897 (
898 col(resolved.as_str()).sum(),
899 format!("sum({})", ident.value),
900 )
901 } else {
902 return Err(PolarsError::InvalidOperation(
903 "SQL: SUM(column) only.".into(),
904 ));
905 }
906 }
907 "AVG" | "MEAN" => {
908 if let Some(sqlparser::ast::FunctionArg::Unnamed(
909 sqlparser::ast::FunctionArgExpr::Expr(SqlExpr::Identifier(ident)),
910 )) = args.first()
911 {
912 let resolved = df.resolve_column_name(ident.value.as_str())?;
913 (
914 col(resolved.as_str()).mean(),
915 format!("avg({})", ident.value),
916 )
917 } else {
918 return Err(PolarsError::InvalidOperation(
919 "SQL: AVG(column) only.".into(),
920 ));
921 }
922 }
923 "MIN" => {
924 if let Some(sqlparser::ast::FunctionArg::Unnamed(
925 sqlparser::ast::FunctionArgExpr::Expr(SqlExpr::Identifier(ident)),
926 )) = args.first()
927 {
928 let resolved = df.resolve_column_name(ident.value.as_str())?;
929 (
930 col(resolved.as_str()).min(),
931 format!("min({})", ident.value),
932 )
933 } else {
934 return Err(PolarsError::InvalidOperation(
935 "SQL: MIN(column) only.".into(),
936 ));
937 }
938 }
939 "MAX" => {
940 if let Some(sqlparser::ast::FunctionArg::Unnamed(
941 sqlparser::ast::FunctionArgExpr::Expr(SqlExpr::Identifier(ident)),
942 )) = args.first()
943 {
944 let resolved = df.resolve_column_name(ident.value.as_str())?;
945 (
946 col(resolved.as_str()).max(),
947 format!("max({})", ident.value),
948 )
949 } else {
950 return Err(PolarsError::InvalidOperation(
951 "SQL: MAX(column) only.".into(),
952 ));
953 }
954 }
955 _ => {
956 return Err(PolarsError::InvalidOperation(
957 format!(
958 "SQL: unsupported aggregate in SELECT: {}. Use COUNT, SUM, AVG, MIN, MAX.",
959 func_name
960 )
961 .into(),
962 ));
963 }
964 };
965 let name = alias_override.unwrap_or(default_alias.as_str());
966 agg.push(expr.alias(name));
967 Ok(())
968}
969
970fn projection_is_scalar_aggregate(projection: &[SelectItem]) -> bool {
973 use sqlparser::ast::SelectItem;
974 if projection.is_empty() {
975 return false;
976 }
977 for item in projection {
978 let is_agg = match item {
979 SelectItem::UnnamedExpr(SqlExpr::Function(f)) => is_agg_function_name(f),
980 SelectItem::ExprWithAlias {
981 expr: SqlExpr::Function(f),
982 ..
983 } => is_agg_function_name(f),
984 _ => false,
985 };
986 if !is_agg {
987 return false;
988 }
989 }
990 true
991}
992
993fn is_agg_function_name(func: &Function) -> bool {
994 let name = func
995 .name
996 .0
997 .last()
998 .and_then(|p| p.as_ident())
999 .map(|i| i.value.as_str())
1000 .unwrap_or("");
1001 matches!(
1002 name.to_uppercase().as_str(),
1003 "COUNT" | "SUM" | "AVG" | "MEAN" | "MIN" | "MAX"
1004 )
1005}
1006
1007fn agg_function_key(func: &Function) -> Option<(String, String)> {
1009 let name = func
1010 .name
1011 .0
1012 .last()
1013 .and_then(|p| p.as_ident())
1014 .map(|i| i.value.as_str())
1015 .unwrap_or("");
1016 if !matches!(
1017 name.to_uppercase().as_str(),
1018 "COUNT" | "SUM" | "AVG" | "MEAN" | "MIN" | "MAX"
1019 ) {
1020 return None;
1021 }
1022 let arg_desc = match function_args_slice(&func.args).first() {
1023 None => "*".to_string(),
1024 Some(sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(
1025 SqlExpr::Identifier(ident),
1026 ))) => ident.value.to_string(),
1027 Some(sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(
1028 SqlExpr::Wildcard(_),
1029 ))) => "*".to_string(),
1030 _ => return None,
1031 };
1032 Some((name.to_uppercase(), arg_desc))
1033}
1034
1035fn extract_having_agg_calls(expr: &SqlExpr) -> Vec<(Function, String)> {
1037 let mut seen: HashMap<(String, String), String> = HashMap::new();
1038 let mut list: Vec<(Function, String)> = Vec::new();
1039 fn walk(
1040 e: &SqlExpr,
1041 seen: &mut HashMap<(String, String), String>,
1042 list: &mut Vec<(Function, String)>,
1043 ) {
1044 if let SqlExpr::Function(f) = e {
1045 if let Some(key) = agg_function_key(f) {
1046 if !seen.contains_key(&key) {
1047 let alias = format!("__having_{}", list.len());
1048 seen.insert(key.clone(), alias.clone());
1049 list.push((f.clone(), alias));
1050 }
1051 return;
1052 }
1053 }
1054 match e {
1055 SqlExpr::BinaryOp { left, right, .. } => {
1056 walk(left.as_ref(), seen, list);
1057 walk(right.as_ref(), seen, list);
1058 }
1059 SqlExpr::UnaryOp { expr: inner, .. } => walk(inner.as_ref(), seen, list),
1060 SqlExpr::IsNull(inner) | SqlExpr::IsNotNull(inner) => walk(inner.as_ref(), seen, list),
1061 SqlExpr::Function(f) => {
1062 for arg in function_args_slice(&f.args) {
1063 if let FunctionArg::Unnamed(FunctionArgExpr::Expr(a)) = arg {
1064 walk(a, seen, list);
1065 }
1066 }
1067 }
1068 _ => {}
1069 }
1070 }
1071 walk(expr, &mut seen, &mut list);
1072 list
1073}
1074
1075fn projection_to_agg_exprs(
1076 projection: &[SelectItem],
1077 group_cols: &[String],
1078 df: &DataFrame,
1079) -> Result<Vec<Expr>, PolarsError> {
1080 let mut agg = Vec::new();
1081 for item in projection {
1082 match item {
1083 SelectItem::UnnamedExpr(SqlExpr::Identifier(ident)) => {
1084 let resolved = df.resolve_column_name(ident.value.as_str())?;
1085 if !group_cols.iter().any(|c| c == &resolved) {
1086 return Err(PolarsError::InvalidOperation(
1087 format!(
1088 "SQL: non-aggregated column '{}' must appear in GROUP BY.",
1089 ident.value
1090 )
1091 .into(),
1092 ));
1093 }
1094 }
1095 SelectItem::UnnamedExpr(SqlExpr::CompoundIdentifier(parts)) => {
1096 let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
1097 let resolved = df.resolve_column_name(name)?;
1098 if !group_cols.iter().any(|c| c == &resolved) {
1099 return Err(PolarsError::InvalidOperation(
1100 format!(
1101 "SQL: non-aggregated column '{}' must appear in GROUP BY.",
1102 name
1103 )
1104 .into(),
1105 ));
1106 }
1107 }
1108 SelectItem::UnnamedExpr(SqlExpr::Function(Function { name, args, .. })) => {
1109 push_agg_function(name, function_args_slice(args), df, None, &mut agg)?;
1110 }
1111 SelectItem::ExprWithAlias { expr, alias } => {
1112 let alias_str = alias.value.as_str();
1113 match expr {
1114 SqlExpr::Identifier(ident) => {
1115 let resolved = df.resolve_column_name(ident.value.as_str())?;
1116 if !group_cols.iter().any(|c| c == &resolved) {
1117 return Err(PolarsError::InvalidOperation(
1118 format!(
1119 "SQL: non-aggregated column '{}' must appear in GROUP BY.",
1120 ident.value
1121 )
1122 .into(),
1123 ));
1124 }
1125 }
1127 SqlExpr::CompoundIdentifier(parts) => {
1128 let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
1129 let resolved = df.resolve_column_name(name)?;
1130 if !group_cols.iter().any(|c| c == &resolved) {
1131 return Err(PolarsError::InvalidOperation(
1132 format!(
1133 "SQL: non-aggregated column '{}' must appear in GROUP BY.",
1134 name
1135 )
1136 .into(),
1137 ));
1138 }
1139 }
1140 SqlExpr::Function(Function { name, args, .. }) => {
1141 push_agg_function(
1142 name,
1143 function_args_slice(args),
1144 df,
1145 Some(alias_str),
1146 &mut agg,
1147 )?;
1148 }
1149 _ => {
1150 return Err(PolarsError::InvalidOperation(
1151 format!(
1152 "SQL: unsupported aliased SELECT item in aggregation: {:?}",
1153 expr
1154 )
1155 .into(),
1156 ));
1157 }
1158 }
1159 }
1160 SelectItem::Wildcard(_) => {
1161 return Err(PolarsError::InvalidOperation(
1162 "SQL: SELECT * with GROUP BY is not supported; list columns and aggregates explicitly.".into(),
1163 ));
1164 }
1165 _ => {
1166 return Err(PolarsError::InvalidOperation(
1167 format!("SQL: unsupported SELECT item in aggregation: {:?}", item).into(),
1168 ));
1169 }
1170 }
1171 }
1172 Ok(agg)
1173}
1174
1175fn sql_limit_to_usize(expr: &SqlExpr) -> Result<usize, PolarsError> {
1176 match expr {
1177 SqlExpr::Value(ValueWithSpan {
1178 value: Value::Number(s, _),
1179 ..
1180 }) => {
1181 let n: i64 = s.parse().map_err(|_| {
1182 PolarsError::InvalidOperation(
1183 format!("SQL: LIMIT must be a positive integer, got '{}'", s).into(),
1184 )
1185 })?;
1186 if n < 0 {
1187 return Err(PolarsError::InvalidOperation(
1188 "SQL: LIMIT must be non-negative.".into(),
1189 ));
1190 }
1191 Ok(n as usize)
1192 }
1193 _ => Err(PolarsError::InvalidOperation(
1194 "SQL: LIMIT must be a literal integer.".into(),
1195 )),
1196 }
1197}