polars_sql/
sql_expr.rs

1//! Expressions that are supported by the Polars SQL interface.
2//!
3//! This is useful for syntax highlighting
4//!
5//! This module defines:
6//! - all Polars SQL keywords [`all_keywords`]
7//! - all of polars SQL functions [`all_functions`]
8
9use std::fmt::Display;
10use std::ops::Div;
11
12use polars_core::prelude::*;
13use polars_lazy::prelude::*;
14use polars_plan::plans::DynLiteralValue;
15use polars_plan::prelude::typed_lit;
16use polars_time::Duration;
17use rand::Rng;
18use rand::distr::Alphanumeric;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21use sqlparser::ast::{
22    BinaryOperator as SQLBinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
23    DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, Interval, Query as Subquery,
24    SelectItem, Subscript, TimezoneInfo, TrimWhereField, UnaryOperator, Value as SQLValue,
25};
26use sqlparser::dialect::GenericDialect;
27use sqlparser::parser::{Parser, ParserOptions};
28
29use crate::SQLContext;
30use crate::functions::SQLFunctionVisitor;
31use crate::types::{
32    bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
33};
34
35#[inline]
36#[cold]
37#[must_use]
38/// Convert a Display-able error to PolarsError::SQLInterface
39pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
40    PolarsError::SQLInterface(err.to_string().into())
41}
42
43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
45/// Categorises the type of (allowed) subquery constraint
46pub enum SubqueryRestriction {
47    /// Subquery must return a single column
48    SingleColumn,
49    // SingleRow,
50    // SingleValue,
51    // Any
52}
53
54/// Recursively walks a SQL Expr to create a polars Expr
55pub(crate) struct SQLExprVisitor<'a> {
56    ctx: &'a mut SQLContext,
57    active_schema: Option<&'a Schema>,
58}
59
60impl SQLExprVisitor<'_> {
61    fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
62        let mut array_elements = Vec::with_capacity(elements.len());
63        for e in elements {
64            let val = match e {
65                SQLExpr::Value(v) => self.visit_any_value(v, None),
66                SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
67                    SQLExpr::Value(v) => self.visit_any_value(v, Some(op)),
68                    _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
69                },
70                SQLExpr::Array(values) => {
71                    let srs = self.array_expr_to_series(&values.elem)?;
72                    Ok(AnyValue::List(srs))
73                },
74                _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
75            }?
76            .into_static();
77            array_elements.push(val);
78        }
79        Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
80    }
81
82    fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
83        match expr {
84            SQLExpr::AllOp {
85                left,
86                compare_op,
87                right,
88            } => self.visit_all(left, compare_op, right),
89            SQLExpr::AnyOp {
90                left,
91                compare_op,
92                right,
93                is_some: _,
94            } => self.visit_any(left, compare_op, right),
95            SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
96            SQLExpr::Between {
97                expr,
98                negated,
99                low,
100                high,
101            } => self.visit_between(expr, *negated, low, high),
102            SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
103            SQLExpr::Cast {
104                kind,
105                expr,
106                data_type,
107                format,
108            } => self.visit_cast(expr, data_type, format, kind),
109            SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
110            SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
111            SQLExpr::Extract {
112                field,
113                syntax: _,
114                expr,
115            } => parse_extract_date_part(self.visit_expr(expr)?, field),
116            SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
117            SQLExpr::Function(function) => self.visit_function(function),
118            SQLExpr::Identifier(ident) => self.visit_identifier(ident),
119            SQLExpr::InList {
120                expr,
121                list,
122                negated,
123            } => {
124                let expr = self.visit_expr(expr)?;
125                let elems = self.visit_array_expr(list, true, Some(&expr))?;
126                let is_in = expr.is_in(elems, false);
127                Ok(if *negated { is_in.not() } else { is_in })
128            },
129            SQLExpr::InSubquery {
130                expr,
131                subquery,
132                negated,
133            } => self.visit_in_subquery(expr, subquery, *negated),
134            SQLExpr::Interval(interval) => Ok(lit(interval_to_duration(interval, true)?)),
135            SQLExpr::IsDistinctFrom(e1, e2) => {
136                Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
137            },
138            SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
139            SQLExpr::IsNotDistinctFrom(e1, e2) => {
140                Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
141            },
142            SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
143            SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
144            SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
145            SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
146            SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
147            SQLExpr::Like {
148                negated,
149                any,
150                expr,
151                pattern,
152                escape_char,
153            } => {
154                if *any {
155                    polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
156                }
157                self.visit_like(*negated, expr, pattern, escape_char, false)
158            },
159            SQLExpr::ILike {
160                negated,
161                any,
162                expr,
163                pattern,
164                escape_char,
165            } => {
166                if *any {
167                    polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
168                }
169                self.visit_like(*negated, expr, pattern, escape_char, true)
170            },
171            SQLExpr::Nested(expr) => self.visit_expr(expr),
172            SQLExpr::Position { expr, r#in } => Ok(
173                // note: SQL is 1-indexed
174                (self
175                    .visit_expr(r#in)?
176                    .str()
177                    .find(self.visit_expr(expr)?, true)
178                    + typed_lit(1u32))
179                .fill_null(typed_lit(0u32)),
180            ),
181            SQLExpr::RLike {
182                // note: parses both RLIKE and REGEXP
183                negated,
184                expr,
185                pattern,
186                regexp: _,
187            } => {
188                let matches = self
189                    .visit_expr(expr)?
190                    .str()
191                    .contains(self.visit_expr(pattern)?, true);
192                Ok(if *negated { matches.not() } else { matches })
193            },
194            SQLExpr::Subscript { expr, subscript } => self.visit_subscript(expr, subscript),
195            SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
196            SQLExpr::Trim {
197                expr,
198                trim_where,
199                trim_what,
200                trim_characters,
201            } => self.visit_trim(expr, trim_where, trim_what, trim_characters),
202            SQLExpr::TypedString { data_type, value } => match data_type {
203                SQLDataType::Date => {
204                    if is_iso_date(value) {
205                        Ok(lit(value.as_str()).cast(DataType::Date))
206                    } else {
207                        polars_bail!(SQLSyntax: "invalid DATE literal '{}'", value)
208                    }
209                },
210                SQLDataType::Time(None, TimezoneInfo::None) => {
211                    if is_iso_time(value) {
212                        Ok(lit(value.as_str()).str().to_time(StrptimeOptions {
213                            strict: true,
214                            ..Default::default()
215                        }))
216                    } else {
217                        polars_bail!(SQLSyntax: "invalid TIME literal '{}'", value)
218                    }
219                },
220                SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
221                    if is_iso_datetime(value) {
222                        Ok(lit(value.as_str()).str().to_datetime(
223                            None,
224                            None,
225                            StrptimeOptions {
226                                strict: true,
227                                ..Default::default()
228                            },
229                            lit("latest"),
230                        ))
231                    } else {
232                        let fn_name = match data_type {
233                            SQLDataType::Timestamp(_, _) => "TIMESTAMP",
234                            SQLDataType::Datetime(_) => "DATETIME",
235                            _ => unreachable!(),
236                        };
237                        polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, value)
238                    }
239                },
240                _ => {
241                    polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
242                },
243            },
244            SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
245            SQLExpr::Value(value) => self.visit_literal(value),
246            SQLExpr::Wildcard(_) => Ok(all().as_expr()),
247            e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
248            other => {
249                polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
250            },
251        }
252    }
253
254    fn visit_subquery(
255        &mut self,
256        subquery: &Subquery,
257        restriction: SubqueryRestriction,
258    ) -> PolarsResult<Expr> {
259        if subquery.with.is_some() {
260            polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
261        }
262        let mut lf = self.ctx.execute_query_no_ctes(subquery)?;
263        let schema = self.ctx.get_frame_schema(&mut lf)?;
264
265        if restriction == SubqueryRestriction::SingleColumn {
266            if schema.len() != 1 {
267                polars_bail!(SQLSyntax: "SQL subquery returns more than one column");
268            }
269            let rand_string: String = rand::rng()
270                .sample_iter(&Alphanumeric)
271                .take(16)
272                .map(char::from)
273                .collect();
274
275            let schema_entry = schema.get_at_index(0);
276            if let Some((old_name, _)) = schema_entry {
277                let new_name = String::from(old_name.as_str()) + rand_string.as_str();
278                lf = lf.rename([old_name.to_string()], [new_name.clone()], true);
279                return Ok(Expr::SubPlan(
280                    SpecialEq::new(Arc::new(lf.logical_plan)),
281                    vec![new_name],
282                ));
283            }
284        };
285        polars_bail!(SQLInterface: "subquery type not supported");
286    }
287
288    /// Visit a single SQL identifier.
289    ///
290    /// e.g. column
291    fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
292        Ok(col(ident.value.as_str()))
293    }
294
295    /// Visit a compound SQL identifier
296    ///
297    /// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields)
298    fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
299        Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
300    }
301
302    fn visit_like(
303        &mut self,
304        negated: bool,
305        expr: &SQLExpr,
306        pattern: &SQLExpr,
307        escape_char: &Option<String>,
308        case_insensitive: bool,
309    ) -> PolarsResult<Expr> {
310        if escape_char.is_some() {
311            polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
312        }
313        let pat = match self.visit_expr(pattern) {
314            Ok(Expr::Literal(lv)) if lv.extract_str().is_some() => {
315                PlSmallStr::from_str(lv.extract_str().unwrap())
316            },
317            _ => {
318                polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
319            },
320        };
321        if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
322            // empty string or other exact literal match (eg: no wildcard chars)
323            let op = if negated {
324                SQLBinaryOperator::NotEq
325            } else {
326                SQLBinaryOperator::Eq
327            };
328            self.visit_binary_op(expr, &op, pattern)
329        } else {
330            // create regex from pattern containing SQL wildcard chars ('%' => '.*', '_' => '.')
331            let mut rx = regex::escape(pat.as_str())
332                .replace('%', ".*")
333                .replace('_', ".");
334
335            rx = format!(
336                "^{}{}$",
337                if case_insensitive { "(?is)" } else { "(?s)" },
338                rx
339            );
340
341            let expr = self.visit_expr(expr)?;
342            let matches = expr.str().contains(lit(rx), true);
343            Ok(if negated { matches.not() } else { matches })
344        }
345    }
346
347    fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
348        let expr = self.visit_expr(expr)?;
349        Ok(match subscript {
350            Subscript::Index { index } => {
351                let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
352                expr.list().get(idx, true)
353            },
354            Subscript::Slice { .. } => {
355                polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
356            },
357        })
358    }
359
360    /// Handle implicit temporal string comparisons.
361    ///
362    /// eg: clauses such as -
363    ///   "dt >= '2024-04-30'"
364    ///   "dt = '2077-10-10'::date"
365    ///   "dtm::date = '2077-10-10'
366    fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
367        if let (Some(name), Some(s), expr_dtype) = match (left, right) {
368            // identify "col <op> string" expressions
369            (Expr::Column(name), Expr::Literal(lv)) if lv.extract_str().is_some() => {
370                (Some(name.clone()), Some(lv.extract_str().unwrap()), None)
371            },
372            // identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
373            (Expr::Cast { expr, dtype, .. }, Expr::Literal(lv)) if lv.extract_str().is_some() => {
374                let s = lv.extract_str().unwrap();
375                match &**expr {
376                    Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
377                    _ => (None, Some(s), Some(dtype)),
378                }
379            },
380            _ => (None, None, None),
381        } {
382            if expr_dtype.is_none() && self.active_schema.is_none() {
383                right.clone()
384            } else {
385                let left_dtype = expr_dtype.map_or_else(
386                    || {
387                        self.active_schema
388                            .as_ref()
389                            .and_then(|schema| schema.get(&name))
390                    },
391                    |dt| dt.as_literal(),
392                );
393                match left_dtype {
394                    Some(DataType::Time) if is_iso_time(s) => {
395                        right.clone().str().to_time(StrptimeOptions {
396                            strict: true,
397                            ..Default::default()
398                        })
399                    },
400                    Some(DataType::Date) if is_iso_date(s) => {
401                        right.clone().str().to_date(StrptimeOptions {
402                            strict: true,
403                            ..Default::default()
404                        })
405                    },
406                    Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
407                        if s.len() == 10 {
408                            // handle upcast from ISO date string (10 chars) to datetime
409                            lit(format!("{s}T00:00:00"))
410                        } else {
411                            lit(s.replacen(' ', "T", 1))
412                        }
413                        .str()
414                        .to_datetime(
415                            Some(*tu),
416                            tz.clone(),
417                            StrptimeOptions {
418                                strict: true,
419                                ..Default::default()
420                            },
421                            lit("latest"),
422                        )
423                    },
424                    _ => right.clone(),
425                }
426            }
427        } else {
428            right.clone()
429        }
430    }
431
432    fn struct_field_access_expr(
433        &mut self,
434        expr: &Expr,
435        path: &str,
436        infer_index: bool,
437    ) -> PolarsResult<Expr> {
438        let path_elems = if path.starts_with('{') && path.ends_with('}') {
439            path.trim_matches(|c| c == '{' || c == '}')
440        } else {
441            path
442        }
443        .split(',');
444
445        let mut expr = expr.clone();
446        for p in path_elems {
447            let p = p.trim();
448            expr = if infer_index {
449                match p.parse::<i64>() {
450                    Ok(idx) => expr.list().get(lit(idx), true),
451                    Err(_) => expr.struct_().field_by_name(p),
452                }
453            } else {
454                expr.struct_().field_by_name(p)
455            }
456        }
457        Ok(expr)
458    }
459
460    /// Visit a SQL binary operator.
461    ///
462    /// e.g. "column + 1", "column1 <= column2"
463    fn visit_binary_op(
464        &mut self,
465        left: &SQLExpr,
466        op: &SQLBinaryOperator,
467        right: &SQLExpr,
468    ) -> PolarsResult<Expr> {
469        // need special handling for interval offsets and comparisons
470        let (lhs, mut rhs) = match (left, op, right) {
471            (_, SQLBinaryOperator::Minus, SQLExpr::Interval(v)) => {
472                let duration = interval_to_duration(v, false)?;
473                return Ok(self
474                    .visit_expr(left)?
475                    .dt()
476                    .offset_by(lit(format!("-{duration}"))));
477            },
478            (_, SQLBinaryOperator::Plus, SQLExpr::Interval(v)) => {
479                let duration = interval_to_duration(v, false)?;
480                return Ok(self
481                    .visit_expr(left)?
482                    .dt()
483                    .offset_by(lit(format!("{duration}"))));
484            },
485            (SQLExpr::Interval(v1), _, SQLExpr::Interval(v2)) => {
486                // shortcut interval comparison evaluation (-> bool)
487                let d1 = interval_to_duration(v1, false)?;
488                let d2 = interval_to_duration(v2, false)?;
489                let res = match op {
490                    SQLBinaryOperator::Gt => Ok(lit(d1 > d2)),
491                    SQLBinaryOperator::Lt => Ok(lit(d1 < d2)),
492                    SQLBinaryOperator::GtEq => Ok(lit(d1 >= d2)),
493                    SQLBinaryOperator::LtEq => Ok(lit(d1 <= d2)),
494                    SQLBinaryOperator::NotEq => Ok(lit(d1 != d2)),
495                    SQLBinaryOperator::Eq | SQLBinaryOperator::Spaceship => Ok(lit(d1 == d2)),
496                    _ => polars_bail!(SQLInterface: "invalid interval comparison operator"),
497                };
498                if res.is_ok() {
499                    return res;
500                }
501                (self.visit_expr(left)?, self.visit_expr(right)?)
502            },
503            _ => (self.visit_expr(left)?, self.visit_expr(right)?),
504        };
505        rhs = self.convert_temporal_strings(&lhs, &rhs);
506
507        Ok(match op {
508            // ----
509            // Bitwise operators
510            // ----
511            SQLBinaryOperator::BitwiseAnd => lhs.and(rhs),  // "x & y"
512            SQLBinaryOperator::BitwiseOr => lhs.or(rhs),  // "x | y"
513            SQLBinaryOperator::Xor => lhs.xor(rhs),  // "x XOR y"
514
515            // ----
516            // General operators
517            // ----
518            SQLBinaryOperator::And => lhs.and(rhs),  // "x AND y"
519            SQLBinaryOperator::Divide => lhs / rhs,  // "x / y"
520            SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64),  // "x // y"
521            SQLBinaryOperator::Eq => lhs.eq(rhs),  // "x = y"
522            SQLBinaryOperator::Gt => lhs.gt(rhs),  // "x > y"
523            SQLBinaryOperator::GtEq => lhs.gt_eq(rhs),  // "x >= y"
524            SQLBinaryOperator::Lt => lhs.lt(rhs),  // "x < y"
525            SQLBinaryOperator::LtEq => lhs.lt_eq(rhs),  // "x <= y"
526            SQLBinaryOperator::Minus => lhs - rhs,  // "x - y"
527            SQLBinaryOperator::Modulo => lhs % rhs,  // "x % y"
528            SQLBinaryOperator::Multiply => lhs * rhs,  // "x * y"
529            SQLBinaryOperator::NotEq => lhs.eq(rhs).not(),  // "x != y"
530            SQLBinaryOperator::Or => lhs.or(rhs),  // "x OR y"
531            SQLBinaryOperator::Plus => lhs + rhs,  // "x + y"
532            SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs),  // "x <=> y"
533            SQLBinaryOperator::StringConcat => {  // "x || y"
534                lhs.cast(DataType::String) + rhs.cast(DataType::String)
535            },
536            SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs),  // "x ^@ y"
537            // ----
538            // Regular expression operators
539            // ----
540            SQLBinaryOperator::PGRegexMatch => match rhs {  // "x ~ y"
541                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true),
542                _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
543            },
544            SQLBinaryOperator::PGRegexNotMatch => match rhs {  // "x !~ y"
545                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true).not(),
546                _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
547            },
548            SQLBinaryOperator::PGRegexIMatch => match rhs {  // "x ~* y"
549                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
550                    let pat = lv.extract_str().unwrap();
551                    lhs.str().contains(lit(format!("(?i){pat}")), true)
552                },
553                _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
554            },
555            SQLBinaryOperator::PGRegexNotIMatch => match rhs {  // "x !~* y"
556                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
557                    let pat = lv.extract_str().unwrap();
558                    lhs.str().contains(lit(format!("(?i){pat}")), true).not()
559                },
560                _ => {
561                    polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
562                },
563            },
564            // ----
565            // LIKE/ILIKE operators
566            // ----
567            SQLBinaryOperator::PGLikeMatch  // "x ~~ y"
568            | SQLBinaryOperator::PGNotLikeMatch  // "x !~~ y"
569            | SQLBinaryOperator::PGILikeMatch  // "x ~~* y"
570            | SQLBinaryOperator::PGNotILikeMatch => {  // "x !~~* y"
571                let expr = if matches!(
572                    op,
573                    SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
574                ) {
575                    SQLExpr::Like {
576                        negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
577                        any: false,
578                        expr: Box::new(left.clone()),
579                        pattern: Box::new(right.clone()),
580                        escape_char: None,
581                    }
582                } else {
583                    SQLExpr::ILike {
584                        negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
585                        any: false,
586                        expr: Box::new(left.clone()),
587                        pattern: Box::new(right.clone()),
588                        escape_char: None,
589                    }
590                };
591                self.visit_expr(&expr)?
592            },
593            // ----
594            // JSON/Struct field access operators
595            // ----
596            SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs {  // "x -> y", "x ->> y"
597                Expr::Literal(lv) if lv.extract_str().is_some() => {
598                    let path = lv.extract_str().unwrap();
599                    let mut expr = self.struct_field_access_expr(&lhs, path, false)?;
600                    if let SQLBinaryOperator::LongArrow = op {
601                        expr = expr.cast(DataType::String);
602                    }
603                    expr
604                },
605                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(idx))) => {
606                    let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
607                    if let SQLBinaryOperator::LongArrow = op {
608                        expr = expr.cast(DataType::String);
609                    }
610                    expr
611                },
612                _ => {
613                    polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
614                },
615            },
616            SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => {  // "x #> y", "x #>> y"
617                match rhs {
618                    Expr::Literal(lv) if lv.extract_str().is_some() => {
619                        let path = lv.extract_str().unwrap();
620                        let mut expr = self.struct_field_access_expr(&lhs, path, true)?;
621                        if let SQLBinaryOperator::HashLongArrow = op {
622                            expr = expr.cast(DataType::String);
623                        }
624                        expr
625                    },
626                    _ => {
627                        polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
628                    }
629                }
630            },
631            other => {
632                polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
633            },
634        })
635    }
636
637    /// Visit a SQL unary operator.
638    ///
639    /// e.g. +column or -column
640    fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
641        let expr = self.visit_expr(expr)?;
642        Ok(match (op, expr.clone()) {
643            // simplify the parse tree by special-casing common unary +/- ops
644            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
645                lit(n)
646            },
647            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
648                lit(n)
649            },
650            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
651                lit(-n)
652            },
653            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
654                lit(-n)
655            },
656            // general case
657            (UnaryOperator::Plus, _) => lit(0) + expr,
658            (UnaryOperator::Minus, _) => lit(0) - expr,
659            (UnaryOperator::Not, _) => expr.not(),
660            other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
661        })
662    }
663
664    /// Visit a SQL function.
665    ///
666    /// e.g. SUM(column) or COUNT(*)
667    ///
668    /// See [SQLFunctionVisitor] for more details
669    fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
670        let mut visitor = SQLFunctionVisitor {
671            func: function,
672            ctx: self.ctx,
673            active_schema: self.active_schema,
674        };
675        visitor.visit_function()
676    }
677
678    /// Visit a SQL `ALL` expression.
679    ///
680    /// e.g. `a > ALL(y)`
681    fn visit_all(
682        &mut self,
683        left: &SQLExpr,
684        compare_op: &SQLBinaryOperator,
685        right: &SQLExpr,
686    ) -> PolarsResult<Expr> {
687        let left = self.visit_expr(left)?;
688        let right = self.visit_expr(right)?;
689
690        match compare_op {
691            SQLBinaryOperator::Gt => Ok(left.gt(right.max())),
692            SQLBinaryOperator::Lt => Ok(left.lt(right.min())),
693            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
694            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
695            SQLBinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
696            SQLBinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
697            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
698        }
699    }
700
701    /// Visit a SQL `ANY` expression.
702    ///
703    /// e.g. `a != ANY(y)`
704    fn visit_any(
705        &mut self,
706        left: &SQLExpr,
707        compare_op: &SQLBinaryOperator,
708        right: &SQLExpr,
709    ) -> PolarsResult<Expr> {
710        let left = self.visit_expr(left)?;
711        let right = self.visit_expr(right)?;
712
713        match compare_op {
714            SQLBinaryOperator::Gt => Ok(left.gt(right.min())),
715            SQLBinaryOperator::Lt => Ok(left.lt(right.max())),
716            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
717            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
718            SQLBinaryOperator::Eq => Ok(left.is_in(right, false)),
719            SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()),
720            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
721        }
722    }
723
724    /// Visit a SQL `ARRAY` list (including `IN` values).
725    fn visit_array_expr(
726        &mut self,
727        elements: &[SQLExpr],
728        result_as_element: bool,
729        dtype_expr_match: Option<&Expr>,
730    ) -> PolarsResult<Expr> {
731        let mut elems = self.array_expr_to_series(elements)?;
732
733        // handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
734        // (not yet as versatile as the temporal string conversions in visit_binary_op)
735        if let (Some(Expr::Column(name)), Some(schema)) =
736            (dtype_expr_match, self.active_schema.as_ref())
737        {
738            if elems.dtype() == &DataType::String {
739                if let Some(dtype) = schema.get(name) {
740                    if matches!(
741                        dtype,
742                        DataType::Date | DataType::Time | DataType::Datetime(_, _)
743                    ) {
744                        elems = elems.strict_cast(dtype)?;
745                    }
746                }
747            }
748        }
749
750        // if we are parsing the list as an element in a series, implode.
751        // otherwise, return the series as-is.
752        let res = if result_as_element {
753            elems.implode()?.into_series()
754        } else {
755            elems
756        };
757        Ok(lit(res))
758    }
759
760    /// Visit a SQL `CAST` or `TRY_CAST` expression.
761    ///
762    /// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
763    fn visit_cast(
764        &mut self,
765        expr: &SQLExpr,
766        dtype: &SQLDataType,
767        format: &Option<CastFormat>,
768        cast_kind: &CastKind,
769    ) -> PolarsResult<Expr> {
770        if format.is_some() {
771            return Err(
772                polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
773            );
774        }
775        let expr = self.visit_expr(expr)?;
776
777        #[cfg(feature = "json")]
778        if dtype == &SQLDataType::JSON {
779            // @BROKEN: we cannot handle this.
780            return Ok(expr.str().json_decode(DataType::Struct(Vec::new())));
781        }
782        let polars_type = map_sql_dtype_to_polars(dtype)?;
783        Ok(match cast_kind {
784            CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
785            CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
786        })
787    }
788
789    /// Visit a SQL literal.
790    ///
791    /// e.g. 1, 'foo', 1.0, NULL
792    ///
793    /// See [SQLValue] and [LiteralValue] for more details
794    fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
795        // note: double-quoted strings will be parsed as identifiers, not literals
796        Ok(match value {
797            SQLValue::Boolean(b) => lit(*b),
798            SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
799            #[cfg(feature = "binary_encoding")]
800            SQLValue::HexStringLiteral(x) => {
801                if x.len() % 2 != 0 {
802                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
803                };
804                lit(hex::decode(x.clone()).unwrap())
805            },
806            SQLValue::Null => Expr::Literal(LiteralValue::untyped_null()),
807            SQLValue::Number(s, _) => {
808                // Check for existence of decimal separator dot
809                if s.contains('.') {
810                    s.parse::<f64>().map(lit).map_err(|_| ())
811                } else {
812                    s.parse::<i64>().map(lit).map_err(|_| ())
813                }
814                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
815            },
816            SQLValue::SingleQuotedByteStringLiteral(b) => {
817                // note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string
818                // literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser-rs
819                // patterned the token name after BigQuery (where b'str' really IS a byte string)
820                bitstring_to_bytes_literal(b)?
821            },
822            SQLValue::SingleQuotedString(s) => lit(s.clone()),
823            other => {
824                polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
825            },
826        })
827    }
828
829    /// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr.
830    fn visit_any_value(
831        &self,
832        value: &SQLValue,
833        op: Option<&UnaryOperator>,
834    ) -> PolarsResult<AnyValue<'_>> {
835        Ok(match value {
836            SQLValue::Boolean(b) => AnyValue::Boolean(*b),
837            SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
838            #[cfg(feature = "binary_encoding")]
839            SQLValue::HexStringLiteral(x) => {
840                if x.len() % 2 != 0 {
841                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
842                };
843                AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
844            },
845            SQLValue::Null => AnyValue::Null,
846            SQLValue::Number(s, _) => {
847                let negate = match op {
848                    Some(UnaryOperator::Minus) => true,
849                    // no op should be taken as plus.
850                    Some(UnaryOperator::Plus) | None => false,
851                    Some(op) => {
852                        polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
853                    },
854                };
855                // Check for existence of decimal separator dot
856                if s.contains('.') {
857                    s.parse::<f64>()
858                        .map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
859                        .map_err(|_| ())
860                } else {
861                    s.parse::<i64>()
862                        .map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
863                        .map_err(|_| ())
864                }
865                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
866            },
867            SQLValue::SingleQuotedByteStringLiteral(b) => {
868                // note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE
869                let bytes_literal = bitstring_to_bytes_literal(b)?;
870                match bytes_literal {
871                    Expr::Literal(lv) if lv.extract_binary().is_some() => {
872                        AnyValue::BinaryOwned(lv.extract_binary().unwrap().to_vec())
873                    },
874                    _ => {
875                        polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
876                    },
877                }
878            },
879            SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
880            other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
881        })
882    }
883
884    /// Visit a SQL `BETWEEN` expression.
885    /// See [sqlparser::ast::Expr::Between] for more details
886    fn visit_between(
887        &mut self,
888        expr: &SQLExpr,
889        negated: bool,
890        low: &SQLExpr,
891        high: &SQLExpr,
892    ) -> PolarsResult<Expr> {
893        let expr = self.visit_expr(expr)?;
894        let low = self.visit_expr(low)?;
895        let high = self.visit_expr(high)?;
896
897        let low = self.convert_temporal_strings(&expr, &low);
898        let high = self.convert_temporal_strings(&expr, &high);
899        Ok(if negated {
900            expr.clone().lt(low).or(expr.gt(high))
901        } else {
902            expr.clone().gt_eq(low).and(expr.lt_eq(high))
903        })
904    }
905
906    /// Visit a SQL `TRIM` function.
907    /// See [sqlparser::ast::Expr::Trim] for more details
908    fn visit_trim(
909        &mut self,
910        expr: &SQLExpr,
911        trim_where: &Option<TrimWhereField>,
912        trim_what: &Option<Box<SQLExpr>>,
913        trim_characters: &Option<Vec<SQLExpr>>,
914    ) -> PolarsResult<Expr> {
915        if trim_characters.is_some() {
916            // TODO: allow compact snowflake/bigquery syntax?
917            return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
918        };
919        let expr = self.visit_expr(expr)?;
920        let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
921        let trim_what = match trim_what {
922            Some(Expr::Literal(lv)) if lv.extract_str().is_some() => {
923                Some(PlSmallStr::from_str(lv.extract_str().unwrap()))
924            },
925            None => None,
926            _ => return self.err(&expr),
927        };
928        Ok(match (trim_where, trim_what) {
929            (None | Some(TrimWhereField::Both), None) => {
930                expr.str().strip_chars(lit(LiteralValue::untyped_null()))
931            },
932            (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
933            (Some(TrimWhereField::Leading), None) => expr
934                .str()
935                .strip_chars_start(lit(LiteralValue::untyped_null())),
936            (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
937            (Some(TrimWhereField::Trailing), None) => expr
938                .str()
939                .strip_chars_end(lit(LiteralValue::untyped_null())),
940            (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
941        })
942    }
943
944    /// Visit a SQL subquery inside and `IN` expression.
945    fn visit_in_subquery(
946        &mut self,
947        expr: &SQLExpr,
948        subquery: &Subquery,
949        negated: bool,
950    ) -> PolarsResult<Expr> {
951        let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?;
952        let expr = self.visit_expr(expr)?;
953        Ok(if negated {
954            expr.is_in(subquery_result, false).not()
955        } else {
956            expr.is_in(subquery_result, false)
957        })
958    }
959
960    /// Visit `CASE` control flow expression.
961    fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
962        if let SQLExpr::Case {
963            operand,
964            conditions,
965            results,
966            else_result,
967        } = expr
968        {
969            polars_ensure!(
970                conditions.len() == results.len(),
971                SQLSyntax: "WHEN and THEN expressions must have the same length"
972            );
973            polars_ensure!(
974                !conditions.is_empty(),
975                SQLSyntax: "WHEN and THEN expressions must have at least one element"
976            );
977
978            let mut when_thens = conditions.iter().zip(results.iter());
979            let first = when_thens.next();
980            if first.is_none() {
981                polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
982            }
983            let else_res = match else_result {
984                Some(else_res) => self.visit_expr(else_res)?,
985                None => lit(LiteralValue::untyped_null()), // ELSE clause is optional; when omitted, it is implicitly NULL
986            };
987            if let Some(operand_expr) = operand {
988                let first_operand_expr = self.visit_expr(operand_expr)?;
989
990                let first = first.unwrap();
991                let first_cond = first_operand_expr.eq(self.visit_expr(first.0)?);
992                let first_then = self.visit_expr(first.1)?;
993                let expr = when(first_cond).then(first_then);
994                let next = when_thens.next();
995
996                let mut when_then = if let Some((cond, res)) = next {
997                    let second_operand_expr = self.visit_expr(operand_expr)?;
998                    let cond = second_operand_expr.eq(self.visit_expr(cond)?);
999                    let res = self.visit_expr(res)?;
1000                    expr.when(cond).then(res)
1001                } else {
1002                    return Ok(expr.otherwise(else_res));
1003                };
1004                for (cond, res) in when_thens {
1005                    let new_operand_expr = self.visit_expr(operand_expr)?;
1006                    let cond = new_operand_expr.eq(self.visit_expr(cond)?);
1007                    let res = self.visit_expr(res)?;
1008                    when_then = when_then.when(cond).then(res);
1009                }
1010                return Ok(when_then.otherwise(else_res));
1011            }
1012
1013            let first = first.unwrap();
1014            let first_cond = self.visit_expr(first.0)?;
1015            let first_then = self.visit_expr(first.1)?;
1016            let expr = when(first_cond).then(first_then);
1017            let next = when_thens.next();
1018
1019            let mut when_then = if let Some((cond, res)) = next {
1020                let cond = self.visit_expr(cond)?;
1021                let res = self.visit_expr(res)?;
1022                expr.when(cond).then(res)
1023            } else {
1024                return Ok(expr.otherwise(else_res));
1025            };
1026            for (cond, res) in when_thens {
1027                let cond = self.visit_expr(cond)?;
1028                let res = self.visit_expr(res)?;
1029                when_then = when_then.when(cond).then(res);
1030            }
1031            Ok(when_then.otherwise(else_res))
1032        } else {
1033            unreachable!()
1034        }
1035    }
1036
1037    fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
1038        polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
1039    }
1040}
1041
1042/// parse a SQL expression to a polars expression
1043/// # Example
1044/// ```rust
1045/// # use polars_sql::{SQLContext, sql_expr};
1046/// # use polars_core::prelude::*;
1047/// # use polars_lazy::prelude::*;
1048/// # fn main() {
1049///
1050/// let mut ctx = SQLContext::new();
1051/// let df = df! {
1052///    "a" =>  [1, 2, 3],
1053/// }
1054/// .unwrap();
1055/// let expr = sql_expr("MAX(a)").unwrap();
1056/// df.lazy().select(vec![expr]).collect().unwrap();
1057/// # }
1058/// ```
1059pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1060    let mut ctx = SQLContext::new();
1061
1062    let mut parser = Parser::new(&GenericDialect);
1063    parser = parser.with_options(ParserOptions {
1064        trailing_commas: true,
1065        ..Default::default()
1066    });
1067
1068    let mut ast = parser
1069        .try_with_sql(s.as_ref())
1070        .map_err(to_sql_interface_err)?;
1071    let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1072
1073    Ok(match &expr {
1074        SelectItem::ExprWithAlias { expr, alias } => {
1075            let expr = parse_sql_expr(expr, &mut ctx, None)?;
1076            expr.alias(alias.value.as_str())
1077        },
1078        SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1079        _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1080    })
1081}
1082
1083pub(crate) fn interval_to_duration(interval: &Interval, fixed: bool) -> PolarsResult<Duration> {
1084    if interval.last_field.is_some()
1085        || interval.leading_field.is_some()
1086        || interval.leading_precision.is_some()
1087        || interval.fractional_seconds_precision.is_some()
1088    {
1089        polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
1090    }
1091    let s = match &*interval.value {
1092        SQLExpr::UnaryOp { .. } => {
1093            polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
1094        },
1095        SQLExpr::Value(SQLValue::SingleQuotedString(s)) => Some(s),
1096        _ => None,
1097    };
1098    match s {
1099        Some(s) if s.contains('-') => {
1100            polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
1101        },
1102        Some(s) => {
1103            // years, quarters, and months do not have a fixed duration; these
1104            // interval parts can only be used with respect to a reference point
1105            let duration = Duration::parse_interval(s);
1106            if fixed && duration.months() != 0 {
1107                polars_bail!(SQLSyntax: "fixed-duration interval cannot contain years, quarters, or months; found {}", s)
1108            };
1109            Ok(duration)
1110        },
1111        None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
1112    }
1113}
1114
1115pub(crate) fn parse_sql_expr(
1116    expr: &SQLExpr,
1117    ctx: &mut SQLContext,
1118    active_schema: Option<&Schema>,
1119) -> PolarsResult<Expr> {
1120    let mut visitor = SQLExprVisitor { ctx, active_schema };
1121    visitor.visit_expr(expr)
1122}
1123
1124pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1125    match expr {
1126        SQLExpr::Array(arr) => {
1127            let mut visitor = SQLExprVisitor {
1128                ctx,
1129                active_schema: None,
1130            };
1131            visitor.array_expr_to_series(arr.elem.as_slice())
1132        },
1133        _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1134    }
1135}
1136
1137pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1138    let field = match field {
1139        // handle 'DATE_PART' and all valid abbreviations/alternates
1140        DateTimeField::Custom(Ident { value, .. }) => {
1141            let value = value.to_ascii_lowercase();
1142            match value.as_str() {
1143                "millennium" | "millennia" => &DateTimeField::Millennium,
1144                "century" | "centuries" => &DateTimeField::Century,
1145                "decade" | "decades" => &DateTimeField::Decade,
1146                "isoyear" => &DateTimeField::Isoyear,
1147                "year" | "years" | "y" => &DateTimeField::Year,
1148                "quarter" | "quarters" => &DateTimeField::Quarter,
1149                "month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1150                "dayofyear" | "doy" => &DateTimeField::DayOfYear,
1151                "dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1152                "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1153                "isodow" => &DateTimeField::Isodow,
1154                "day" | "days" | "d" => &DateTimeField::Day,
1155                "hour" | "hours" | "h" => &DateTimeField::Hour,
1156                "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1157                "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1158                "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1159                "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1160                "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1161                #[cfg(feature = "timezones")]
1162                "timezone" => &DateTimeField::Timezone,
1163                "time" => &DateTimeField::Time,
1164                "epoch" => &DateTimeField::Epoch,
1165                _ => {
1166                    polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1167                },
1168            }
1169        },
1170        _ => field,
1171    };
1172    Ok(match field {
1173        DateTimeField::Millennium => expr.dt().millennium(),
1174        DateTimeField::Century => expr.dt().century(),
1175        DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1176        DateTimeField::Isoyear => expr.dt().iso_year(),
1177        DateTimeField::Year => expr.dt().year(),
1178        DateTimeField::Quarter => expr.dt().quarter(),
1179        DateTimeField::Month => expr.dt().month(),
1180        DateTimeField::Week(weekday) => {
1181            if weekday.is_some() {
1182                polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1183            }
1184            expr.dt().week()
1185        },
1186        DateTimeField::IsoWeek => expr.dt().week(),
1187        DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1188        DateTimeField::DayOfWeek | DateTimeField::Dow => {
1189            let w = expr.dt().weekday();
1190            when(w.clone().eq(typed_lit(7i8)))
1191                .then(typed_lit(0i8))
1192                .otherwise(w)
1193        },
1194        DateTimeField::Isodow => expr.dt().weekday(),
1195        DateTimeField::Day => expr.dt().day(),
1196        DateTimeField::Hour => expr.dt().hour(),
1197        DateTimeField::Minute => expr.dt().minute(),
1198        DateTimeField::Second => expr.dt().second(),
1199        DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1200            (expr.clone().dt().second() * typed_lit(1_000f64))
1201                + expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1202        },
1203        DateTimeField::Microsecond | DateTimeField::Microseconds => {
1204            (expr.clone().dt().second() * typed_lit(1_000_000f64))
1205                + expr.dt().nanosecond().div(typed_lit(1_000f64))
1206        },
1207        DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1208            (expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1209        },
1210        DateTimeField::Time => expr.dt().time(),
1211        #[cfg(feature = "timezones")]
1212        DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(),
1213        DateTimeField::Epoch => {
1214            expr.clone()
1215                .dt()
1216                .timestamp(TimeUnit::Nanoseconds)
1217                .div(typed_lit(1_000_000_000i64))
1218                + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1219        },
1220        _ => {
1221            polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1222        },
1223    })
1224}
1225
1226/// Allow an expression that represents a 1-indexed parameter to
1227/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars)
1228pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1229    match idx {
1230        Expr::Literal(sc) if sc.is_null() => lit(LiteralValue::untyped_null()),
1231        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => {
1232            if null_if_zero {
1233                lit(LiteralValue::untyped_null())
1234            } else {
1235                idx
1236            }
1237        },
1238        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n < 0 => idx,
1239        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => lit(n - 1),
1240        // TODO: when 'saturating_sub' is available, should be able
1241        //  to streamline the when/then/otherwise block below -
1242        _ => when(idx.clone().gt(lit(0)))
1243            .then(idx.clone() - lit(1))
1244            .otherwise(if null_if_zero {
1245                when(idx.clone().eq(lit(0)))
1246                    .then(lit(LiteralValue::untyped_null()))
1247                    .otherwise(idx.clone())
1248            } else {
1249                idx.clone()
1250            }),
1251    }
1252}
1253
1254fn resolve_column<'a>(
1255    ctx: &'a mut SQLContext,
1256    ident_root: &'a Ident,
1257    name: &'a str,
1258    dtype: &'a DataType,
1259) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1260    let resolved = ctx.resolve_name(&ident_root.value, name);
1261    let resolved = resolved.as_str();
1262    Ok((
1263        if name != resolved {
1264            col(resolved).alias(name)
1265        } else {
1266            col(name)
1267        },
1268        Some(dtype),
1269    ))
1270}
1271
1272pub(crate) fn resolve_compound_identifier(
1273    ctx: &mut SQLContext,
1274    idents: &[Ident],
1275    active_schema: Option<&Schema>,
1276) -> PolarsResult<Vec<Expr>> {
1277    // inference priority: table > struct > column
1278    let ident_root = &idents[0];
1279    let mut remaining_idents = idents.iter().skip(1);
1280    let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1281
1282    let schema = if let Some(ref mut lf) = lf {
1283        lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)
1284    } else {
1285        Ok(Arc::new(if let Some(active_schema) = active_schema {
1286            active_schema.clone()
1287        } else {
1288            Schema::default()
1289        }))
1290    }?;
1291
1292    let col_dtype: PolarsResult<(Expr, Option<&DataType>)> = if lf.is_none() && schema.is_empty() {
1293        Ok((col(ident_root.value.as_str()), None))
1294    } else {
1295        let name = &remaining_idents.next().unwrap().value;
1296        if lf.is_some() && name == "*" {
1297            return Ok(schema
1298                .iter_names_and_dtypes()
1299                .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).unwrap().0)
1300                .collect::<Vec<_>>());
1301        };
1302        let root_is_field = schema.get(&ident_root.value).is_some();
1303        if lf.is_none() && root_is_field {
1304            remaining_idents = idents.iter().skip(1);
1305            Ok((
1306                col(ident_root.value.as_str()),
1307                schema.get(&ident_root.value),
1308            ))
1309        } else if lf.is_none() && !root_is_field {
1310            polars_bail!(
1311                SQLInterface: "no table or struct column named '{}' found",
1312                ident_root
1313            )
1314        } else if let Some((_, name, dtype)) = schema.get_full(name) {
1315            resolve_column(ctx, ident_root, name, dtype)
1316        } else {
1317            polars_bail!(
1318                SQLInterface: "no column named '{}' found in table '{}'",
1319                name,
1320                ident_root
1321            )
1322        }
1323    };
1324
1325    // additional ident levels index into struct fields
1326    let (mut column, mut dtype) = col_dtype?;
1327    for ident in remaining_idents {
1328        let name = ident.value.as_str();
1329        match dtype {
1330            Some(DataType::Struct(fields)) if name == "*" => {
1331                return Ok(fields
1332                    .iter()
1333                    .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1334                    .collect());
1335            },
1336            Some(DataType::Struct(fields)) => {
1337                dtype = fields
1338                    .iter()
1339                    .find(|fld| fld.name == name)
1340                    .map(|fld| &fld.dtype);
1341            },
1342            Some(dtype) if name == "*" => {
1343                polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1344            },
1345            _ => {
1346                dtype = None;
1347            },
1348        }
1349        column = column.struct_().field_by_name(name);
1350    }
1351    Ok(vec![column])
1352}