polars_sql/
sql_expr.rs

1//! Expressions that are supported by the Polars SQL interface.
2//!
3//! This is useful for syntax highlighting
4//!
5//! This module defines:
6//! - all Polars SQL keywords [`all_keywords`]
7//! - all of polars SQL functions [`all_functions`]
8
9use std::fmt::Display;
10use std::ops::Div;
11
12use polars_core::prelude::*;
13use polars_lazy::prelude::*;
14use polars_plan::prelude::typed_lit;
15use polars_plan::prelude::LiteralValue::Null;
16use polars_time::Duration;
17use rand::distributions::Alphanumeric;
18use rand::{thread_rng, Rng};
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21use sqlparser::ast::{
22    BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind,
23    DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident,
24    Interval, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField,
25    UnaryOperator, Value as SQLValue,
26};
27use sqlparser::dialect::GenericDialect;
28use sqlparser::parser::{Parser, ParserOptions};
29
30use crate::functions::SQLFunctionVisitor;
31use crate::types::{
32    bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
33};
34use crate::SQLContext;
35
36#[inline]
37#[cold]
38#[must_use]
39/// Convert a Display-able error to PolarsError::SQLInterface
40pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
41    PolarsError::SQLInterface(err.to_string().into())
42}
43
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
46/// Categorises the type of (allowed) subquery constraint
47pub enum SubqueryRestriction {
48    /// Subquery must return a single column
49    SingleColumn,
50    // SingleRow,
51    // SingleValue,
52    // Any
53}
54
55/// Recursively walks a SQL Expr to create a polars Expr
56pub(crate) struct SQLExprVisitor<'a> {
57    ctx: &'a mut SQLContext,
58    active_schema: Option<&'a Schema>,
59}
60
61impl SQLExprVisitor<'_> {
62    fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
63        let array_elements = elements
64            .iter()
65            .map(|e| match e {
66                SQLExpr::Value(v) => self.visit_any_value(v, None),
67                SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
68                    SQLExpr::Value(v) => self.visit_any_value(v, Some(op)),
69                    _ => Err(polars_err!(SQLInterface: "expression {:?} is not currently supported", e)),
70                },
71                SQLExpr::Array(_) => {
72                    // TODO: nested arrays (handle FnMut issues)
73                    // let srs = self.array_expr_to_series(&[e.clone()])?;
74                    // Ok(AnyValue::List(srs))
75                    Err(polars_err!(SQLInterface: "nested array literals are not currently supported:\n{:?}", e))
76                },
77                _ => Err(polars_err!(SQLInterface: "expression {:?} is not currently supported", e)),
78            })
79            .collect::<PolarsResult<Vec<_>>>()?;
80
81        Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
82    }
83
84    fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
85        match expr {
86            SQLExpr::AllOp {
87                left,
88                compare_op,
89                right,
90            } => self.visit_all(left, compare_op, right),
91            SQLExpr::AnyOp {
92                left,
93                compare_op,
94                right,
95                is_some: _,
96            } => self.visit_any(left, compare_op, right),
97            SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
98            SQLExpr::Between {
99                expr,
100                negated,
101                low,
102                high,
103            } => self.visit_between(expr, *negated, low, high),
104            SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
105            SQLExpr::Cast {
106                kind,
107                expr,
108                data_type,
109                format,
110            } => self.visit_cast(expr, data_type, format, kind),
111            SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
112            SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
113            SQLExpr::Extract {
114                field,
115                syntax: _,
116                expr,
117            } => parse_extract_date_part(self.visit_expr(expr)?, field),
118            SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
119            SQLExpr::Function(function) => self.visit_function(function),
120            SQLExpr::Identifier(ident) => self.visit_identifier(ident),
121            SQLExpr::InList {
122                expr,
123                list,
124                negated,
125            } => {
126                let expr = self.visit_expr(expr)?;
127                let elems = self.visit_array_expr(list, false, Some(&expr))?;
128                let is_in = expr.is_in(elems);
129                Ok(if *negated { is_in.not() } else { is_in })
130            },
131            SQLExpr::InSubquery {
132                expr,
133                subquery,
134                negated,
135            } => self.visit_in_subquery(expr, subquery, *negated),
136            SQLExpr::Interval(interval) => self.visit_interval(interval),
137            SQLExpr::IsDistinctFrom(e1, e2) => {
138                Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
139            },
140            SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
141            SQLExpr::IsNotDistinctFrom(e1, e2) => {
142                Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
143            },
144            SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
145            SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
146            SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
147            SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
148            SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
149            SQLExpr::Like {
150                negated,
151                any,
152                expr,
153                pattern,
154                escape_char,
155            } => {
156                if *any {
157                    polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
158                }
159                self.visit_like(*negated, expr, pattern, escape_char, false)
160            },
161            SQLExpr::ILike {
162                negated,
163                any,
164                expr,
165                pattern,
166                escape_char,
167            } => {
168                if *any {
169                    polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
170                }
171                self.visit_like(*negated, expr, pattern, escape_char, true)
172            },
173            SQLExpr::Nested(expr) => self.visit_expr(expr),
174            SQLExpr::Position { expr, r#in } => Ok(
175                // note: SQL is 1-indexed
176                (self
177                    .visit_expr(r#in)?
178                    .str()
179                    .find(self.visit_expr(expr)?, true)
180                    + typed_lit(1u32))
181                .fill_null(typed_lit(0u32)),
182            ),
183            SQLExpr::RLike {
184                // note: parses both RLIKE and REGEXP
185                negated,
186                expr,
187                pattern,
188                regexp: _,
189            } => {
190                let matches = self
191                    .visit_expr(expr)?
192                    .str()
193                    .contains(self.visit_expr(pattern)?, true);
194                Ok(if *negated { matches.not() } else { matches })
195            },
196            SQLExpr::Subscript { expr, subscript } => self.visit_subscript(expr, subscript),
197            SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
198            SQLExpr::Trim {
199                expr,
200                trim_where,
201                trim_what,
202                trim_characters,
203            } => self.visit_trim(expr, trim_where, trim_what, trim_characters),
204            SQLExpr::TypedString { data_type, value } => match data_type {
205                SQLDataType::Date => {
206                    if is_iso_date(value) {
207                        Ok(lit(value.as_str()).cast(DataType::Date))
208                    } else {
209                        polars_bail!(SQLSyntax: "invalid DATE literal '{}'", value)
210                    }
211                },
212                SQLDataType::Time(None, TimezoneInfo::None) => {
213                    if is_iso_time(value) {
214                        Ok(lit(value.as_str()).str().to_time(StrptimeOptions {
215                            strict: true,
216                            ..Default::default()
217                        }))
218                    } else {
219                        polars_bail!(SQLSyntax: "invalid TIME literal '{}'", value)
220                    }
221                },
222                SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
223                    if is_iso_datetime(value) {
224                        Ok(lit(value.as_str()).str().to_datetime(
225                            None,
226                            None,
227                            StrptimeOptions {
228                                strict: true,
229                                ..Default::default()
230                            },
231                            lit("latest"),
232                        ))
233                    } else {
234                        let fn_name = match data_type {
235                            SQLDataType::Timestamp(_, _) => "TIMESTAMP",
236                            SQLDataType::Datetime(_) => "DATETIME",
237                            _ => unreachable!(),
238                        };
239                        polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, value)
240                    }
241                },
242                _ => {
243                    polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
244                },
245            },
246            SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
247            SQLExpr::Value(value) => self.visit_literal(value),
248            SQLExpr::Wildcard(_) => Ok(Expr::Wildcard),
249            e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
250            other => {
251                polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
252            },
253        }
254    }
255
256    fn visit_subquery(
257        &mut self,
258        subquery: &Subquery,
259        restriction: SubqueryRestriction,
260    ) -> PolarsResult<Expr> {
261        if subquery.with.is_some() {
262            polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
263        }
264        let mut lf = self.ctx.execute_query_no_ctes(subquery)?;
265        let schema = self.ctx.get_frame_schema(&mut lf)?;
266
267        if restriction == SubqueryRestriction::SingleColumn {
268            if schema.len() != 1 {
269                polars_bail!(SQLSyntax: "SQL subquery returns more than one column");
270            }
271            let rand_string: String = thread_rng()
272                .sample_iter(&Alphanumeric)
273                .take(16)
274                .map(char::from)
275                .collect();
276
277            let schema_entry = schema.get_at_index(0);
278            if let Some((old_name, _)) = schema_entry {
279                let new_name = String::from(old_name.as_str()) + rand_string.as_str();
280                lf = lf.rename([old_name.to_string()], [new_name.clone()], true);
281                return Ok(Expr::SubPlan(
282                    SpecialEq::new(Arc::new(lf.logical_plan)),
283                    vec![new_name],
284                ));
285            }
286        };
287        polars_bail!(SQLInterface: "subquery type not supported");
288    }
289
290    /// Visit a single SQL identifier.
291    ///
292    /// e.g. column
293    fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
294        Ok(col(ident.value.as_str()))
295    }
296
297    /// Visit a compound SQL identifier
298    ///
299    /// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields)
300    fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
301        Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
302    }
303
304    fn visit_interval(&self, interval: &Interval) -> PolarsResult<Expr> {
305        if interval.last_field.is_some()
306            || interval.leading_field.is_some()
307            || interval.leading_precision.is_some()
308            || interval.fractional_seconds_precision.is_some()
309        {
310            polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
311        }
312        let s = match &*interval.value {
313            SQLExpr::UnaryOp { .. } => {
314                polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
315            },
316            SQLExpr::Value(SQLValue::SingleQuotedString(s)) => Some(s),
317            _ => None,
318        };
319        match s {
320            Some(s) if s.contains('-') => {
321                polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
322            },
323            Some(s) => Ok(lit(Duration::parse_interval(s))),
324            None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
325        }
326    }
327
328    fn visit_like(
329        &mut self,
330        negated: bool,
331        expr: &SQLExpr,
332        pattern: &SQLExpr,
333        escape_char: &Option<String>,
334        case_insensitive: bool,
335    ) -> PolarsResult<Expr> {
336        if escape_char.is_some() {
337            polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
338        }
339        let pat = match self.visit_expr(pattern) {
340            Ok(Expr::Literal(LiteralValue::String(s))) => s,
341            _ => {
342                polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
343            },
344        };
345        if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
346            // empty string or other exact literal match (eg: no wildcard chars)
347            let op = if negated {
348                BinaryOperator::NotEq
349            } else {
350                BinaryOperator::Eq
351            };
352            self.visit_binary_op(expr, &op, pattern)
353        } else {
354            // create regex from pattern containing SQL wildcard chars ('%' => '.*', '_' => '.')
355            let mut rx = regex::escape(pat.as_str())
356                .replace('%', ".*")
357                .replace('_', ".");
358
359            rx = format!(
360                "^{}{}$",
361                if case_insensitive { "(?is)" } else { "(?s)" },
362                rx
363            );
364
365            let expr = self.visit_expr(expr)?;
366            let matches = expr.str().contains(lit(rx), true);
367            Ok(if negated { matches.not() } else { matches })
368        }
369    }
370
371    fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
372        let expr = self.visit_expr(expr)?;
373        Ok(match subscript {
374            Subscript::Index { index } => {
375                let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
376                expr.list().get(idx, true)
377            },
378            Subscript::Slice { .. } => {
379                polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
380            },
381        })
382    }
383
384    /// Handle implicit temporal string comparisons.
385    ///
386    /// eg: "dt >= '2024-04-30'", or "dtm::date = '2077-10-10'"
387    fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
388        if let (Some(name), Some(s), expr_dtype) = match (left, right) {
389            // identify "col <op> string" expressions
390            (Expr::Column(name), Expr::Literal(LiteralValue::String(s))) => {
391                (Some(name.clone()), Some(s), None)
392            },
393            // identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
394            (Expr::Cast { expr, dtype, .. }, Expr::Literal(LiteralValue::String(s))) => {
395                match &**expr {
396                    Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
397                    _ => (None, Some(s), Some(dtype)),
398                }
399            },
400            _ => (None, None, None),
401        } {
402            if expr_dtype.is_none() && self.active_schema.is_none() {
403                right.clone()
404            } else {
405                let left_dtype = expr_dtype.or_else(|| {
406                    self.active_schema
407                        .as_ref()
408                        .and_then(|schema| schema.get(&name))
409                });
410                match left_dtype {
411                    Some(DataType::Time) if is_iso_time(s) => {
412                        right.clone().str().to_time(StrptimeOptions {
413                            strict: true,
414                            ..Default::default()
415                        })
416                    },
417                    Some(DataType::Date) if is_iso_date(s) => {
418                        right.clone().str().to_date(StrptimeOptions {
419                            strict: true,
420                            ..Default::default()
421                        })
422                    },
423                    Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
424                        if s.len() == 10 {
425                            // handle upcast from ISO date string (10 chars) to datetime
426                            lit(format!("{}T00:00:00", s))
427                        } else {
428                            lit(s.replacen(' ', "T", 1))
429                        }
430                        .str()
431                        .to_datetime(
432                            Some(*tu),
433                            tz.clone(),
434                            StrptimeOptions {
435                                strict: true,
436                                ..Default::default()
437                            },
438                            lit("latest"),
439                        )
440                    },
441                    _ => right.clone(),
442                }
443            }
444        } else {
445            right.clone()
446        }
447    }
448
449    fn struct_field_access_expr(
450        &mut self,
451        expr: &Expr,
452        path: &str,
453        infer_index: bool,
454    ) -> PolarsResult<Expr> {
455        let path_elems = if path.starts_with('{') && path.ends_with('}') {
456            path.trim_matches(|c| c == '{' || c == '}')
457        } else {
458            path
459        }
460        .split(',');
461
462        let mut expr = expr.clone();
463        for p in path_elems {
464            let p = p.trim();
465            expr = if infer_index {
466                match p.parse::<i64>() {
467                    Ok(idx) => expr.list().get(lit(idx), true),
468                    Err(_) => expr.struct_().field_by_name(p),
469                }
470            } else {
471                expr.struct_().field_by_name(p)
472            }
473        }
474        Ok(expr)
475    }
476
477    /// Visit a SQL binary operator.
478    ///
479    /// e.g. "column + 1", "column1 <= column2"
480    fn visit_binary_op(
481        &mut self,
482        left: &SQLExpr,
483        op: &BinaryOperator,
484        right: &SQLExpr,
485    ) -> PolarsResult<Expr> {
486        let lhs = self.visit_expr(left)?;
487        let mut rhs = self.visit_expr(right)?;
488        rhs = self.convert_temporal_strings(&lhs, &rhs);
489
490        Ok(match op {
491            // ----
492            // Bitwise operators
493            // ----
494            SQLBinaryOperator::BitwiseAnd => lhs.and(rhs),  // "x & y"
495            SQLBinaryOperator::BitwiseOr => lhs.or(rhs),  // "x | y"
496            SQLBinaryOperator::Xor => lhs.xor(rhs),  // "x XOR y"
497
498            // ----
499            // General operators
500            // ----
501            SQLBinaryOperator::And => lhs.and(rhs),  // "x AND y"
502            SQLBinaryOperator::Divide => lhs / rhs,  // "x / y"
503            SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64),  // "x // y"
504            SQLBinaryOperator::Eq => lhs.eq(rhs),  // "x = y"
505            SQLBinaryOperator::Gt => lhs.gt(rhs),  // "x > y"
506            SQLBinaryOperator::GtEq => lhs.gt_eq(rhs),  // "x >= y"
507            SQLBinaryOperator::Lt => lhs.lt(rhs),  // "x < y"
508            SQLBinaryOperator::LtEq => lhs.lt_eq(rhs),  // "x <= y"
509            SQLBinaryOperator::Minus => lhs - rhs,  // "x - y"
510            SQLBinaryOperator::Modulo => lhs % rhs,  // "x % y"
511            SQLBinaryOperator::Multiply => lhs * rhs,  // "x * y"
512            SQLBinaryOperator::NotEq => lhs.eq(rhs).not(),  // "x != y"
513            SQLBinaryOperator::Or => lhs.or(rhs),  // "x OR y"
514            SQLBinaryOperator::Plus => lhs + rhs,  // "x + y"
515            SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs),  // "x <=> y"
516            SQLBinaryOperator::StringConcat => {  // "x || y"
517                lhs.cast(DataType::String) + rhs.cast(DataType::String)
518            },
519            SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs),  // "x ^@ y"
520            // ----
521            // Regular expression operators
522            // ----
523            SQLBinaryOperator::PGRegexMatch => match rhs {  // "x ~ y"
524                Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true),
525                _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
526            },
527            SQLBinaryOperator::PGRegexNotMatch => match rhs {  // "x !~ y"
528                Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true).not(),
529                _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
530            },
531            SQLBinaryOperator::PGRegexIMatch => match rhs {  // "x ~* y"
532                Expr::Literal(LiteralValue::String(pat)) => {
533                    lhs.str().contains(lit(format!("(?i){}", pat)), true)
534                },
535                _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
536            },
537            SQLBinaryOperator::PGRegexNotIMatch => match rhs {  // "x !~* y"
538                Expr::Literal(LiteralValue::String(pat)) => {
539                    lhs.str().contains(lit(format!("(?i){}", pat)), true).not()
540                },
541                _ => {
542                    polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
543                },
544            },
545            // ----
546            // LIKE/ILIKE operators
547            // ----
548            SQLBinaryOperator::PGLikeMatch  // "x ~~ y"
549            | SQLBinaryOperator::PGNotLikeMatch  // "x !~~ y"
550            | SQLBinaryOperator::PGILikeMatch  // "x ~~* y"
551            | SQLBinaryOperator::PGNotILikeMatch => {  // "x !~~* y"
552                let expr = if matches!(
553                    op,
554                    SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
555                ) {
556                    SQLExpr::Like {
557                        negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
558                        any: false,
559                        expr: Box::new(left.clone()),
560                        pattern: Box::new(right.clone()),
561                        escape_char: None,
562                    }
563                } else {
564                    SQLExpr::ILike {
565                        negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
566                        any: false,
567                        expr: Box::new(left.clone()),
568                        pattern: Box::new(right.clone()),
569                        escape_char: None,
570                    }
571                };
572                self.visit_expr(&expr)?
573            },
574            // ----
575            // JSON/Struct field access operators
576            // ----
577            SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs {  // "x -> y", "x ->> y"
578                Expr::Literal(LiteralValue::String(path)) => {
579                    let mut expr = self.struct_field_access_expr(&lhs, &path, false)?;
580                    if let SQLBinaryOperator::LongArrow = op {
581                        expr = expr.cast(DataType::String);
582                    }
583                    expr
584                },
585                Expr::Literal(LiteralValue::Int(idx)) => {
586                    let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
587                    if let SQLBinaryOperator::LongArrow = op {
588                        expr = expr.cast(DataType::String);
589                    }
590                    expr
591                },
592                _ => {
593                    polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
594                },
595            },
596            SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => {  // "x #> y", "x #>> y"
597                if let Expr::Literal(LiteralValue::String(path)) = rhs {
598                    let mut expr = self.struct_field_access_expr(&lhs, &path, true)?;
599                    if let SQLBinaryOperator::HashLongArrow = op {
600                        expr = expr.cast(DataType::String);
601                    }
602                    expr
603                } else {
604                    polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
605                }
606            },
607            other => {
608                polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
609            },
610        })
611    }
612
613    /// Visit a SQL unary operator.
614    ///
615    /// e.g. +column or -column
616    fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
617        let expr = self.visit_expr(expr)?;
618        Ok(match (op, expr.clone()) {
619            // simplify the parse tree by special-casing common unary +/- ops
620            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Int(n))) => lit(n),
621            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Float(n))) => lit(n),
622            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Int(n))) => lit(-n),
623            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Float(n))) => lit(-n),
624            // general case
625            (UnaryOperator::Plus, _) => lit(0) + expr,
626            (UnaryOperator::Minus, _) => lit(0) - expr,
627            (UnaryOperator::Not, _) => expr.not(),
628            other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
629        })
630    }
631
632    /// Visit a SQL function.
633    ///
634    /// e.g. SUM(column) or COUNT(*)
635    ///
636    /// See [SQLFunctionVisitor] for more details
637    fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
638        let mut visitor = SQLFunctionVisitor {
639            func: function,
640            ctx: self.ctx,
641            active_schema: self.active_schema,
642        };
643        visitor.visit_function()
644    }
645
646    /// Visit a SQL `ALL` expression.
647    ///
648    /// e.g. `a > ALL(y)`
649    fn visit_all(
650        &mut self,
651        left: &SQLExpr,
652        compare_op: &BinaryOperator,
653        right: &SQLExpr,
654    ) -> PolarsResult<Expr> {
655        let left = self.visit_expr(left)?;
656        let right = self.visit_expr(right)?;
657
658        match compare_op {
659            BinaryOperator::Gt => Ok(left.gt(right.max())),
660            BinaryOperator::Lt => Ok(left.lt(right.min())),
661            BinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
662            BinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
663            BinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
664            BinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
665            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
666        }
667    }
668
669    /// Visit a SQL `ANY` expression.
670    ///
671    /// e.g. `a != ANY(y)`
672    fn visit_any(
673        &mut self,
674        left: &SQLExpr,
675        compare_op: &BinaryOperator,
676        right: &SQLExpr,
677    ) -> PolarsResult<Expr> {
678        let left = self.visit_expr(left)?;
679        let right = self.visit_expr(right)?;
680
681        match compare_op {
682            BinaryOperator::Gt => Ok(left.gt(right.min())),
683            BinaryOperator::Lt => Ok(left.lt(right.max())),
684            BinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
685            BinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
686            BinaryOperator::Eq => Ok(left.is_in(right)),
687            BinaryOperator::NotEq => Ok(left.is_in(right).not()),
688            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
689        }
690    }
691
692    /// Visit a SQL `ARRAY` list (including `IN` values).
693    fn visit_array_expr(
694        &mut self,
695        elements: &[SQLExpr],
696        result_as_element: bool,
697        dtype_expr_match: Option<&Expr>,
698    ) -> PolarsResult<Expr> {
699        let mut elems = self.array_expr_to_series(elements)?;
700
701        // handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
702        // (not yet as versatile as the temporal string conversions in visit_binary_op)
703        if let (Some(Expr::Column(name)), Some(schema)) =
704            (dtype_expr_match, self.active_schema.as_ref())
705        {
706            if elems.dtype() == &DataType::String {
707                if let Some(dtype) = schema.get(name) {
708                    if matches!(
709                        dtype,
710                        DataType::Date | DataType::Time | DataType::Datetime(_, _)
711                    ) {
712                        elems = elems.strict_cast(dtype)?;
713                    }
714                }
715            }
716        }
717
718        // if we are parsing the list as an element in a series, implode.
719        // otherwise, return the series as-is.
720        let res = if result_as_element {
721            elems.implode()?.into_series()
722        } else {
723            elems
724        };
725        Ok(lit(res))
726    }
727
728    /// Visit a SQL `CAST` or `TRY_CAST` expression.
729    ///
730    /// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
731    fn visit_cast(
732        &mut self,
733        expr: &SQLExpr,
734        dtype: &SQLDataType,
735        format: &Option<CastFormat>,
736        cast_kind: &CastKind,
737    ) -> PolarsResult<Expr> {
738        if format.is_some() {
739            return Err(
740                polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
741            );
742        }
743        let expr = self.visit_expr(expr)?;
744
745        #[cfg(feature = "json")]
746        if dtype == &SQLDataType::JSON {
747            return Ok(expr.str().json_decode(None, None));
748        }
749        let polars_type = map_sql_dtype_to_polars(dtype)?;
750        Ok(match cast_kind {
751            CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
752            CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
753        })
754    }
755
756    /// Visit a SQL literal.
757    ///
758    /// e.g. 1, 'foo', 1.0, NULL
759    ///
760    /// See [SQLValue] and [LiteralValue] for more details
761    fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
762        // note: double-quoted strings will be parsed as identifiers, not literals
763        Ok(match value {
764            SQLValue::Boolean(b) => lit(*b),
765            SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
766            #[cfg(feature = "binary_encoding")]
767            SQLValue::HexStringLiteral(x) => {
768                if x.len() % 2 != 0 {
769                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
770                };
771                lit(hex::decode(x.clone()).unwrap())
772            },
773            SQLValue::Null => Expr::Literal(LiteralValue::Null),
774            SQLValue::Number(s, _) => {
775                // Check for existence of decimal separator dot
776                if s.contains('.') {
777                    s.parse::<f64>().map(lit).map_err(|_| ())
778                } else {
779                    s.parse::<i64>().map(lit).map_err(|_| ())
780                }
781                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
782            },
783            SQLValue::SingleQuotedByteStringLiteral(b) => {
784                // note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string
785                // literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser-rs
786                // patterned the token name after BigQuery (where b'str' really IS a byte string)
787                bitstring_to_bytes_literal(b)?
788            },
789            SQLValue::SingleQuotedString(s) => lit(s.clone()),
790            other => {
791                polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
792            },
793        })
794    }
795
796    /// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr.
797    fn visit_any_value(
798        &self,
799        value: &SQLValue,
800        op: Option<&UnaryOperator>,
801    ) -> PolarsResult<AnyValue> {
802        Ok(match value {
803            SQLValue::Boolean(b) => AnyValue::Boolean(*b),
804            SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
805            #[cfg(feature = "binary_encoding")]
806            SQLValue::HexStringLiteral(x) => {
807                if x.len() % 2 != 0 {
808                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
809                };
810                AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
811            },
812            SQLValue::Null => AnyValue::Null,
813            SQLValue::Number(s, _) => {
814                let negate = match op {
815                    Some(UnaryOperator::Minus) => true,
816                    // no op should be taken as plus.
817                    Some(UnaryOperator::Plus) | None => false,
818                    Some(op) => {
819                        polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
820                    },
821                };
822                // Check for existence of decimal separator dot
823                if s.contains('.') {
824                    s.parse::<f64>()
825                        .map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
826                        .map_err(|_| ())
827                } else {
828                    s.parse::<i64>()
829                        .map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
830                        .map_err(|_| ())
831                }
832                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
833            },
834            SQLValue::SingleQuotedByteStringLiteral(b) => {
835                // note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE
836                let bytes_literal = bitstring_to_bytes_literal(b)?;
837                match bytes_literal {
838                    Expr::Literal(LiteralValue::Binary(v)) => AnyValue::BinaryOwned(v.to_vec()),
839                    _ => {
840                        polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
841                    },
842                }
843            },
844            SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
845            other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
846        })
847    }
848
849    /// Visit a SQL `BETWEEN` expression.
850    /// See [sqlparser::ast::Expr::Between] for more details
851    fn visit_between(
852        &mut self,
853        expr: &SQLExpr,
854        negated: bool,
855        low: &SQLExpr,
856        high: &SQLExpr,
857    ) -> PolarsResult<Expr> {
858        let expr = self.visit_expr(expr)?;
859        let low = self.visit_expr(low)?;
860        let high = self.visit_expr(high)?;
861
862        let low = self.convert_temporal_strings(&expr, &low);
863        let high = self.convert_temporal_strings(&expr, &high);
864        Ok(if negated {
865            expr.clone().lt(low).or(expr.gt(high))
866        } else {
867            expr.clone().gt_eq(low).and(expr.lt_eq(high))
868        })
869    }
870
871    /// Visit a SQL `TRIM` function.
872    /// See [sqlparser::ast::Expr::Trim] for more details
873    fn visit_trim(
874        &mut self,
875        expr: &SQLExpr,
876        trim_where: &Option<TrimWhereField>,
877        trim_what: &Option<Box<SQLExpr>>,
878        trim_characters: &Option<Vec<SQLExpr>>,
879    ) -> PolarsResult<Expr> {
880        if trim_characters.is_some() {
881            // TODO: allow compact snowflake/bigquery syntax?
882            return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
883        };
884        let expr = self.visit_expr(expr)?;
885        let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
886        let trim_what = match trim_what {
887            Some(Expr::Literal(LiteralValue::String(val))) => Some(val),
888            None => None,
889            _ => return self.err(&expr),
890        };
891        Ok(match (trim_where, trim_what) {
892            (None | Some(TrimWhereField::Both), None) => expr.str().strip_chars(lit(Null)),
893            (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
894            (Some(TrimWhereField::Leading), None) => expr.str().strip_chars_start(lit(Null)),
895            (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
896            (Some(TrimWhereField::Trailing), None) => expr.str().strip_chars_end(lit(Null)),
897            (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
898        })
899    }
900
901    /// Visit a SQL subquery inside and `IN` expression.
902    fn visit_in_subquery(
903        &mut self,
904        expr: &SQLExpr,
905        subquery: &Subquery,
906        negated: bool,
907    ) -> PolarsResult<Expr> {
908        let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?;
909        let expr = self.visit_expr(expr)?;
910        Ok(if negated {
911            expr.is_in(subquery_result).not()
912        } else {
913            expr.is_in(subquery_result)
914        })
915    }
916
917    /// Visit `CASE` control flow expression.
918    fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
919        if let SQLExpr::Case {
920            operand,
921            conditions,
922            results,
923            else_result,
924        } = expr
925        {
926            polars_ensure!(
927                conditions.len() == results.len(),
928                SQLSyntax: "WHEN and THEN expressions must have the same length"
929            );
930            polars_ensure!(
931                !conditions.is_empty(),
932                SQLSyntax: "WHEN and THEN expressions must have at least one element"
933            );
934
935            let mut when_thens = conditions.iter().zip(results.iter());
936            let first = when_thens.next();
937            if first.is_none() {
938                polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
939            }
940            let else_res = match else_result {
941                Some(else_res) => self.visit_expr(else_res)?,
942                None => lit(Null), // ELSE clause is optional; when omitted, it is implicitly NULL
943            };
944            if let Some(operand_expr) = operand {
945                let first_operand_expr = self.visit_expr(operand_expr)?;
946
947                let first = first.unwrap();
948                let first_cond = first_operand_expr.eq(self.visit_expr(first.0)?);
949                let first_then = self.visit_expr(first.1)?;
950                let expr = when(first_cond).then(first_then);
951                let next = when_thens.next();
952
953                let mut when_then = if let Some((cond, res)) = next {
954                    let second_operand_expr = self.visit_expr(operand_expr)?;
955                    let cond = second_operand_expr.eq(self.visit_expr(cond)?);
956                    let res = self.visit_expr(res)?;
957                    expr.when(cond).then(res)
958                } else {
959                    return Ok(expr.otherwise(else_res));
960                };
961                for (cond, res) in when_thens {
962                    let new_operand_expr = self.visit_expr(operand_expr)?;
963                    let cond = new_operand_expr.eq(self.visit_expr(cond)?);
964                    let res = self.visit_expr(res)?;
965                    when_then = when_then.when(cond).then(res);
966                }
967                return Ok(when_then.otherwise(else_res));
968            }
969
970            let first = first.unwrap();
971            let first_cond = self.visit_expr(first.0)?;
972            let first_then = self.visit_expr(first.1)?;
973            let expr = when(first_cond).then(first_then);
974            let next = when_thens.next();
975
976            let mut when_then = if let Some((cond, res)) = next {
977                let cond = self.visit_expr(cond)?;
978                let res = self.visit_expr(res)?;
979                expr.when(cond).then(res)
980            } else {
981                return Ok(expr.otherwise(else_res));
982            };
983            for (cond, res) in when_thens {
984                let cond = self.visit_expr(cond)?;
985                let res = self.visit_expr(res)?;
986                when_then = when_then.when(cond).then(res);
987            }
988            Ok(when_then.otherwise(else_res))
989        } else {
990            unreachable!()
991        }
992    }
993
994    fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
995        polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
996    }
997}
998
999/// parse a SQL expression to a polars expression
1000/// # Example
1001/// ```rust
1002/// # use polars_sql::{SQLContext, sql_expr};
1003/// # use polars_core::prelude::*;
1004/// # use polars_lazy::prelude::*;
1005/// # fn main() {
1006///
1007/// let mut ctx = SQLContext::new();
1008/// let df = df! {
1009///    "a" =>  [1, 2, 3],
1010/// }
1011/// .unwrap();
1012/// let expr = sql_expr("MAX(a)").unwrap();
1013/// df.lazy().select(vec![expr]).collect().unwrap();
1014/// # }
1015/// ```
1016pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1017    let mut ctx = SQLContext::new();
1018
1019    let mut parser = Parser::new(&GenericDialect);
1020    parser = parser.with_options(ParserOptions {
1021        trailing_commas: true,
1022        ..Default::default()
1023    });
1024
1025    let mut ast = parser
1026        .try_with_sql(s.as_ref())
1027        .map_err(to_sql_interface_err)?;
1028    let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1029
1030    Ok(match &expr {
1031        SelectItem::ExprWithAlias { expr, alias } => {
1032            let expr = parse_sql_expr(expr, &mut ctx, None)?;
1033            expr.alias(alias.value.as_str())
1034        },
1035        SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1036        _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1037    })
1038}
1039
1040pub(crate) fn parse_sql_expr(
1041    expr: &SQLExpr,
1042    ctx: &mut SQLContext,
1043    active_schema: Option<&Schema>,
1044) -> PolarsResult<Expr> {
1045    let mut visitor = SQLExprVisitor { ctx, active_schema };
1046    visitor.visit_expr(expr)
1047}
1048
1049pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1050    match expr {
1051        SQLExpr::Array(arr) => {
1052            let mut visitor = SQLExprVisitor {
1053                ctx,
1054                active_schema: None,
1055            };
1056            visitor.array_expr_to_series(arr.elem.as_slice())
1057        },
1058        _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1059    }
1060}
1061
1062pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1063    let field = match field {
1064        // handle 'DATE_PART' and all valid abbreviations/alternates
1065        DateTimeField::Custom(Ident { value, .. }) => {
1066            let value = value.to_ascii_lowercase();
1067            match value.as_str() {
1068                "millennium" | "millennia" => &DateTimeField::Millennium,
1069                "century" | "centuries" => &DateTimeField::Century,
1070                "decade" | "decades" => &DateTimeField::Decade,
1071                "isoyear" => &DateTimeField::Isoyear,
1072                "year" | "years" | "y" => &DateTimeField::Year,
1073                "quarter" | "quarters" => &DateTimeField::Quarter,
1074                "month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1075                "dayofyear" | "doy" => &DateTimeField::DayOfYear,
1076                "dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1077                "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1078                "isodow" => &DateTimeField::Isodow,
1079                "day" | "days" | "d" => &DateTimeField::Day,
1080                "hour" | "hours" | "h" => &DateTimeField::Hour,
1081                "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1082                "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1083                "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1084                "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1085                "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1086                #[cfg(feature = "timezones")]
1087                "timezone" => &DateTimeField::Timezone,
1088                "time" => &DateTimeField::Time,
1089                "epoch" => &DateTimeField::Epoch,
1090                _ => {
1091                    polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1092                },
1093            }
1094        },
1095        _ => field,
1096    };
1097    Ok(match field {
1098        DateTimeField::Millennium => expr.dt().millennium(),
1099        DateTimeField::Century => expr.dt().century(),
1100        DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1101        DateTimeField::Isoyear => expr.dt().iso_year(),
1102        DateTimeField::Year => expr.dt().year(),
1103        DateTimeField::Quarter => expr.dt().quarter(),
1104        DateTimeField::Month => expr.dt().month(),
1105        DateTimeField::Week(weekday) => {
1106            if weekday.is_some() {
1107                polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1108            }
1109            expr.dt().week()
1110        },
1111        DateTimeField::IsoWeek => expr.dt().week(),
1112        DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1113        DateTimeField::DayOfWeek | DateTimeField::Dow => {
1114            let w = expr.dt().weekday();
1115            when(w.clone().eq(typed_lit(7i8)))
1116                .then(typed_lit(0i8))
1117                .otherwise(w)
1118        },
1119        DateTimeField::Isodow => expr.dt().weekday(),
1120        DateTimeField::Day => expr.dt().day(),
1121        DateTimeField::Hour => expr.dt().hour(),
1122        DateTimeField::Minute => expr.dt().minute(),
1123        DateTimeField::Second => expr.dt().second(),
1124        DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1125            (expr.clone().dt().second() * typed_lit(1_000f64))
1126                + expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1127        },
1128        DateTimeField::Microsecond | DateTimeField::Microseconds => {
1129            (expr.clone().dt().second() * typed_lit(1_000_000f64))
1130                + expr.dt().nanosecond().div(typed_lit(1_000f64))
1131        },
1132        DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1133            (expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1134        },
1135        DateTimeField::Time => expr.dt().time(),
1136        #[cfg(feature = "timezones")]
1137        DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(),
1138        DateTimeField::Epoch => {
1139            expr.clone()
1140                .dt()
1141                .timestamp(TimeUnit::Nanoseconds)
1142                .div(typed_lit(1_000_000_000i64))
1143                + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1144        },
1145        _ => {
1146            polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1147        },
1148    })
1149}
1150
1151/// Allow an expression that represents a 1-indexed parameter to
1152/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars)
1153pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1154    match idx {
1155        Expr::Literal(Null) => lit(Null),
1156        Expr::Literal(LiteralValue::Int(0)) => {
1157            if null_if_zero {
1158                lit(Null)
1159            } else {
1160                idx
1161            }
1162        },
1163        Expr::Literal(LiteralValue::Int(n)) if n < 0 => idx,
1164        Expr::Literal(LiteralValue::Int(n)) => lit(n - 1),
1165        // TODO: when 'saturating_sub' is available, should be able
1166        //  to streamline the when/then/otherwise block below -
1167        _ => when(idx.clone().gt(lit(0)))
1168            .then(idx.clone() - lit(1))
1169            .otherwise(if null_if_zero {
1170                when(idx.clone().eq(lit(0)))
1171                    .then(lit(Null))
1172                    .otherwise(idx.clone())
1173            } else {
1174                idx.clone()
1175            }),
1176    }
1177}
1178
1179fn resolve_column<'a>(
1180    ctx: &'a mut SQLContext,
1181    ident_root: &'a Ident,
1182    name: &'a str,
1183    dtype: &'a DataType,
1184) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1185    let resolved = ctx.resolve_name(&ident_root.value, name);
1186    let resolved = resolved.as_str();
1187    Ok((
1188        if name != resolved {
1189            col(resolved).alias(name)
1190        } else {
1191            col(name)
1192        },
1193        Some(dtype),
1194    ))
1195}
1196
1197pub(crate) fn resolve_compound_identifier(
1198    ctx: &mut SQLContext,
1199    idents: &[Ident],
1200    active_schema: Option<&Schema>,
1201) -> PolarsResult<Vec<Expr>> {
1202    // inference priority: table > struct > column
1203    let ident_root = &idents[0];
1204    let mut remaining_idents = idents.iter().skip(1);
1205    let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1206
1207    let schema = if let Some(ref mut lf) = lf {
1208        lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)
1209    } else {
1210        Ok(Arc::new(if let Some(active_schema) = active_schema {
1211            active_schema.clone()
1212        } else {
1213            Schema::default()
1214        }))
1215    }?;
1216
1217    let col_dtype: PolarsResult<(Expr, Option<&DataType>)> = if lf.is_none() && schema.is_empty() {
1218        Ok((col(ident_root.value.as_str()), None))
1219    } else {
1220        let name = &remaining_idents.next().unwrap().value;
1221        if lf.is_some() && name == "*" {
1222            return Ok(schema
1223                .iter_names_and_dtypes()
1224                .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).unwrap().0)
1225                .collect::<Vec<_>>());
1226        } else if let Some((_, name, dtype)) = schema.get_full(name) {
1227            resolve_column(ctx, ident_root, name, dtype)
1228        } else if lf.is_none() {
1229            remaining_idents = idents.iter().skip(1);
1230            Ok((
1231                col(ident_root.value.as_str()),
1232                schema.get(&ident_root.value),
1233            ))
1234        } else {
1235            polars_bail!(
1236                SQLInterface: "no column named '{}' found in table '{}'",
1237                name,
1238                ident_root
1239            )
1240        }
1241    };
1242
1243    // additional ident levels index into struct fields
1244    let (mut column, mut dtype) = col_dtype?;
1245    for ident in remaining_idents {
1246        let name = ident.value.as_str();
1247        match dtype {
1248            Some(DataType::Struct(fields)) if name == "*" => {
1249                return Ok(fields
1250                    .iter()
1251                    .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1252                    .collect())
1253            },
1254            Some(DataType::Struct(fields)) => {
1255                dtype = fields
1256                    .iter()
1257                    .find(|fld| fld.name == name)
1258                    .map(|fld| &fld.dtype);
1259            },
1260            Some(dtype) if name == "*" => {
1261                polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1262            },
1263            _ => {
1264                dtype = None;
1265            },
1266        }
1267        column = column.struct_().field_by_name(name);
1268    }
1269    Ok(vec![column])
1270}