Skip to main content

robin_sparkless/sql/
translator.rs

1//! Translate sqlparser AST to DataFrame operations.
2//! Resolves unknown functions as UDFs from the session registry.
3
4use std::collections::HashMap;
5
6use crate::column::Column;
7use crate::dataframe::{DataFrame, JoinType, join};
8use crate::functions;
9use crate::session::{SparkSession, set_thread_udf_session};
10use polars::prelude::{DataFrame as PlDataFrame, Expr, PolarsError, col, lit};
11use sqlparser::ast::{
12    BinaryOperator, Expr as SqlExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments,
13    GroupByExpr, JoinConstraint, JoinOperator, ObjectType, OrderByKind, Query, Select, SelectItem,
14    SetExpr, Statement, TableFactor, Value, ValueWithSpan,
15};
16
17use super::parser;
18
19/// Return a slice of positional function arguments for List variant; empty otherwise.
20fn function_args_slice(args: &FunctionArguments) -> &[FunctionArg] {
21    match args {
22        FunctionArguments::List(list) => &list.args,
23        _ => &[],
24    }
25}
26
27/// Parse a single SQL expression string and convert to Polars Expr using the given DataFrame for column resolution.
28/// Used by selectExpr/expr() for PySpark parity. Parses "SELECT expr FROM __t" and returns the first select item's Expr.
29pub fn expr_string_to_polars(
30    expr_str: &str,
31    session: &SparkSession,
32    df: &DataFrame,
33) -> Result<Expr, PolarsError> {
34    let query = format!("SELECT {} FROM __selectexpr_t", expr_str);
35    let stmt = parser::parse_sql(&query)?;
36    let query_ast = match &stmt {
37        Statement::Query(q) => q.as_ref(),
38        _ => {
39            return Err(PolarsError::InvalidOperation(
40                "expr_string_to_polars: expected SELECT statement".into(),
41            ));
42        }
43    };
44    let body = match query_ast.body.as_ref() {
45        SetExpr::Select(s) => s.as_ref(),
46        _ => {
47            return Err(PolarsError::InvalidOperation(
48                "expr_string_to_polars: expected SELECT".into(),
49            ));
50        }
51    };
52    let first = body.projection.first().ok_or_else(|| {
53        PolarsError::InvalidOperation("expr_string_to_polars: empty SELECT list".into())
54    })?;
55    set_thread_udf_session(session.clone());
56    let (sql_expr, alias) = match first {
57        SelectItem::UnnamedExpr(e) => ((*e).clone(), None),
58        SelectItem::ExprWithAlias { expr, alias: a } => ((*expr).clone(), Some(a.value.as_str())),
59        _ => {
60            return Err(PolarsError::InvalidOperation(
61                format!("expr_string_to_polars: unsupported select item {:?}", first).into(),
62            ));
63        }
64    };
65    let expr = sql_expr_to_polars(&sql_expr, session, Some(df), None)?;
66    Ok(match alias {
67        Some(a) => expr.alias(a),
68        None => expr,
69    })
70}
71
72/// Translate a parsed Statement (Query or DDL) into a DataFrame using the session catalog.
73/// CREATE SCHEMA / CREATE DATABASE return empty DataFrame. DROP TABLE / DROP VIEW remove from session catalog.
74pub fn translate(
75    session: &SparkSession,
76    stmt: &Statement,
77) -> Result<crate::dataframe::DataFrame, PolarsError> {
78    set_thread_udf_session(session.clone());
79    match stmt {
80        Statement::Query(q) => translate_query(session, q.as_ref()),
81        Statement::CreateSchema { schema_name, .. } => {
82            let name = schema_name.to_string();
83            session.register_database(&name);
84            Ok(DataFrame::from_polars_with_options(
85                PlDataFrame::empty(),
86                session.is_case_sensitive(),
87            ))
88        }
89        Statement::CreateDatabase { db_name, .. } => {
90            let name = db_name.to_string();
91            session.register_database(&name);
92            Ok(DataFrame::from_polars_with_options(
93                PlDataFrame::empty(),
94                session.is_case_sensitive(),
95            ))
96        }
97        Statement::Drop {
98            object_type: ObjectType::Table | ObjectType::View,
99            names,
100            ..
101        } => {
102            for obj_name in names {
103                let name = obj_name.to_string();
104                if name.starts_with("global_temp.") {
105                    if let Some(suffix) = name.strip_prefix("global_temp.") {
106                        session.drop_global_temp_view(suffix);
107                    }
108                }
109                session.drop_temp_view(&name);
110                session.drop_table(&name);
111            }
112            Ok(DataFrame::from_polars_with_options(
113                PlDataFrame::empty(),
114                session.is_case_sensitive(),
115            ))
116        }
117        Statement::Drop {
118            object_type: ObjectType::Schema,
119            names,
120            ..
121        } => {
122            for obj_name in names {
123                session.drop_database(&obj_name.to_string());
124            }
125            Ok(DataFrame::from_polars_with_options(
126                PlDataFrame::empty(),
127                session.is_case_sensitive(),
128            ))
129        }
130        _ => Err(PolarsError::InvalidOperation(
131            "SQL: only SELECT, CREATE SCHEMA/DATABASE, and DROP TABLE/VIEW/SCHEMA are supported."
132                .into(),
133        )),
134    }
135}
136
137fn translate_query(
138    session: &SparkSession,
139    query: &Query,
140) -> Result<crate::dataframe::DataFrame, PolarsError> {
141    let body = match query.body.as_ref() {
142        SetExpr::Select(select) => select.as_ref(),
143        _ => {
144            return Err(PolarsError::InvalidOperation(
145                "SQL: only SELECT (no UNION/EXCEPT/INTERSECT) is supported.".into(),
146            ));
147        }
148    };
149    let mut df = translate_select_from(session, body)?;
150    if let Some(selection) = &body.selection {
151        let expr = sql_expr_to_polars(selection, session, Some(&df), None)?;
152        df = df.filter(expr)?;
153    }
154    let group_exprs: &[SqlExpr] = match &body.group_by {
155        GroupByExpr::Expressions(exprs, _) => exprs.as_slice(),
156        GroupByExpr::All(_) => {
157            return Err(PolarsError::InvalidOperation(
158                "SQL: GROUP BY ALL is not supported. Use explicit GROUP BY columns.".into(),
159            ));
160        }
161    };
162    let has_group_by = !group_exprs.is_empty();
163    let mut having_agg_map: HashMap<(String, String), String> = HashMap::new();
164    if has_group_by {
165        // Support GROUP BY column name or expression, e.g. GROUP BY age or GROUP BY (age > 30) (issue #588).
166        let pairs: Vec<(Expr, String)> = group_exprs
167            .iter()
168            .enumerate()
169            .map(|(i, e)| {
170                Ok(match e {
171                    SqlExpr::Identifier(ident) => {
172                        let name = ident.value.as_str();
173                        let resolved = df.resolve_column_name(name)?;
174                        (col(resolved.as_str()), resolved)
175                    }
176                    SqlExpr::CompoundIdentifier(parts) => {
177                        let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
178                        let resolved = df.resolve_column_name(name)?;
179                        (col(resolved.as_str()), resolved)
180                    }
181                    _ => {
182                        let expr = sql_expr_to_polars(e, session, Some(&df), None)?;
183                        let name = format!("group_{}", i);
184                        (expr.alias(&name), name)
185                    }
186                })
187            })
188            .collect::<Result<Vec<_>, PolarsError>>()?;
189        let (group_exprs_polars, group_cols): (Vec<Expr>, Vec<String>) = pairs.into_iter().unzip();
190        let grouped = df.group_by_exprs(group_exprs_polars, group_cols.clone())?;
191        let mut agg_exprs = projection_to_agg_exprs(&body.projection, &group_cols, &df)?;
192        if let Some(having_expr) = &body.having {
193            let having_list = extract_having_agg_calls(having_expr);
194            for (func, alias) in &having_list {
195                push_agg_function(
196                    &func.name,
197                    function_args_slice(&func.args),
198                    &df,
199                    Some(alias.as_str()),
200                    &mut agg_exprs,
201                )?;
202            }
203            having_agg_map = having_list
204                .into_iter()
205                .filter_map(|(f, alias)| agg_function_key(&f).map(|k| (k, alias)))
206                .collect();
207        }
208        if agg_exprs.is_empty() {
209            df = grouped.count()?;
210        } else {
211            df = grouped.agg(agg_exprs)?;
212        }
213    } else if projection_is_scalar_aggregate(&body.projection) {
214        // SELECT AVG(salary) FROM t (no GROUP BY) — scalar aggregation (issue #587).
215        let agg_exprs = projection_to_agg_exprs(&body.projection, &[], &df)?;
216        let pl_df = df.lazy_frame().select(agg_exprs).collect()?;
217        df = DataFrame::from_polars_with_options(pl_df, df.case_sensitive);
218    } else {
219        df = apply_projection(&df, &body.projection, session)?;
220    }
221    if let Some(having_expr) = &body.having {
222        let having_polars = sql_expr_to_polars(
223            having_expr,
224            session,
225            Some(&df),
226            Some(&having_agg_map).filter(|m| !m.is_empty()),
227        )?;
228        df = df.filter(having_polars)?;
229    }
230    if let Some(order_by) = &query.order_by {
231        if let OrderByKind::Expressions(exprs) = &order_by.kind {
232            if !exprs.is_empty() {
233                let pairs: Vec<(String, bool)> = exprs
234                    .iter()
235                    .map(|o| {
236                        let col_name = sql_expr_to_col_name(&o.expr)?;
237                        let resolved = df.resolve_column_name(&col_name)?;
238                        let ascending = o.options.asc.unwrap_or(true);
239                        Ok((resolved, ascending))
240                    })
241                    .collect::<Result<Vec<_>, PolarsError>>()?;
242                let (cols, asc): (Vec<String>, Vec<bool>) = pairs.into_iter().unzip();
243                let col_refs: Vec<&str> = cols.iter().map(|s| s.as_str()).collect();
244                df = df.order_by(col_refs, asc)?;
245            }
246        }
247    }
248    let limit_expr = query.fetch.as_ref().and_then(|f| f.quantity.as_ref());
249    if let Some(limit_expr) = limit_expr {
250        let n = sql_limit_to_usize(limit_expr)?;
251        df = df.limit(n)?;
252    }
253    Ok(df)
254}
255
256fn translate_select_from(
257    session: &SparkSession,
258    select: &Select,
259) -> Result<crate::dataframe::DataFrame, PolarsError> {
260    if select.from.is_empty() {
261        return Err(PolarsError::InvalidOperation(
262            "SQL: FROM clause is required. Register a table with create_or_replace_temp_view."
263                .into(),
264        ));
265    }
266    let first_tj = &select.from[0];
267    let mut df = resolve_table_factor(session, &first_tj.relation)?;
268    for join_spec in &first_tj.joins {
269        let right_df = resolve_table_factor(session, &join_spec.relation)?;
270        let join_type = match &join_spec.join_operator {
271            JoinOperator::Inner(_) => JoinType::Inner,
272            JoinOperator::LeftOuter(_) => JoinType::Left,
273            JoinOperator::RightOuter(_) => JoinType::Right,
274            JoinOperator::FullOuter(_) => JoinType::Outer,
275            _ => {
276                return Err(PolarsError::InvalidOperation(
277                    "SQL: only INNER, LEFT, RIGHT, FULL JOIN are supported.".into(),
278                ));
279            }
280        };
281        let on_cols = join_condition_to_on_columns(&join_spec.join_operator)?;
282        let on_refs: Vec<&str> = on_cols.iter().map(|s| s.as_str()).collect();
283        df = join(
284            &df,
285            &right_df,
286            on_refs,
287            join_type,
288            session.is_case_sensitive(),
289        )?;
290    }
291    Ok(df)
292}
293
294fn resolve_table_factor(
295    session: &SparkSession,
296    factor: &TableFactor,
297) -> Result<crate::dataframe::DataFrame, PolarsError> {
298    match factor {
299        TableFactor::Table { name, .. } => {
300            // Build full name for global_temp.xyz (sqlparser: ObjectNamePart::Identifier(...))
301            let table_name = if name.0.len() >= 2 {
302                let parts: Vec<String> = name
303                    .0
304                    .iter()
305                    .filter_map(|p| p.as_ident().map(|i| i.value.clone()))
306                    .collect();
307                parts.join(".")
308            } else {
309                name.0
310                    .last()
311                    .and_then(|p| p.as_ident())
312                    .map(|i| i.value.clone())
313                    .unwrap_or_default()
314            };
315            session.table(&table_name)
316        }
317        _ => Err(PolarsError::InvalidOperation(
318            "SQL: only plain table names are supported in FROM (no subqueries, derived tables). Register with create_or_replace_temp_view.".into(),
319        )),
320    }
321}
322
323fn join_condition_to_on_columns(join_op: &JoinOperator) -> Result<Vec<String>, PolarsError> {
324    let constraint = match join_op {
325        JoinOperator::Inner(c)
326        | JoinOperator::LeftOuter(c)
327        | JoinOperator::RightOuter(c)
328        | JoinOperator::FullOuter(c) => c,
329        _ => {
330            return Err(PolarsError::InvalidOperation(
331                "SQL: only INNER/LEFT/RIGHT/FULL JOIN with ON are supported.".into(),
332            ));
333        }
334    };
335    match constraint {
336        JoinConstraint::On(expr) => match expr {
337            SqlExpr::BinaryOp {
338                left,
339                op: BinaryOperator::Eq,
340                right,
341            } => {
342                let l = sql_expr_to_col_name(left.as_ref())?;
343                let r = sql_expr_to_col_name(right.as_ref())?;
344                if l != r {
345                    return Err(PolarsError::InvalidOperation(
346                            "SQL: JOIN ON must use same column name on both sides (e.g. a.id = b.id where both become 'id').".into(),
347                        ));
348                }
349                Ok(vec![l])
350            }
351            _ => Err(PolarsError::InvalidOperation(
352                "SQL: JOIN ON must be a single equality (col = col).".into(),
353            )),
354        },
355        _ => Err(PolarsError::InvalidOperation(
356            "SQL: JOIN must use ON (equality); NATURAL/USING not supported.".into(),
357        )),
358    }
359}
360
361fn sql_expr_to_polars(
362    expr: &SqlExpr,
363    session: &SparkSession,
364    df: Option<&DataFrame>,
365    having_agg_map: Option<&HashMap<(String, String), String>>,
366) -> Result<Expr, PolarsError> {
367    match expr {
368        SqlExpr::Identifier(ident) => {
369            let name = ident.value.as_str();
370            let resolved = df
371                .map(|d| d.resolve_column_name(name))
372                .transpose()?
373                .unwrap_or_else(|| name.to_string());
374            Ok(col(resolved.as_str()))
375        }
376        SqlExpr::CompoundIdentifier(parts) => {
377            let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
378            let resolved = df
379                .map(|d| d.resolve_column_name(name))
380                .transpose()?
381                .unwrap_or_else(|| name.to_string());
382            Ok(col(resolved.as_str()))
383        }
384        SqlExpr::Value(ValueWithSpan { value: Value::Number(s, _), .. }) => {
385            if s.contains('.') {
386                let v: f64 = s.parse().map_err(|_| {
387                    PolarsError::InvalidOperation(format!("SQL: invalid number literal '{}'", s).into())
388                })?;
389                Ok(lit(v))
390            } else {
391                let v: i64 = s.parse().map_err(|_| {
392                    PolarsError::InvalidOperation(format!("SQL: invalid integer literal '{}'", s).into())
393                })?;
394                Ok(lit(v))
395            }
396        }
397        SqlExpr::Value(ValueWithSpan { value: Value::SingleQuotedString(s), .. }) => Ok(lit(s.as_str())),
398        SqlExpr::Value(ValueWithSpan { value: Value::Boolean(b), .. }) => Ok(lit(*b)),
399        SqlExpr::Value(ValueWithSpan { value: Value::Null, .. }) => Ok(lit(polars::prelude::NULL)),
400        SqlExpr::BinaryOp { left, op, right } => {
401            let l = sql_expr_to_polars(left, session, df, having_agg_map)?;
402            let r = sql_expr_to_polars(right, session, df, having_agg_map)?;
403            match op {
404                BinaryOperator::Eq => Ok(l.eq(r)),
405                BinaryOperator::NotEq => Ok(l.eq(r).not()),
406                BinaryOperator::Gt => Ok(l.gt(r)),
407                BinaryOperator::GtEq => Ok(l.gt_eq(r)),
408                BinaryOperator::Lt => Ok(l.lt(r)),
409                BinaryOperator::LtEq => Ok(l.lt_eq(r)),
410                BinaryOperator::And => Ok(l.and(r)),
411                BinaryOperator::Or => Ok(l.or(r)),
412                _ => Err(PolarsError::InvalidOperation(
413                    format!("SQL: unsupported operator in WHERE: {:?}. Use =, <>, <, <=, >, >=, AND, OR.", op).into(),
414                )),
415            }
416        }
417        SqlExpr::Nested(inner) => sql_expr_to_polars(inner, session, df, having_agg_map),
418        SqlExpr::IsNull(expr) => Ok(sql_expr_to_polars(expr, session, df, having_agg_map)?.is_null()),
419        SqlExpr::IsNotNull(expr) => Ok(sql_expr_to_polars(expr, session, df, having_agg_map)?.is_not_null()),
420        SqlExpr::UnaryOp { op, expr } => {
421            let e = sql_expr_to_polars(expr, session, df, having_agg_map)?;
422            match op {
423                sqlparser::ast::UnaryOperator::Not => Ok(e.not()),
424                _ => Err(PolarsError::InvalidOperation(
425                    format!("SQL: unsupported unary operator in WHERE: {:?}", op).into(),
426                )),
427            }
428        }
429        SqlExpr::Function(func) => {
430            if let Some(map) = having_agg_map {
431                if let Some(key) = agg_function_key(func) {
432                    if let Some(col_name) = map.get(&key) {
433                        return Ok(col(col_name.as_str()));
434                    }
435                }
436            }
437            sql_function_to_expr(func, session, df)
438        }
439        SqlExpr::Like {
440            negated,
441            expr: left,
442            pattern,
443            escape_char,
444            any: _,
445        } => {
446            let col_expr = sql_expr_to_polars(left.as_ref(), session, df, having_agg_map)?;
447            let pattern_str = sql_expr_to_string_literal(pattern.as_ref())?;
448            let col_col = crate::column::Column::from_expr(col_expr, None);
449            let escape: Option<char> = escape_char.as_ref().and_then(|v| match v {
450                Value::SingleQuotedString(s) => s.chars().next(),
451                _ => None,
452            });
453            let like_expr = col_col.like(&pattern_str, escape).into_expr();
454            Ok(if *negated {
455                like_expr.not()
456            } else {
457                like_expr
458            })
459        }
460        SqlExpr::InList {
461            expr: left,
462            list,
463            negated,
464        } => {
465            let col_expr = sql_expr_to_polars(left.as_ref(), session, df, having_agg_map)?;
466            if list.is_empty() {
467                return Ok(lit(false));
468            }
469            let series = sql_in_list_to_series(list)?;
470            let in_expr = col_expr.is_in(lit(series), false);
471            Ok(if *negated {
472                in_expr.not()
473            } else {
474                in_expr
475            })
476        }
477        _ => Err(PolarsError::InvalidOperation(
478            format!("SQL: unsupported expression in WHERE: {:?}. Use column, literal, =, <, >, AND, OR, IS NULL, LIKE, IN.", expr).into(),
479        )),
480    }
481}
482
483/// Convert SQL function call to Polars Expr. Supports built-ins (UPPER, LOWER, etc.) and UDFs.
484/// For Python UDF in WHERE/HAVING we cannot return a lazy Expr; returns error (Python UDF in
485/// predicates requires eager materialization - deferred).
486fn sql_function_to_expr(
487    func: &Function,
488    session: &SparkSession,
489    df: Option<&DataFrame>,
490) -> Result<Expr, PolarsError> {
491    let func_name = func
492        .name
493        .0
494        .last()
495        .and_then(|p| p.as_ident())
496        .map(|i| i.value.as_str())
497        .unwrap_or("");
498    let args = sql_function_args_to_columns(func, session, df)?;
499
500    let case_sensitive = session.is_case_sensitive();
501
502    // Built-in scalar functions (single-column arg)
503    if let Some(col) = args.first() {
504        let builtin_expr = match func_name.to_uppercase().as_str() {
505            "UPPER" | "UCASE" if args.len() == 1 => Some(functions::upper(col).expr().clone()),
506            "LOWER" | "LCASE" if args.len() == 1 => Some(functions::lower(col).expr().clone()),
507            _ => None,
508        };
509        if let Some(e) = builtin_expr {
510            return Ok(e);
511        }
512    }
513
514    // UDF lookup
515    if session.udf_registry.has_udf(func_name, case_sensitive) {
516        let col = functions::call_udf(func_name, &args)?;
517        if col.udf_call.is_some() {
518            return Err(PolarsError::InvalidOperation(
519                "SQL: Python UDF in WHERE/HAVING not yet supported. Use in SELECT.".into(),
520            ));
521        }
522        return Ok(col.expr().clone());
523    }
524
525    Err(PolarsError::InvalidOperation(
526        format!("SQL: unknown function '{}'. Register with spark.udf.register() or use built-ins: UPPER, LOWER.", func_name).into(),
527    ))
528}
529
530fn sql_function_args_to_columns(
531    func: &Function,
532    session: &SparkSession,
533    df: Option<&DataFrame>,
534) -> Result<Vec<Column>, PolarsError> {
535    let mut cols = Vec::new();
536    for arg in function_args_slice(&func.args) {
537        if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg {
538            let e = sql_expr_to_polars(expr, session, df, None)?;
539            cols.push(Column::from_expr(e, None));
540        } else {
541            return Err(PolarsError::InvalidOperation(
542                "SQL: only positional function arguments supported.".into(),
543            ));
544        }
545    }
546    Ok(cols)
547}
548
549fn sql_expr_to_col_name(expr: &SqlExpr) -> Result<String, PolarsError> {
550    match expr {
551        SqlExpr::Identifier(ident) => Ok(ident.value.clone()),
552        SqlExpr::CompoundIdentifier(parts) => parts
553            .last()
554            .map(|i| i.value.clone())
555            .ok_or_else(|| PolarsError::InvalidOperation("SQL: empty compound identifier.".into())),
556        _ => Err(PolarsError::InvalidOperation(
557            format!("SQL: expected column name, got {:?}", expr).into(),
558        )),
559    }
560}
561
562/// Extract a string literal from a SQL expression (for LIKE pattern). Issue #590.
563fn sql_expr_to_string_literal(expr: &SqlExpr) -> Result<String, PolarsError> {
564    match expr {
565        SqlExpr::Value(ValueWithSpan {
566            value: Value::SingleQuotedString(s),
567            ..
568        }) => Ok(s.clone()),
569        _ => Err(PolarsError::InvalidOperation(
570            format!("SQL: LIKE pattern must be a string literal, got {:?}", expr).into(),
571        )),
572    }
573}
574
575/// Build a Polars Series from SQL IN list literals (for WHERE col IN (1,2,3)). Issue #590.
576fn sql_in_list_to_series(list: &[SqlExpr]) -> Result<polars::prelude::Series, PolarsError> {
577    use polars::prelude::Series;
578    let mut str_vals: Vec<String> = Vec::new();
579    let mut int_vals: Vec<i64> = Vec::new();
580    let mut float_vals: Vec<f64> = Vec::new();
581    let mut has_string = false;
582    let mut has_float = false;
583    for e in list {
584        match e {
585            SqlExpr::Value(ValueWithSpan {
586                value: Value::SingleQuotedString(s),
587                ..
588            }) => {
589                str_vals.push(s.clone());
590                has_string = true;
591            }
592            SqlExpr::Value(ValueWithSpan {
593                value: Value::Number(n, _),
594                ..
595            }) => {
596                str_vals.push(n.clone());
597                if n.contains('.') {
598                    let v: f64 = n.parse().map_err(|_| {
599                        PolarsError::InvalidOperation(
600                            format!("SQL: invalid number in IN list '{}'", n).into(),
601                        )
602                    })?;
603                    float_vals.push(v);
604                    has_float = true;
605                } else {
606                    let v: i64 = n.parse().map_err(|_| {
607                        PolarsError::InvalidOperation(
608                            format!("SQL: invalid integer in IN list '{}'", n).into(),
609                        )
610                    })?;
611                    int_vals.push(v);
612                }
613            }
614            SqlExpr::Value(ValueWithSpan {
615                value: Value::Boolean(b),
616                ..
617            }) => {
618                str_vals.push(b.to_string());
619                has_string = true;
620            }
621            SqlExpr::Value(ValueWithSpan {
622                value: Value::Null, ..
623            }) => {}
624            _ => {
625                return Err(PolarsError::InvalidOperation(
626                    format!("SQL: IN list supports only literals, got {:?}", e).into(),
627                ));
628            }
629        }
630    }
631    let series = if has_string {
632        Series::from_iter(str_vals.iter().map(|s| s.as_str()))
633    } else if !has_float && int_vals.len() == str_vals.len() {
634        Series::from_iter(int_vals)
635    } else if float_vals.len() == str_vals.len() {
636        Series::from_iter(float_vals)
637    } else {
638        Series::from_iter(str_vals.iter().map(|s| s.as_str()))
639    };
640    Ok(series)
641}
642
643/// Projection item: either a plain Expr (built-in, Rust UDF, identifier) or Python UDF Column.
644enum ProjItem {
645    Expr(Expr, String),
646    PythonUdf(Column, String),
647}
648
649fn apply_projection(
650    df: &crate::dataframe::DataFrame,
651    projection: &[SelectItem],
652    session: &SparkSession,
653) -> Result<crate::dataframe::DataFrame, PolarsError> {
654    // Wildcard: expand to all columns
655    for item in projection {
656        if matches!(item, SelectItem::Wildcard(_)) {
657            let column_names = df.columns()?;
658            let all_col_names: Vec<&str> = column_names.iter().map(|s| s.as_str()).collect();
659            return df.select(all_col_names);
660        }
661    }
662
663    let mut items = Vec::new();
664    for item in projection {
665        let proj = match item {
666            SelectItem::UnnamedExpr(SqlExpr::Identifier(ident)) => {
667                let name = ident.value.as_str();
668                let resolved = df.resolve_column_name(name)?;
669                ProjItem::Expr(col(resolved.as_str()), name.to_string())
670            }
671            SelectItem::UnnamedExpr(SqlExpr::CompoundIdentifier(parts)) => {
672                let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
673                let resolved = df.resolve_column_name(name)?;
674                ProjItem::Expr(col(resolved.as_str()), name.to_string())
675            }
676            SelectItem::UnnamedExpr(SqlExpr::Function(func)) => {
677                projection_function_to_item(func, session, Some(df))?
678            }
679            SelectItem::ExprWithAlias { expr, alias } => {
680                let alias_str = alias.value.clone();
681                match expr {
682                    SqlExpr::Identifier(ident) => {
683                        let name = ident.value.as_str();
684                        let resolved = df.resolve_column_name(name)?;
685                        ProjItem::Expr(col(resolved.as_str()), alias_str)
686                    }
687                    SqlExpr::CompoundIdentifier(parts) => {
688                        let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
689                        let resolved = df.resolve_column_name(name)?;
690                        ProjItem::Expr(col(resolved.as_str()), alias_str)
691                    }
692                    SqlExpr::Function(func) => {
693                        let mut item = projection_function_to_item(func, session, Some(df))?;
694                        // Override alias with AS alias
695                        item = match item {
696                            ProjItem::Expr(e, _) => ProjItem::Expr(e, alias_str),
697                            ProjItem::PythonUdf(c, _) => ProjItem::PythonUdf(c, alias_str),
698                        };
699                        item
700                    }
701                    _ => {
702                        return Err(PolarsError::InvalidOperation(
703                            format!("SQL: unsupported expression with alias: {:?}", expr).into(),
704                        ));
705                    }
706                }
707            }
708            _ => {
709                return Err(PolarsError::InvalidOperation(
710                    format!(
711                        "SQL: SELECT supports column names, *, and function calls. Got {:?}",
712                        item
713                    )
714                    .into(),
715                ));
716            }
717        };
718        items.push(proj);
719    }
720
721    if items.is_empty() {
722        return Err(PolarsError::InvalidOperation(
723            "SQL: SELECT must list at least one column or *.".into(),
724        ));
725    }
726
727    // Check if any Python UDF (requires with_column path)
728    let has_python_udf = items.iter().any(|i| matches!(i, ProjItem::PythonUdf(_, _)));
729
730    let mut df = df.clone();
731
732    if has_python_udf {
733        // Add Python UDF columns first, then select all in order
734        for item in &items {
735            if let ProjItem::PythonUdf(col, alias) = item {
736                df = df.with_column(alias, col)?;
737            }
738        }
739        let exprs: Vec<Expr> = items
740            .iter()
741            .map(|i| match i {
742                ProjItem::Expr(e, alias) => e.clone().alias(alias),
743                ProjItem::PythonUdf(_, alias) => col(alias.as_str()).alias(alias),
744            })
745            .collect();
746        df.select_exprs(exprs)
747    } else {
748        // All exprs: use select_with_exprs
749        let exprs: Vec<Expr> = items
750            .iter()
751            .map(|i| match i {
752                ProjItem::Expr(e, alias) => e.clone().alias(alias),
753                ProjItem::PythonUdf(_, _) => unreachable!(),
754            })
755            .collect();
756        df.select_exprs(exprs)
757    }
758}
759
760fn sql_function_alias(func: &Function) -> String {
761    let func_name = func
762        .name
763        .0
764        .last()
765        .and_then(|p| p.as_ident())
766        .map(|i| i.value.as_str())
767        .unwrap_or("");
768    let arg_parts: Vec<String> = function_args_slice(&func.args)
769        .iter()
770        .filter_map(|a| {
771            if let FunctionArg::Unnamed(FunctionArgExpr::Expr(SqlExpr::Identifier(ident))) = a {
772                Some(ident.value.to_string())
773            } else if let FunctionArg::Unnamed(FunctionArgExpr::Expr(
774                SqlExpr::CompoundIdentifier(parts),
775            )) = a
776            {
777                parts.last().map(|i| i.value.to_string())
778            } else {
779                Some("_".to_string())
780            }
781        })
782        .collect();
783    if arg_parts.is_empty() {
784        format!("{}()", func_name)
785    } else {
786        format!("{}({})", func_name, arg_parts.join(", "))
787    }
788}
789
790fn projection_function_to_item(
791    func: &Function,
792    session: &SparkSession,
793    df: Option<&DataFrame>,
794) -> Result<ProjItem, PolarsError> {
795    let func_name = func
796        .name
797        .0
798        .last()
799        .and_then(|p| p.as_ident())
800        .map(|i| i.value.as_str())
801        .unwrap_or("");
802    let args = sql_function_args_to_columns(func, session, df)?;
803    let case_sensitive = session.is_case_sensitive();
804    let alias = sql_function_alias(func);
805
806    // Built-ins
807    if let Some(col) = args.first() {
808        let builtin = match func_name.to_uppercase().as_str() {
809            "UPPER" | "UCASE" if args.len() == 1 => {
810                Some(functions::upper(col).expr().clone().alias(&alias))
811            }
812            "LOWER" | "LCASE" if args.len() == 1 => {
813                Some(functions::lower(col).expr().clone().alias(&alias))
814            }
815            _ => None,
816        };
817        if let Some(e) = builtin {
818            return Ok(ProjItem::Expr(e, alias));
819        }
820    }
821
822    // UDF lookup
823    if session.udf_registry.has_udf(func_name, case_sensitive) {
824        let col = functions::call_udf(func_name, &args)?;
825        if col.udf_call.is_some() {
826            return Ok(ProjItem::PythonUdf(col, alias));
827        }
828        return Ok(ProjItem::Expr(col.expr().clone().alias(&alias), alias));
829    }
830
831    Err(PolarsError::InvalidOperation(
832        format!(
833            "SQL: unknown function '{}'. Register with spark.udf.register() or use built-ins: UPPER, LOWER.",
834            func_name
835        )
836        .into(),
837    ))
838}
839
840/// Push one aggregate expression from a SQL function. `alias_override`: when Some (e.g. AS cnt), use it; when None use default (count, sum(col), etc).
841fn push_agg_function(
842    name: &sqlparser::ast::ObjectName,
843    args: &[sqlparser::ast::FunctionArg],
844    df: &DataFrame,
845    alias_override: Option<&str>,
846    agg: &mut Vec<Expr>,
847) -> Result<(), PolarsError> {
848    use polars::prelude::len;
849
850    let func_name = name
851        .0
852        .last()
853        .and_then(|p| p.as_ident())
854        .map(|i| i.value.as_str())
855        .unwrap_or("");
856    let (expr, default_alias) = match func_name.to_uppercase().as_str() {
857        "COUNT" => {
858            let e = if args.is_empty() {
859                len()
860            } else if args.len() == 1 {
861                use sqlparser::ast::FunctionArgExpr;
862                match &args[0] {
863                    sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => len(),
864                    sqlparser::ast::FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => {
865                        let expr = match e {
866                            SqlExpr::Nested(inner) => inner.as_ref(),
867                            other => other,
868                        };
869                        match expr {
870                            SqlExpr::Wildcard(_) => len(),
871                            SqlExpr::Identifier(ident) => {
872                                let resolved = df.resolve_column_name(ident.value.as_str())?;
873                                col(resolved.as_str()).count()
874                            }
875                            _ => len(), // COUNT(1) etc.
876                        }
877                    }
878                    _ => {
879                        return Err(PolarsError::InvalidOperation(
880                            "SQL: COUNT(*) or COUNT(column) only.".into(),
881                        ));
882                    }
883                }
884            } else {
885                return Err(PolarsError::InvalidOperation(
886                    "SQL: COUNT takes at most one argument.".into(),
887                ));
888            };
889            (e, "count".to_string())
890        }
891        "SUM" => {
892            if let Some(sqlparser::ast::FunctionArg::Unnamed(
893                sqlparser::ast::FunctionArgExpr::Expr(SqlExpr::Identifier(ident)),
894            )) = args.first()
895            {
896                let resolved = df.resolve_column_name(ident.value.as_str())?;
897                (
898                    col(resolved.as_str()).sum(),
899                    format!("sum({})", ident.value),
900                )
901            } else {
902                return Err(PolarsError::InvalidOperation(
903                    "SQL: SUM(column) only.".into(),
904                ));
905            }
906        }
907        "AVG" | "MEAN" => {
908            if let Some(sqlparser::ast::FunctionArg::Unnamed(
909                sqlparser::ast::FunctionArgExpr::Expr(SqlExpr::Identifier(ident)),
910            )) = args.first()
911            {
912                let resolved = df.resolve_column_name(ident.value.as_str())?;
913                (
914                    col(resolved.as_str()).mean(),
915                    format!("avg({})", ident.value),
916                )
917            } else {
918                return Err(PolarsError::InvalidOperation(
919                    "SQL: AVG(column) only.".into(),
920                ));
921            }
922        }
923        "MIN" => {
924            if let Some(sqlparser::ast::FunctionArg::Unnamed(
925                sqlparser::ast::FunctionArgExpr::Expr(SqlExpr::Identifier(ident)),
926            )) = args.first()
927            {
928                let resolved = df.resolve_column_name(ident.value.as_str())?;
929                (
930                    col(resolved.as_str()).min(),
931                    format!("min({})", ident.value),
932                )
933            } else {
934                return Err(PolarsError::InvalidOperation(
935                    "SQL: MIN(column) only.".into(),
936                ));
937            }
938        }
939        "MAX" => {
940            if let Some(sqlparser::ast::FunctionArg::Unnamed(
941                sqlparser::ast::FunctionArgExpr::Expr(SqlExpr::Identifier(ident)),
942            )) = args.first()
943            {
944                let resolved = df.resolve_column_name(ident.value.as_str())?;
945                (
946                    col(resolved.as_str()).max(),
947                    format!("max({})", ident.value),
948                )
949            } else {
950                return Err(PolarsError::InvalidOperation(
951                    "SQL: MAX(column) only.".into(),
952                ));
953            }
954        }
955        _ => {
956            return Err(PolarsError::InvalidOperation(
957                format!(
958                    "SQL: unsupported aggregate in SELECT: {}. Use COUNT, SUM, AVG, MIN, MAX.",
959                    func_name
960                )
961                .into(),
962            ));
963        }
964    };
965    let name = alias_override.unwrap_or(default_alias.as_str());
966    agg.push(expr.alias(name));
967    Ok(())
968}
969
970/// True if the projection contains only aggregate function calls (COUNT, SUM, AVG, MIN, MAX).
971/// Used for scalar aggregation: SELECT AVG(salary) FROM t (issue #587).
972fn projection_is_scalar_aggregate(projection: &[SelectItem]) -> bool {
973    use sqlparser::ast::SelectItem;
974    if projection.is_empty() {
975        return false;
976    }
977    for item in projection {
978        let is_agg = match item {
979            SelectItem::UnnamedExpr(SqlExpr::Function(f)) => is_agg_function_name(f),
980            SelectItem::ExprWithAlias {
981                expr: SqlExpr::Function(f),
982                ..
983            } => is_agg_function_name(f),
984            _ => false,
985        };
986        if !is_agg {
987            return false;
988        }
989    }
990    true
991}
992
993fn is_agg_function_name(func: &Function) -> bool {
994    let name = func
995        .name
996        .0
997        .last()
998        .and_then(|p| p.as_ident())
999        .map(|i| i.value.as_str())
1000        .unwrap_or("");
1001    matches!(
1002        name.to_uppercase().as_str(),
1003        "COUNT" | "SUM" | "AVG" | "MEAN" | "MIN" | "MAX"
1004    )
1005}
1006
1007/// Key for deduplicating aggregate function calls in HAVING (issue #589).
1008fn agg_function_key(func: &Function) -> Option<(String, String)> {
1009    let name = func
1010        .name
1011        .0
1012        .last()
1013        .and_then(|p| p.as_ident())
1014        .map(|i| i.value.as_str())
1015        .unwrap_or("");
1016    if !matches!(
1017        name.to_uppercase().as_str(),
1018        "COUNT" | "SUM" | "AVG" | "MEAN" | "MIN" | "MAX"
1019    ) {
1020        return None;
1021    }
1022    let arg_desc = match function_args_slice(&func.args).first() {
1023        None => "*".to_string(),
1024        Some(sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(
1025            SqlExpr::Identifier(ident),
1026        ))) => ident.value.to_string(),
1027        Some(sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(
1028            SqlExpr::Wildcard(_),
1029        ))) => "*".to_string(),
1030        _ => return None,
1031    };
1032    Some((name.to_uppercase(), arg_desc))
1033}
1034
1035/// Collect unique aggregate function calls from a HAVING expression and assign __having_0, __having_1, ...
1036fn extract_having_agg_calls(expr: &SqlExpr) -> Vec<(Function, String)> {
1037    let mut seen: HashMap<(String, String), String> = HashMap::new();
1038    let mut list: Vec<(Function, String)> = Vec::new();
1039    fn walk(
1040        e: &SqlExpr,
1041        seen: &mut HashMap<(String, String), String>,
1042        list: &mut Vec<(Function, String)>,
1043    ) {
1044        if let SqlExpr::Function(f) = e {
1045            if let Some(key) = agg_function_key(f) {
1046                if !seen.contains_key(&key) {
1047                    let alias = format!("__having_{}", list.len());
1048                    seen.insert(key.clone(), alias.clone());
1049                    list.push((f.clone(), alias));
1050                }
1051                return;
1052            }
1053        }
1054        match e {
1055            SqlExpr::BinaryOp { left, right, .. } => {
1056                walk(left.as_ref(), seen, list);
1057                walk(right.as_ref(), seen, list);
1058            }
1059            SqlExpr::UnaryOp { expr: inner, .. } => walk(inner.as_ref(), seen, list),
1060            SqlExpr::IsNull(inner) | SqlExpr::IsNotNull(inner) => walk(inner.as_ref(), seen, list),
1061            SqlExpr::Function(f) => {
1062                for arg in function_args_slice(&f.args) {
1063                    if let FunctionArg::Unnamed(FunctionArgExpr::Expr(a)) = arg {
1064                        walk(a, seen, list);
1065                    }
1066                }
1067            }
1068            _ => {}
1069        }
1070    }
1071    walk(expr, &mut seen, &mut list);
1072    list
1073}
1074
1075fn projection_to_agg_exprs(
1076    projection: &[SelectItem],
1077    group_cols: &[String],
1078    df: &DataFrame,
1079) -> Result<Vec<Expr>, PolarsError> {
1080    let mut agg = Vec::new();
1081    for item in projection {
1082        match item {
1083            SelectItem::UnnamedExpr(SqlExpr::Identifier(ident)) => {
1084                let resolved = df.resolve_column_name(ident.value.as_str())?;
1085                if !group_cols.iter().any(|c| c == &resolved) {
1086                    return Err(PolarsError::InvalidOperation(
1087                        format!(
1088                            "SQL: non-aggregated column '{}' must appear in GROUP BY.",
1089                            ident.value
1090                        )
1091                        .into(),
1092                    ));
1093                }
1094            }
1095            SelectItem::UnnamedExpr(SqlExpr::CompoundIdentifier(parts)) => {
1096                let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
1097                let resolved = df.resolve_column_name(name)?;
1098                if !group_cols.iter().any(|c| c == &resolved) {
1099                    return Err(PolarsError::InvalidOperation(
1100                        format!(
1101                            "SQL: non-aggregated column '{}' must appear in GROUP BY.",
1102                            name
1103                        )
1104                        .into(),
1105                    ));
1106                }
1107            }
1108            SelectItem::UnnamedExpr(SqlExpr::Function(Function { name, args, .. })) => {
1109                push_agg_function(name, function_args_slice(args), df, None, &mut agg)?;
1110            }
1111            SelectItem::ExprWithAlias { expr, alias } => {
1112                let alias_str = alias.value.as_str();
1113                match expr {
1114                    SqlExpr::Identifier(ident) => {
1115                        let resolved = df.resolve_column_name(ident.value.as_str())?;
1116                        if !group_cols.iter().any(|c| c == &resolved) {
1117                            return Err(PolarsError::InvalidOperation(
1118                                format!(
1119                                    "SQL: non-aggregated column '{}' must appear in GROUP BY.",
1120                                    ident.value
1121                                )
1122                                .into(),
1123                            ));
1124                        }
1125                        // Group column with alias (e.g. grp AS g): validation only; result keeps group column name from frame.
1126                    }
1127                    SqlExpr::CompoundIdentifier(parts) => {
1128                        let name = parts.last().map(|i| i.value.as_str()).unwrap_or("");
1129                        let resolved = df.resolve_column_name(name)?;
1130                        if !group_cols.iter().any(|c| c == &resolved) {
1131                            return Err(PolarsError::InvalidOperation(
1132                                format!(
1133                                    "SQL: non-aggregated column '{}' must appear in GROUP BY.",
1134                                    name
1135                                )
1136                                .into(),
1137                            ));
1138                        }
1139                    }
1140                    SqlExpr::Function(Function { name, args, .. }) => {
1141                        push_agg_function(
1142                            name,
1143                            function_args_slice(args),
1144                            df,
1145                            Some(alias_str),
1146                            &mut agg,
1147                        )?;
1148                    }
1149                    _ => {
1150                        return Err(PolarsError::InvalidOperation(
1151                            format!(
1152                                "SQL: unsupported aliased SELECT item in aggregation: {:?}",
1153                                expr
1154                            )
1155                            .into(),
1156                        ));
1157                    }
1158                }
1159            }
1160            SelectItem::Wildcard(_) => {
1161                return Err(PolarsError::InvalidOperation(
1162                    "SQL: SELECT * with GROUP BY is not supported; list columns and aggregates explicitly.".into(),
1163                ));
1164            }
1165            _ => {
1166                return Err(PolarsError::InvalidOperation(
1167                    format!("SQL: unsupported SELECT item in aggregation: {:?}", item).into(),
1168                ));
1169            }
1170        }
1171    }
1172    Ok(agg)
1173}
1174
1175fn sql_limit_to_usize(expr: &SqlExpr) -> Result<usize, PolarsError> {
1176    match expr {
1177        SqlExpr::Value(ValueWithSpan {
1178            value: Value::Number(s, _),
1179            ..
1180        }) => {
1181            let n: i64 = s.parse().map_err(|_| {
1182                PolarsError::InvalidOperation(
1183                    format!("SQL: LIMIT must be a positive integer, got '{}'", s).into(),
1184                )
1185            })?;
1186            if n < 0 {
1187                return Err(PolarsError::InvalidOperation(
1188                    "SQL: LIMIT must be non-negative.".into(),
1189                ));
1190            }
1191            Ok(n as usize)
1192        }
1193        _ => Err(PolarsError::InvalidOperation(
1194            "SQL: LIMIT must be a literal integer.".into(),
1195        )),
1196    }
1197}