polars_sql/
sql_expr.rs

1//! Expressions that are supported by the Polars SQL interface.
2//!
3//! This is useful for syntax highlighting
4//!
5//! This module defines:
6//! - all Polars SQL keywords [`all_keywords`]
7//! - all of polars SQL functions [`all_functions`]
8
9use std::fmt::Display;
10use std::ops::Div;
11
12use polars_core::prelude::*;
13use polars_lazy::prelude::*;
14use polars_plan::plans::DynLiteralValue;
15use polars_plan::prelude::typed_lit;
16use polars_time::Duration;
17use rand::distributions::Alphanumeric;
18use rand::{Rng, thread_rng};
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21use sqlparser::ast::{
22    BinaryOperator as SQLBinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
23    DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, Interval, Query as Subquery,
24    SelectItem, Subscript, TimezoneInfo, TrimWhereField, UnaryOperator, Value as SQLValue,
25};
26use sqlparser::dialect::GenericDialect;
27use sqlparser::parser::{Parser, ParserOptions};
28
29use crate::SQLContext;
30use crate::functions::SQLFunctionVisitor;
31use crate::types::{
32    bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
33};
34
35#[inline]
36#[cold]
37#[must_use]
38/// Convert a Display-able error to PolarsError::SQLInterface
39pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
40    PolarsError::SQLInterface(err.to_string().into())
41}
42
43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
45/// Categorises the type of (allowed) subquery constraint
46pub enum SubqueryRestriction {
47    /// Subquery must return a single column
48    SingleColumn,
49    // SingleRow,
50    // SingleValue,
51    // Any
52}
53
54/// Recursively walks a SQL Expr to create a polars Expr
55pub(crate) struct SQLExprVisitor<'a> {
56    ctx: &'a mut SQLContext,
57    active_schema: Option<&'a Schema>,
58}
59
60impl SQLExprVisitor<'_> {
61    fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
62        let mut array_elements = Vec::with_capacity(elements.len());
63        for e in elements {
64            let val = match e {
65                SQLExpr::Value(v) => self.visit_any_value(v, None),
66                SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
67                    SQLExpr::Value(v) => self.visit_any_value(v, Some(op)),
68                    _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
69                },
70                SQLExpr::Array(values) => {
71                    let srs = self.array_expr_to_series(&values.elem)?;
72                    Ok(AnyValue::List(srs))
73                },
74                _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
75            }?
76            .into_static();
77            array_elements.push(val);
78        }
79        Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
80    }
81
82    fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
83        match expr {
84            SQLExpr::AllOp {
85                left,
86                compare_op,
87                right,
88            } => self.visit_all(left, compare_op, right),
89            SQLExpr::AnyOp {
90                left,
91                compare_op,
92                right,
93                is_some: _,
94            } => self.visit_any(left, compare_op, right),
95            SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
96            SQLExpr::Between {
97                expr,
98                negated,
99                low,
100                high,
101            } => self.visit_between(expr, *negated, low, high),
102            SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
103            SQLExpr::Cast {
104                kind,
105                expr,
106                data_type,
107                format,
108            } => self.visit_cast(expr, data_type, format, kind),
109            SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
110            SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
111            SQLExpr::Extract {
112                field,
113                syntax: _,
114                expr,
115            } => parse_extract_date_part(self.visit_expr(expr)?, field),
116            SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
117            SQLExpr::Function(function) => self.visit_function(function),
118            SQLExpr::Identifier(ident) => self.visit_identifier(ident),
119            SQLExpr::InList {
120                expr,
121                list,
122                negated,
123            } => {
124                let expr = self.visit_expr(expr)?;
125                let elems = self.visit_array_expr(list, true, Some(&expr))?;
126                let is_in = expr.is_in(elems, false);
127                Ok(if *negated { is_in.not() } else { is_in })
128            },
129            SQLExpr::InSubquery {
130                expr,
131                subquery,
132                negated,
133            } => self.visit_in_subquery(expr, subquery, *negated),
134            SQLExpr::Interval(interval) => Ok(lit(interval_to_duration(interval, true)?)),
135            SQLExpr::IsDistinctFrom(e1, e2) => {
136                Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
137            },
138            SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
139            SQLExpr::IsNotDistinctFrom(e1, e2) => {
140                Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
141            },
142            SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
143            SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
144            SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
145            SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
146            SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
147            SQLExpr::Like {
148                negated,
149                any,
150                expr,
151                pattern,
152                escape_char,
153            } => {
154                if *any {
155                    polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
156                }
157                self.visit_like(*negated, expr, pattern, escape_char, false)
158            },
159            SQLExpr::ILike {
160                negated,
161                any,
162                expr,
163                pattern,
164                escape_char,
165            } => {
166                if *any {
167                    polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
168                }
169                self.visit_like(*negated, expr, pattern, escape_char, true)
170            },
171            SQLExpr::Nested(expr) => self.visit_expr(expr),
172            SQLExpr::Position { expr, r#in } => Ok(
173                // note: SQL is 1-indexed
174                (self
175                    .visit_expr(r#in)?
176                    .str()
177                    .find(self.visit_expr(expr)?, true)
178                    + typed_lit(1u32))
179                .fill_null(typed_lit(0u32)),
180            ),
181            SQLExpr::RLike {
182                // note: parses both RLIKE and REGEXP
183                negated,
184                expr,
185                pattern,
186                regexp: _,
187            } => {
188                let matches = self
189                    .visit_expr(expr)?
190                    .str()
191                    .contains(self.visit_expr(pattern)?, true);
192                Ok(if *negated { matches.not() } else { matches })
193            },
194            SQLExpr::Subscript { expr, subscript } => self.visit_subscript(expr, subscript),
195            SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
196            SQLExpr::Trim {
197                expr,
198                trim_where,
199                trim_what,
200                trim_characters,
201            } => self.visit_trim(expr, trim_where, trim_what, trim_characters),
202            SQLExpr::TypedString { data_type, value } => match data_type {
203                SQLDataType::Date => {
204                    if is_iso_date(value) {
205                        Ok(lit(value.as_str()).cast(DataType::Date))
206                    } else {
207                        polars_bail!(SQLSyntax: "invalid DATE literal '{}'", value)
208                    }
209                },
210                SQLDataType::Time(None, TimezoneInfo::None) => {
211                    if is_iso_time(value) {
212                        Ok(lit(value.as_str()).str().to_time(StrptimeOptions {
213                            strict: true,
214                            ..Default::default()
215                        }))
216                    } else {
217                        polars_bail!(SQLSyntax: "invalid TIME literal '{}'", value)
218                    }
219                },
220                SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
221                    if is_iso_datetime(value) {
222                        Ok(lit(value.as_str()).str().to_datetime(
223                            None,
224                            None,
225                            StrptimeOptions {
226                                strict: true,
227                                ..Default::default()
228                            },
229                            lit("latest"),
230                        ))
231                    } else {
232                        let fn_name = match data_type {
233                            SQLDataType::Timestamp(_, _) => "TIMESTAMP",
234                            SQLDataType::Datetime(_) => "DATETIME",
235                            _ => unreachable!(),
236                        };
237                        polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, value)
238                    }
239                },
240                _ => {
241                    polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
242                },
243            },
244            SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
245            SQLExpr::Value(value) => self.visit_literal(value),
246            SQLExpr::Wildcard(_) => Ok(Expr::Wildcard),
247            e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
248            other => {
249                polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
250            },
251        }
252    }
253
254    fn visit_subquery(
255        &mut self,
256        subquery: &Subquery,
257        restriction: SubqueryRestriction,
258    ) -> PolarsResult<Expr> {
259        if subquery.with.is_some() {
260            polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
261        }
262        let mut lf = self.ctx.execute_query_no_ctes(subquery)?;
263        let schema = self.ctx.get_frame_schema(&mut lf)?;
264
265        if restriction == SubqueryRestriction::SingleColumn {
266            if schema.len() != 1 {
267                polars_bail!(SQLSyntax: "SQL subquery returns more than one column");
268            }
269            let rand_string: String = thread_rng()
270                .sample_iter(&Alphanumeric)
271                .take(16)
272                .map(char::from)
273                .collect();
274
275            let schema_entry = schema.get_at_index(0);
276            if let Some((old_name, _)) = schema_entry {
277                let new_name = String::from(old_name.as_str()) + rand_string.as_str();
278                lf = lf.rename([old_name.to_string()], [new_name.clone()], true);
279                return Ok(Expr::SubPlan(
280                    SpecialEq::new(Arc::new(lf.logical_plan)),
281                    vec![new_name],
282                ));
283            }
284        };
285        polars_bail!(SQLInterface: "subquery type not supported");
286    }
287
288    /// Visit a single SQL identifier.
289    ///
290    /// e.g. column
291    fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
292        Ok(col(ident.value.as_str()))
293    }
294
295    /// Visit a compound SQL identifier
296    ///
297    /// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields)
298    fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
299        Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
300    }
301
302    fn visit_like(
303        &mut self,
304        negated: bool,
305        expr: &SQLExpr,
306        pattern: &SQLExpr,
307        escape_char: &Option<String>,
308        case_insensitive: bool,
309    ) -> PolarsResult<Expr> {
310        if escape_char.is_some() {
311            polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
312        }
313        let pat = match self.visit_expr(pattern) {
314            Ok(Expr::Literal(lv)) if lv.extract_str().is_some() => {
315                PlSmallStr::from_str(lv.extract_str().unwrap())
316            },
317            _ => {
318                polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
319            },
320        };
321        if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
322            // empty string or other exact literal match (eg: no wildcard chars)
323            let op = if negated {
324                SQLBinaryOperator::NotEq
325            } else {
326                SQLBinaryOperator::Eq
327            };
328            self.visit_binary_op(expr, &op, pattern)
329        } else {
330            // create regex from pattern containing SQL wildcard chars ('%' => '.*', '_' => '.')
331            let mut rx = regex::escape(pat.as_str())
332                .replace('%', ".*")
333                .replace('_', ".");
334
335            rx = format!(
336                "^{}{}$",
337                if case_insensitive { "(?is)" } else { "(?s)" },
338                rx
339            );
340
341            let expr = self.visit_expr(expr)?;
342            let matches = expr.str().contains(lit(rx), true);
343            Ok(if negated { matches.not() } else { matches })
344        }
345    }
346
347    fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
348        let expr = self.visit_expr(expr)?;
349        Ok(match subscript {
350            Subscript::Index { index } => {
351                let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
352                expr.list().get(idx, true)
353            },
354            Subscript::Slice { .. } => {
355                polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
356            },
357        })
358    }
359
360    /// Handle implicit temporal string comparisons.
361    ///
362    /// eg: clauses such as -
363    ///   "dt >= '2024-04-30'"
364    ///   "dt = '2077-10-10'::date"
365    ///   "dtm::date = '2077-10-10'
366    fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
367        if let (Some(name), Some(s), expr_dtype) = match (left, right) {
368            // identify "col <op> string" expressions
369            (Expr::Column(name), Expr::Literal(lv)) if lv.extract_str().is_some() => {
370                (Some(name.clone()), Some(lv.extract_str().unwrap()), None)
371            },
372            // identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
373            (Expr::Cast { expr, dtype, .. }, Expr::Literal(lv)) if lv.extract_str().is_some() => {
374                let s = lv.extract_str().unwrap();
375                match &**expr {
376                    Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
377                    _ => (None, Some(s), Some(dtype)),
378                }
379            },
380            _ => (None, None, None),
381        } {
382            if expr_dtype.is_none() && self.active_schema.is_none() {
383                right.clone()
384            } else {
385                let left_dtype = expr_dtype.or_else(|| {
386                    self.active_schema
387                        .as_ref()
388                        .and_then(|schema| schema.get(&name))
389                });
390                match left_dtype {
391                    Some(DataType::Time) if is_iso_time(s) => {
392                        right.clone().str().to_time(StrptimeOptions {
393                            strict: true,
394                            ..Default::default()
395                        })
396                    },
397                    Some(DataType::Date) if is_iso_date(s) => {
398                        right.clone().str().to_date(StrptimeOptions {
399                            strict: true,
400                            ..Default::default()
401                        })
402                    },
403                    Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
404                        if s.len() == 10 {
405                            // handle upcast from ISO date string (10 chars) to datetime
406                            lit(format!("{}T00:00:00", s))
407                        } else {
408                            lit(s.replacen(' ', "T", 1))
409                        }
410                        .str()
411                        .to_datetime(
412                            Some(*tu),
413                            tz.clone(),
414                            StrptimeOptions {
415                                strict: true,
416                                ..Default::default()
417                            },
418                            lit("latest"),
419                        )
420                    },
421                    _ => right.clone(),
422                }
423            }
424        } else {
425            right.clone()
426        }
427    }
428
429    fn struct_field_access_expr(
430        &mut self,
431        expr: &Expr,
432        path: &str,
433        infer_index: bool,
434    ) -> PolarsResult<Expr> {
435        let path_elems = if path.starts_with('{') && path.ends_with('}') {
436            path.trim_matches(|c| c == '{' || c == '}')
437        } else {
438            path
439        }
440        .split(',');
441
442        let mut expr = expr.clone();
443        for p in path_elems {
444            let p = p.trim();
445            expr = if infer_index {
446                match p.parse::<i64>() {
447                    Ok(idx) => expr.list().get(lit(idx), true),
448                    Err(_) => expr.struct_().field_by_name(p),
449                }
450            } else {
451                expr.struct_().field_by_name(p)
452            }
453        }
454        Ok(expr)
455    }
456
457    /// Visit a SQL binary operator.
458    ///
459    /// e.g. "column + 1", "column1 <= column2"
460    fn visit_binary_op(
461        &mut self,
462        left: &SQLExpr,
463        op: &SQLBinaryOperator,
464        right: &SQLExpr,
465    ) -> PolarsResult<Expr> {
466        // need special handling for interval offsets and comparisons
467        let (lhs, mut rhs) = match (left, op, right) {
468            (_, SQLBinaryOperator::Minus, SQLExpr::Interval(v)) => {
469                let duration = interval_to_duration(v, false)?;
470                return Ok(self
471                    .visit_expr(left)?
472                    .dt()
473                    .offset_by(lit(format!("-{}", duration))));
474            },
475            (_, SQLBinaryOperator::Plus, SQLExpr::Interval(v)) => {
476                let duration = interval_to_duration(v, false)?;
477                return Ok(self
478                    .visit_expr(left)?
479                    .dt()
480                    .offset_by(lit(format!("{}", duration))));
481            },
482            (SQLExpr::Interval(v1), _, SQLExpr::Interval(v2)) => {
483                // shortcut interval comparison evaluation (-> bool)
484                let d1 = interval_to_duration(v1, false)?;
485                let d2 = interval_to_duration(v2, false)?;
486                let res = match op {
487                    SQLBinaryOperator::Gt => Ok(lit(d1 > d2)),
488                    SQLBinaryOperator::Lt => Ok(lit(d1 < d2)),
489                    SQLBinaryOperator::GtEq => Ok(lit(d1 >= d2)),
490                    SQLBinaryOperator::LtEq => Ok(lit(d1 <= d2)),
491                    SQLBinaryOperator::NotEq => Ok(lit(d1 != d2)),
492                    SQLBinaryOperator::Eq | SQLBinaryOperator::Spaceship => Ok(lit(d1 == d2)),
493                    _ => polars_bail!(SQLInterface: "invalid interval comparison operator"),
494                };
495                if res.is_ok() {
496                    return res;
497                }
498                (self.visit_expr(left)?, self.visit_expr(right)?)
499            },
500            _ => (self.visit_expr(left)?, self.visit_expr(right)?),
501        };
502        rhs = self.convert_temporal_strings(&lhs, &rhs);
503
504        Ok(match op {
505            // ----
506            // Bitwise operators
507            // ----
508            SQLBinaryOperator::BitwiseAnd => lhs.and(rhs),  // "x & y"
509            SQLBinaryOperator::BitwiseOr => lhs.or(rhs),  // "x | y"
510            SQLBinaryOperator::Xor => lhs.xor(rhs),  // "x XOR y"
511
512            // ----
513            // General operators
514            // ----
515            SQLBinaryOperator::And => lhs.and(rhs),  // "x AND y"
516            SQLBinaryOperator::Divide => lhs / rhs,  // "x / y"
517            SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64),  // "x // y"
518            SQLBinaryOperator::Eq => lhs.eq(rhs),  // "x = y"
519            SQLBinaryOperator::Gt => lhs.gt(rhs),  // "x > y"
520            SQLBinaryOperator::GtEq => lhs.gt_eq(rhs),  // "x >= y"
521            SQLBinaryOperator::Lt => lhs.lt(rhs),  // "x < y"
522            SQLBinaryOperator::LtEq => lhs.lt_eq(rhs),  // "x <= y"
523            SQLBinaryOperator::Minus => lhs - rhs,  // "x - y"
524            SQLBinaryOperator::Modulo => lhs % rhs,  // "x % y"
525            SQLBinaryOperator::Multiply => lhs * rhs,  // "x * y"
526            SQLBinaryOperator::NotEq => lhs.eq(rhs).not(),  // "x != y"
527            SQLBinaryOperator::Or => lhs.or(rhs),  // "x OR y"
528            SQLBinaryOperator::Plus => lhs + rhs,  // "x + y"
529            SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs),  // "x <=> y"
530            SQLBinaryOperator::StringConcat => {  // "x || y"
531                lhs.cast(DataType::String) + rhs.cast(DataType::String)
532            },
533            SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs),  // "x ^@ y"
534            // ----
535            // Regular expression operators
536            // ----
537            SQLBinaryOperator::PGRegexMatch => match rhs {  // "x ~ y"
538                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true),
539                _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
540            },
541            SQLBinaryOperator::PGRegexNotMatch => match rhs {  // "x !~ y"
542                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true).not(),
543                _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
544            },
545            SQLBinaryOperator::PGRegexIMatch => match rhs {  // "x ~* y"
546                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
547                    let pat = lv.extract_str().unwrap();
548                    lhs.str().contains(lit(format!("(?i){}", pat)), true)
549                },
550                _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
551            },
552            SQLBinaryOperator::PGRegexNotIMatch => match rhs {  // "x !~* y"
553                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
554                    let pat = lv.extract_str().unwrap();
555                    lhs.str().contains(lit(format!("(?i){}", pat)), true).not()
556                },
557                _ => {
558                    polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
559                },
560            },
561            // ----
562            // LIKE/ILIKE operators
563            // ----
564            SQLBinaryOperator::PGLikeMatch  // "x ~~ y"
565            | SQLBinaryOperator::PGNotLikeMatch  // "x !~~ y"
566            | SQLBinaryOperator::PGILikeMatch  // "x ~~* y"
567            | SQLBinaryOperator::PGNotILikeMatch => {  // "x !~~* y"
568                let expr = if matches!(
569                    op,
570                    SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
571                ) {
572                    SQLExpr::Like {
573                        negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
574                        any: false,
575                        expr: Box::new(left.clone()),
576                        pattern: Box::new(right.clone()),
577                        escape_char: None,
578                    }
579                } else {
580                    SQLExpr::ILike {
581                        negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
582                        any: false,
583                        expr: Box::new(left.clone()),
584                        pattern: Box::new(right.clone()),
585                        escape_char: None,
586                    }
587                };
588                self.visit_expr(&expr)?
589            },
590            // ----
591            // JSON/Struct field access operators
592            // ----
593            SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs {  // "x -> y", "x ->> y"
594                Expr::Literal(lv) if lv.extract_str().is_some() => {
595                    let path = lv.extract_str().unwrap();
596                    let mut expr = self.struct_field_access_expr(&lhs, path, false)?;
597                    if let SQLBinaryOperator::LongArrow = op {
598                        expr = expr.cast(DataType::String);
599                    }
600                    expr
601                },
602                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(idx))) => {
603                    let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
604                    if let SQLBinaryOperator::LongArrow = op {
605                        expr = expr.cast(DataType::String);
606                    }
607                    expr
608                },
609                _ => {
610                    polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
611                },
612            },
613            SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => {  // "x #> y", "x #>> y"
614                match rhs {
615                    Expr::Literal(lv) if lv.extract_str().is_some() => {
616                        let path = lv.extract_str().unwrap();
617                        let mut expr = self.struct_field_access_expr(&lhs, path, true)?;
618                        if let SQLBinaryOperator::HashLongArrow = op {
619                            expr = expr.cast(DataType::String);
620                        }
621                        expr
622                    },
623                    _ => {
624                        polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
625                    }
626                }
627            },
628            other => {
629                polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
630            },
631        })
632    }
633
634    /// Visit a SQL unary operator.
635    ///
636    /// e.g. +column or -column
637    fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
638        let expr = self.visit_expr(expr)?;
639        Ok(match (op, expr.clone()) {
640            // simplify the parse tree by special-casing common unary +/- ops
641            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
642                lit(n)
643            },
644            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
645                lit(n)
646            },
647            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
648                lit(-n)
649            },
650            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
651                lit(-n)
652            },
653            // general case
654            (UnaryOperator::Plus, _) => lit(0) + expr,
655            (UnaryOperator::Minus, _) => lit(0) - expr,
656            (UnaryOperator::Not, _) => expr.not(),
657            other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
658        })
659    }
660
661    /// Visit a SQL function.
662    ///
663    /// e.g. SUM(column) or COUNT(*)
664    ///
665    /// See [SQLFunctionVisitor] for more details
666    fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
667        let mut visitor = SQLFunctionVisitor {
668            func: function,
669            ctx: self.ctx,
670            active_schema: self.active_schema,
671        };
672        visitor.visit_function()
673    }
674
675    /// Visit a SQL `ALL` expression.
676    ///
677    /// e.g. `a > ALL(y)`
678    fn visit_all(
679        &mut self,
680        left: &SQLExpr,
681        compare_op: &SQLBinaryOperator,
682        right: &SQLExpr,
683    ) -> PolarsResult<Expr> {
684        let left = self.visit_expr(left)?;
685        let right = self.visit_expr(right)?;
686
687        match compare_op {
688            SQLBinaryOperator::Gt => Ok(left.gt(right.max())),
689            SQLBinaryOperator::Lt => Ok(left.lt(right.min())),
690            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
691            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
692            SQLBinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
693            SQLBinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
694            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
695        }
696    }
697
698    /// Visit a SQL `ANY` expression.
699    ///
700    /// e.g. `a != ANY(y)`
701    fn visit_any(
702        &mut self,
703        left: &SQLExpr,
704        compare_op: &SQLBinaryOperator,
705        right: &SQLExpr,
706    ) -> PolarsResult<Expr> {
707        let left = self.visit_expr(left)?;
708        let right = self.visit_expr(right)?;
709
710        match compare_op {
711            SQLBinaryOperator::Gt => Ok(left.gt(right.min())),
712            SQLBinaryOperator::Lt => Ok(left.lt(right.max())),
713            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
714            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
715            SQLBinaryOperator::Eq => Ok(left.is_in(right, false)),
716            SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()),
717            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
718        }
719    }
720
721    /// Visit a SQL `ARRAY` list (including `IN` values).
722    fn visit_array_expr(
723        &mut self,
724        elements: &[SQLExpr],
725        result_as_element: bool,
726        dtype_expr_match: Option<&Expr>,
727    ) -> PolarsResult<Expr> {
728        let mut elems = self.array_expr_to_series(elements)?;
729
730        // handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
731        // (not yet as versatile as the temporal string conversions in visit_binary_op)
732        if let (Some(Expr::Column(name)), Some(schema)) =
733            (dtype_expr_match, self.active_schema.as_ref())
734        {
735            if elems.dtype() == &DataType::String {
736                if let Some(dtype) = schema.get(name) {
737                    if matches!(
738                        dtype,
739                        DataType::Date | DataType::Time | DataType::Datetime(_, _)
740                    ) {
741                        elems = elems.strict_cast(dtype)?;
742                    }
743                }
744            }
745        }
746
747        // if we are parsing the list as an element in a series, implode.
748        // otherwise, return the series as-is.
749        let res = if result_as_element {
750            elems.implode()?.into_series()
751        } else {
752            elems
753        };
754        Ok(lit(res))
755    }
756
757    /// Visit a SQL `CAST` or `TRY_CAST` expression.
758    ///
759    /// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
760    fn visit_cast(
761        &mut self,
762        expr: &SQLExpr,
763        dtype: &SQLDataType,
764        format: &Option<CastFormat>,
765        cast_kind: &CastKind,
766    ) -> PolarsResult<Expr> {
767        if format.is_some() {
768            return Err(
769                polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
770            );
771        }
772        let expr = self.visit_expr(expr)?;
773
774        #[cfg(feature = "json")]
775        if dtype == &SQLDataType::JSON {
776            return Ok(expr.str().json_decode(None, None));
777        }
778        let polars_type = map_sql_dtype_to_polars(dtype)?;
779        Ok(match cast_kind {
780            CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
781            CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
782        })
783    }
784
785    /// Visit a SQL literal.
786    ///
787    /// e.g. 1, 'foo', 1.0, NULL
788    ///
789    /// See [SQLValue] and [LiteralValue] for more details
790    fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
791        // note: double-quoted strings will be parsed as identifiers, not literals
792        Ok(match value {
793            SQLValue::Boolean(b) => lit(*b),
794            SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
795            #[cfg(feature = "binary_encoding")]
796            SQLValue::HexStringLiteral(x) => {
797                if x.len() % 2 != 0 {
798                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
799                };
800                lit(hex::decode(x.clone()).unwrap())
801            },
802            SQLValue::Null => Expr::Literal(LiteralValue::untyped_null()),
803            SQLValue::Number(s, _) => {
804                // Check for existence of decimal separator dot
805                if s.contains('.') {
806                    s.parse::<f64>().map(lit).map_err(|_| ())
807                } else {
808                    s.parse::<i64>().map(lit).map_err(|_| ())
809                }
810                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
811            },
812            SQLValue::SingleQuotedByteStringLiteral(b) => {
813                // note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string
814                // literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser-rs
815                // patterned the token name after BigQuery (where b'str' really IS a byte string)
816                bitstring_to_bytes_literal(b)?
817            },
818            SQLValue::SingleQuotedString(s) => lit(s.clone()),
819            other => {
820                polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
821            },
822        })
823    }
824
825    /// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr.
826    fn visit_any_value(
827        &self,
828        value: &SQLValue,
829        op: Option<&UnaryOperator>,
830    ) -> PolarsResult<AnyValue> {
831        Ok(match value {
832            SQLValue::Boolean(b) => AnyValue::Boolean(*b),
833            SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
834            #[cfg(feature = "binary_encoding")]
835            SQLValue::HexStringLiteral(x) => {
836                if x.len() % 2 != 0 {
837                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
838                };
839                AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
840            },
841            SQLValue::Null => AnyValue::Null,
842            SQLValue::Number(s, _) => {
843                let negate = match op {
844                    Some(UnaryOperator::Minus) => true,
845                    // no op should be taken as plus.
846                    Some(UnaryOperator::Plus) | None => false,
847                    Some(op) => {
848                        polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
849                    },
850                };
851                // Check for existence of decimal separator dot
852                if s.contains('.') {
853                    s.parse::<f64>()
854                        .map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
855                        .map_err(|_| ())
856                } else {
857                    s.parse::<i64>()
858                        .map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
859                        .map_err(|_| ())
860                }
861                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
862            },
863            SQLValue::SingleQuotedByteStringLiteral(b) => {
864                // note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE
865                let bytes_literal = bitstring_to_bytes_literal(b)?;
866                match bytes_literal {
867                    Expr::Literal(lv) if lv.extract_binary().is_some() => {
868                        AnyValue::BinaryOwned(lv.extract_binary().unwrap().to_vec())
869                    },
870                    _ => {
871                        polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
872                    },
873                }
874            },
875            SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
876            other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
877        })
878    }
879
880    /// Visit a SQL `BETWEEN` expression.
881    /// See [sqlparser::ast::Expr::Between] for more details
882    fn visit_between(
883        &mut self,
884        expr: &SQLExpr,
885        negated: bool,
886        low: &SQLExpr,
887        high: &SQLExpr,
888    ) -> PolarsResult<Expr> {
889        let expr = self.visit_expr(expr)?;
890        let low = self.visit_expr(low)?;
891        let high = self.visit_expr(high)?;
892
893        let low = self.convert_temporal_strings(&expr, &low);
894        let high = self.convert_temporal_strings(&expr, &high);
895        Ok(if negated {
896            expr.clone().lt(low).or(expr.gt(high))
897        } else {
898            expr.clone().gt_eq(low).and(expr.lt_eq(high))
899        })
900    }
901
902    /// Visit a SQL `TRIM` function.
903    /// See [sqlparser::ast::Expr::Trim] for more details
904    fn visit_trim(
905        &mut self,
906        expr: &SQLExpr,
907        trim_where: &Option<TrimWhereField>,
908        trim_what: &Option<Box<SQLExpr>>,
909        trim_characters: &Option<Vec<SQLExpr>>,
910    ) -> PolarsResult<Expr> {
911        if trim_characters.is_some() {
912            // TODO: allow compact snowflake/bigquery syntax?
913            return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
914        };
915        let expr = self.visit_expr(expr)?;
916        let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
917        let trim_what = match trim_what {
918            Some(Expr::Literal(lv)) if lv.extract_str().is_some() => {
919                Some(PlSmallStr::from_str(lv.extract_str().unwrap()))
920            },
921            None => None,
922            _ => return self.err(&expr),
923        };
924        Ok(match (trim_where, trim_what) {
925            (None | Some(TrimWhereField::Both), None) => {
926                expr.str().strip_chars(lit(LiteralValue::untyped_null()))
927            },
928            (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
929            (Some(TrimWhereField::Leading), None) => expr
930                .str()
931                .strip_chars_start(lit(LiteralValue::untyped_null())),
932            (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
933            (Some(TrimWhereField::Trailing), None) => expr
934                .str()
935                .strip_chars_end(lit(LiteralValue::untyped_null())),
936            (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
937        })
938    }
939
940    /// Visit a SQL subquery inside and `IN` expression.
941    fn visit_in_subquery(
942        &mut self,
943        expr: &SQLExpr,
944        subquery: &Subquery,
945        negated: bool,
946    ) -> PolarsResult<Expr> {
947        let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?;
948        let expr = self.visit_expr(expr)?;
949        Ok(if negated {
950            expr.is_in(subquery_result, false).not()
951        } else {
952            expr.is_in(subquery_result, false)
953        })
954    }
955
956    /// Visit `CASE` control flow expression.
957    fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
958        if let SQLExpr::Case {
959            operand,
960            conditions,
961            results,
962            else_result,
963        } = expr
964        {
965            polars_ensure!(
966                conditions.len() == results.len(),
967                SQLSyntax: "WHEN and THEN expressions must have the same length"
968            );
969            polars_ensure!(
970                !conditions.is_empty(),
971                SQLSyntax: "WHEN and THEN expressions must have at least one element"
972            );
973
974            let mut when_thens = conditions.iter().zip(results.iter());
975            let first = when_thens.next();
976            if first.is_none() {
977                polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
978            }
979            let else_res = match else_result {
980                Some(else_res) => self.visit_expr(else_res)?,
981                None => lit(LiteralValue::untyped_null()), // ELSE clause is optional; when omitted, it is implicitly NULL
982            };
983            if let Some(operand_expr) = operand {
984                let first_operand_expr = self.visit_expr(operand_expr)?;
985
986                let first = first.unwrap();
987                let first_cond = first_operand_expr.eq(self.visit_expr(first.0)?);
988                let first_then = self.visit_expr(first.1)?;
989                let expr = when(first_cond).then(first_then);
990                let next = when_thens.next();
991
992                let mut when_then = if let Some((cond, res)) = next {
993                    let second_operand_expr = self.visit_expr(operand_expr)?;
994                    let cond = second_operand_expr.eq(self.visit_expr(cond)?);
995                    let res = self.visit_expr(res)?;
996                    expr.when(cond).then(res)
997                } else {
998                    return Ok(expr.otherwise(else_res));
999                };
1000                for (cond, res) in when_thens {
1001                    let new_operand_expr = self.visit_expr(operand_expr)?;
1002                    let cond = new_operand_expr.eq(self.visit_expr(cond)?);
1003                    let res = self.visit_expr(res)?;
1004                    when_then = when_then.when(cond).then(res);
1005                }
1006                return Ok(when_then.otherwise(else_res));
1007            }
1008
1009            let first = first.unwrap();
1010            let first_cond = self.visit_expr(first.0)?;
1011            let first_then = self.visit_expr(first.1)?;
1012            let expr = when(first_cond).then(first_then);
1013            let next = when_thens.next();
1014
1015            let mut when_then = if let Some((cond, res)) = next {
1016                let cond = self.visit_expr(cond)?;
1017                let res = self.visit_expr(res)?;
1018                expr.when(cond).then(res)
1019            } else {
1020                return Ok(expr.otherwise(else_res));
1021            };
1022            for (cond, res) in when_thens {
1023                let cond = self.visit_expr(cond)?;
1024                let res = self.visit_expr(res)?;
1025                when_then = when_then.when(cond).then(res);
1026            }
1027            Ok(when_then.otherwise(else_res))
1028        } else {
1029            unreachable!()
1030        }
1031    }
1032
1033    fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
1034        polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
1035    }
1036}
1037
1038/// parse a SQL expression to a polars expression
1039/// # Example
1040/// ```rust
1041/// # use polars_sql::{SQLContext, sql_expr};
1042/// # use polars_core::prelude::*;
1043/// # use polars_lazy::prelude::*;
1044/// # fn main() {
1045///
1046/// let mut ctx = SQLContext::new();
1047/// let df = df! {
1048///    "a" =>  [1, 2, 3],
1049/// }
1050/// .unwrap();
1051/// let expr = sql_expr("MAX(a)").unwrap();
1052/// df.lazy().select(vec![expr]).collect().unwrap();
1053/// # }
1054/// ```
1055pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1056    let mut ctx = SQLContext::new();
1057
1058    let mut parser = Parser::new(&GenericDialect);
1059    parser = parser.with_options(ParserOptions {
1060        trailing_commas: true,
1061        ..Default::default()
1062    });
1063
1064    let mut ast = parser
1065        .try_with_sql(s.as_ref())
1066        .map_err(to_sql_interface_err)?;
1067    let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1068
1069    Ok(match &expr {
1070        SelectItem::ExprWithAlias { expr, alias } => {
1071            let expr = parse_sql_expr(expr, &mut ctx, None)?;
1072            expr.alias(alias.value.as_str())
1073        },
1074        SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1075        _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1076    })
1077}
1078
1079pub(crate) fn interval_to_duration(interval: &Interval, fixed: bool) -> PolarsResult<Duration> {
1080    if interval.last_field.is_some()
1081        || interval.leading_field.is_some()
1082        || interval.leading_precision.is_some()
1083        || interval.fractional_seconds_precision.is_some()
1084    {
1085        polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
1086    }
1087    let s = match &*interval.value {
1088        SQLExpr::UnaryOp { .. } => {
1089            polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
1090        },
1091        SQLExpr::Value(SQLValue::SingleQuotedString(s)) => Some(s),
1092        _ => None,
1093    };
1094    match s {
1095        Some(s) if s.contains('-') => {
1096            polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
1097        },
1098        Some(s) => {
1099            // years, quarters, and months do not have a fixed duration; these
1100            // interval parts can only be used with respect to a reference point
1101            let duration = Duration::parse_interval(s);
1102            if fixed && duration.months() != 0 {
1103                polars_bail!(SQLSyntax: "fixed-duration interval cannot contain years, quarters, or months; found {}", s)
1104            };
1105            Ok(duration)
1106        },
1107        None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
1108    }
1109}
1110
1111pub(crate) fn parse_sql_expr(
1112    expr: &SQLExpr,
1113    ctx: &mut SQLContext,
1114    active_schema: Option<&Schema>,
1115) -> PolarsResult<Expr> {
1116    let mut visitor = SQLExprVisitor { ctx, active_schema };
1117    visitor.visit_expr(expr)
1118}
1119
1120pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1121    match expr {
1122        SQLExpr::Array(arr) => {
1123            let mut visitor = SQLExprVisitor {
1124                ctx,
1125                active_schema: None,
1126            };
1127            visitor.array_expr_to_series(arr.elem.as_slice())
1128        },
1129        _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1130    }
1131}
1132
1133pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1134    let field = match field {
1135        // handle 'DATE_PART' and all valid abbreviations/alternates
1136        DateTimeField::Custom(Ident { value, .. }) => {
1137            let value = value.to_ascii_lowercase();
1138            match value.as_str() {
1139                "millennium" | "millennia" => &DateTimeField::Millennium,
1140                "century" | "centuries" => &DateTimeField::Century,
1141                "decade" | "decades" => &DateTimeField::Decade,
1142                "isoyear" => &DateTimeField::Isoyear,
1143                "year" | "years" | "y" => &DateTimeField::Year,
1144                "quarter" | "quarters" => &DateTimeField::Quarter,
1145                "month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1146                "dayofyear" | "doy" => &DateTimeField::DayOfYear,
1147                "dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1148                "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1149                "isodow" => &DateTimeField::Isodow,
1150                "day" | "days" | "d" => &DateTimeField::Day,
1151                "hour" | "hours" | "h" => &DateTimeField::Hour,
1152                "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1153                "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1154                "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1155                "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1156                "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1157                #[cfg(feature = "timezones")]
1158                "timezone" => &DateTimeField::Timezone,
1159                "time" => &DateTimeField::Time,
1160                "epoch" => &DateTimeField::Epoch,
1161                _ => {
1162                    polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1163                },
1164            }
1165        },
1166        _ => field,
1167    };
1168    Ok(match field {
1169        DateTimeField::Millennium => expr.dt().millennium(),
1170        DateTimeField::Century => expr.dt().century(),
1171        DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1172        DateTimeField::Isoyear => expr.dt().iso_year(),
1173        DateTimeField::Year => expr.dt().year(),
1174        DateTimeField::Quarter => expr.dt().quarter(),
1175        DateTimeField::Month => expr.dt().month(),
1176        DateTimeField::Week(weekday) => {
1177            if weekday.is_some() {
1178                polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1179            }
1180            expr.dt().week()
1181        },
1182        DateTimeField::IsoWeek => expr.dt().week(),
1183        DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1184        DateTimeField::DayOfWeek | DateTimeField::Dow => {
1185            let w = expr.dt().weekday();
1186            when(w.clone().eq(typed_lit(7i8)))
1187                .then(typed_lit(0i8))
1188                .otherwise(w)
1189        },
1190        DateTimeField::Isodow => expr.dt().weekday(),
1191        DateTimeField::Day => expr.dt().day(),
1192        DateTimeField::Hour => expr.dt().hour(),
1193        DateTimeField::Minute => expr.dt().minute(),
1194        DateTimeField::Second => expr.dt().second(),
1195        DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1196            (expr.clone().dt().second() * typed_lit(1_000f64))
1197                + expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1198        },
1199        DateTimeField::Microsecond | DateTimeField::Microseconds => {
1200            (expr.clone().dt().second() * typed_lit(1_000_000f64))
1201                + expr.dt().nanosecond().div(typed_lit(1_000f64))
1202        },
1203        DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1204            (expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1205        },
1206        DateTimeField::Time => expr.dt().time(),
1207        #[cfg(feature = "timezones")]
1208        DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(),
1209        DateTimeField::Epoch => {
1210            expr.clone()
1211                .dt()
1212                .timestamp(TimeUnit::Nanoseconds)
1213                .div(typed_lit(1_000_000_000i64))
1214                + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1215        },
1216        _ => {
1217            polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1218        },
1219    })
1220}
1221
1222/// Allow an expression that represents a 1-indexed parameter to
1223/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars)
1224pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1225    match idx {
1226        Expr::Literal(sc) if sc.is_null() => lit(LiteralValue::untyped_null()),
1227        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => {
1228            if null_if_zero {
1229                lit(LiteralValue::untyped_null())
1230            } else {
1231                idx
1232            }
1233        },
1234        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n < 0 => idx,
1235        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => lit(n - 1),
1236        // TODO: when 'saturating_sub' is available, should be able
1237        //  to streamline the when/then/otherwise block below -
1238        _ => when(idx.clone().gt(lit(0)))
1239            .then(idx.clone() - lit(1))
1240            .otherwise(if null_if_zero {
1241                when(idx.clone().eq(lit(0)))
1242                    .then(lit(LiteralValue::untyped_null()))
1243                    .otherwise(idx.clone())
1244            } else {
1245                idx.clone()
1246            }),
1247    }
1248}
1249
1250fn resolve_column<'a>(
1251    ctx: &'a mut SQLContext,
1252    ident_root: &'a Ident,
1253    name: &'a str,
1254    dtype: &'a DataType,
1255) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1256    let resolved = ctx.resolve_name(&ident_root.value, name);
1257    let resolved = resolved.as_str();
1258    Ok((
1259        if name != resolved {
1260            col(resolved).alias(name)
1261        } else {
1262            col(name)
1263        },
1264        Some(dtype),
1265    ))
1266}
1267
1268pub(crate) fn resolve_compound_identifier(
1269    ctx: &mut SQLContext,
1270    idents: &[Ident],
1271    active_schema: Option<&Schema>,
1272) -> PolarsResult<Vec<Expr>> {
1273    // inference priority: table > struct > column
1274    let ident_root = &idents[0];
1275    let mut remaining_idents = idents.iter().skip(1);
1276    let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1277
1278    let schema = if let Some(ref mut lf) = lf {
1279        lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)
1280    } else {
1281        Ok(Arc::new(if let Some(active_schema) = active_schema {
1282            active_schema.clone()
1283        } else {
1284            Schema::default()
1285        }))
1286    }?;
1287
1288    let col_dtype: PolarsResult<(Expr, Option<&DataType>)> = if lf.is_none() && schema.is_empty() {
1289        Ok((col(ident_root.value.as_str()), None))
1290    } else {
1291        let name = &remaining_idents.next().unwrap().value;
1292        if lf.is_some() && name == "*" {
1293            return Ok(schema
1294                .iter_names_and_dtypes()
1295                .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).unwrap().0)
1296                .collect::<Vec<_>>());
1297        };
1298        let root_is_field = schema.get(&ident_root.value).is_some();
1299        if lf.is_none() && root_is_field {
1300            remaining_idents = idents.iter().skip(1);
1301            Ok((
1302                col(ident_root.value.as_str()),
1303                schema.get(&ident_root.value),
1304            ))
1305        } else if lf.is_none() && !root_is_field {
1306            polars_bail!(
1307                SQLInterface: "no table or struct column named '{}' found",
1308                ident_root
1309            )
1310        } else if let Some((_, name, dtype)) = schema.get_full(name) {
1311            resolve_column(ctx, ident_root, name, dtype)
1312        } else {
1313            polars_bail!(
1314                SQLInterface: "no column named '{}' found in table '{}'",
1315                name,
1316                ident_root
1317            )
1318        }
1319    };
1320
1321    // additional ident levels index into struct fields
1322    let (mut column, mut dtype) = col_dtype?;
1323    for ident in remaining_idents {
1324        let name = ident.value.as_str();
1325        match dtype {
1326            Some(DataType::Struct(fields)) if name == "*" => {
1327                return Ok(fields
1328                    .iter()
1329                    .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1330                    .collect());
1331            },
1332            Some(DataType::Struct(fields)) => {
1333                dtype = fields
1334                    .iter()
1335                    .find(|fld| fld.name == name)
1336                    .map(|fld| &fld.dtype);
1337            },
1338            Some(dtype) if name == "*" => {
1339                polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1340            },
1341            _ => {
1342                dtype = None;
1343            },
1344        }
1345        column = column.struct_().field_by_name(name);
1346    }
1347    Ok(vec![column])
1348}