Skip to main content

polars_sql/
sql_expr.rs

1//! Expressions that are supported by the Polars SQL interface.
2//!
3//! This is useful for syntax highlighting
4//!
5//! This module defines:
6//! - all Polars SQL keywords [`all_keywords`]
7//! - all Polars SQL functions [`all_functions`]
8
9use std::fmt::Display;
10use std::ops::Div;
11
12use polars_core::prelude::*;
13use polars_lazy::prelude::*;
14use polars_plan::plans::DynLiteralValue;
15use polars_plan::prelude::typed_lit;
16use polars_time::Duration;
17use polars_utils::unique_column_name;
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use sqlparser::ast::{
21    AccessExpr, BinaryOperator as SQLBinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
22    DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, Interval, Query as Subquery,
23    SelectItem, Subscript, TimezoneInfo, TrimWhereField, TypedString, UnaryOperator,
24    Value as SQLValue, ValueWithSpan,
25};
26use sqlparser::dialect::GenericDialect;
27use sqlparser::parser::{Parser, ParserOptions};
28
29use crate::SQLContext;
30use crate::functions::SQLFunctionVisitor;
31use crate::types::{
32    bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
33};
34
35#[inline]
36#[cold]
37#[must_use]
38/// Convert a Display-able error to PolarsError::SQLInterface
39pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
40    PolarsError::SQLInterface(err.to_string().into())
41}
42
43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
45/// Categorises the type of (allowed) subquery constraint
46pub enum SubqueryRestriction {
47    /// Subquery must return a single column
48    SingleColumn,
49    // SingleRow,
50    // SingleValue,
51    // Any
52}
53
54/// Recursively walks a SQL Expr to create a polars Expr
55pub(crate) struct SQLExprVisitor<'a> {
56    ctx: &'a mut SQLContext,
57    active_schema: Option<&'a Schema>,
58}
59
60impl SQLExprVisitor<'_> {
61    fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
62        let mut array_elements = Vec::with_capacity(elements.len());
63        for e in elements {
64            let val = match e {
65                SQLExpr::Value(ValueWithSpan { value: v, .. }) => self.visit_any_value(v, None),
66                SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
67                    SQLExpr::Value(ValueWithSpan { value: v, .. }) => {
68                        self.visit_any_value(v, Some(op))
69                    },
70                    _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
71                },
72                SQLExpr::Array(values) => {
73                    let srs = self.array_expr_to_series(&values.elem)?;
74                    Ok(AnyValue::List(srs))
75                },
76                _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
77            }?
78            .into_static();
79            array_elements.push(val);
80        }
81        Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
82    }
83
84    fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
85        match expr {
86            SQLExpr::AllOp {
87                left,
88                compare_op,
89                right,
90            } => self.visit_all(left, compare_op, right),
91            SQLExpr::AnyOp {
92                left,
93                compare_op,
94                right,
95                is_some: _,
96            } => self.visit_any(left, compare_op, right),
97            SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
98            SQLExpr::Between {
99                expr,
100                negated,
101                low,
102                high,
103            } => self.visit_between(expr, *negated, low, high),
104            SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
105            SQLExpr::Cast {
106                kind,
107                expr,
108                data_type,
109                format,
110            } => self.visit_cast(expr, data_type, format, kind),
111            SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
112            SQLExpr::CompoundFieldAccess { root, access_chain } => {
113                // simple subscript access (eg: "array_col[1]")
114                if access_chain.len() == 1 {
115                    match &access_chain[0] {
116                        AccessExpr::Subscript(subscript) => {
117                            return self.visit_subscript(root, subscript);
118                        },
119                        AccessExpr::Dot(_) => {
120                            polars_bail!(SQLSyntax: "dot-notation field access is currently unsupported: {:?}", access_chain[0])
121                        },
122                    }
123                }
124                // chained dot/bracket notation (eg: "struct_col.field[2].foo[0].bar")
125                polars_bail!(SQLSyntax: "complex field access chains are currently unsupported: {:?}", access_chain[0])
126            },
127            SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
128            SQLExpr::Extract {
129                field,
130                syntax: _,
131                expr,
132            } => parse_extract_date_part(self.visit_expr(expr)?, field),
133            SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
134            SQLExpr::Function(function) => self.visit_function(function),
135            SQLExpr::Identifier(ident) => self.visit_identifier(ident),
136            SQLExpr::InList {
137                expr,
138                list,
139                negated,
140            } => {
141                let expr = self.visit_expr(expr)?;
142                let elems = self.visit_array_expr(list, true, Some(&expr))?;
143                let is_in = expr.is_in(elems, false);
144                Ok(if *negated { is_in.not() } else { is_in })
145            },
146            SQLExpr::InSubquery {
147                expr,
148                subquery,
149                negated,
150            } => self.visit_in_subquery(expr, subquery, *negated),
151            SQLExpr::Interval(interval) => Ok(lit(interval_to_duration(interval, true)?)),
152            SQLExpr::IsDistinctFrom(e1, e2) => {
153                Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
154            },
155            SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
156            SQLExpr::IsNotDistinctFrom(e1, e2) => {
157                Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
158            },
159            SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
160            SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
161            SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
162            SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
163            SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
164            SQLExpr::Like {
165                negated,
166                any,
167                expr,
168                pattern,
169                escape_char,
170            } => {
171                if *any {
172                    polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
173                }
174                let escape_str = escape_char.as_ref().and_then(|v| match v {
175                    SQLValue::SingleQuotedString(s) => Some(s.clone()),
176                    _ => None,
177                });
178                self.visit_like(*negated, expr, pattern, &escape_str, false)
179            },
180            SQLExpr::ILike {
181                negated,
182                any,
183                expr,
184                pattern,
185                escape_char,
186            } => {
187                if *any {
188                    polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
189                }
190                let escape_str = escape_char.as_ref().and_then(|v| match v {
191                    SQLValue::SingleQuotedString(s) => Some(s.clone()),
192                    _ => None,
193                });
194                self.visit_like(*negated, expr, pattern, &escape_str, true)
195            },
196            SQLExpr::Nested(expr) => self.visit_expr(expr),
197            SQLExpr::Position { expr, r#in } => Ok(
198                // note: SQL is 1-indexed
199                (self
200                    .visit_expr(r#in)?
201                    .str()
202                    .find(self.visit_expr(expr)?, true)
203                    + typed_lit(1u32))
204                .fill_null(typed_lit(0u32)),
205            ),
206            SQLExpr::RLike {
207                // note: parses both RLIKE and REGEXP
208                negated,
209                expr,
210                pattern,
211                regexp: _,
212            } => {
213                let matches = self
214                    .visit_expr(expr)?
215                    .str()
216                    .contains(self.visit_expr(pattern)?, true);
217                Ok(if *negated { matches.not() } else { matches })
218            },
219            SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
220            SQLExpr::Substring {
221                expr,
222                substring_from,
223                substring_for,
224                ..
225            } => self.visit_substring(expr, substring_from.as_deref(), substring_for.as_deref()),
226            SQLExpr::Trim {
227                expr,
228                trim_where,
229                trim_what,
230                trim_characters,
231            } => self.visit_trim(expr, trim_where, trim_what, trim_characters),
232            SQLExpr::TypedString(TypedString {
233                data_type,
234                value:
235                    ValueWithSpan {
236                        value: SQLValue::SingleQuotedString(v),
237                        ..
238                    },
239                uses_odbc_syntax: _,
240            }) => match data_type {
241                SQLDataType::Date => {
242                    if is_iso_date(v) {
243                        Ok(lit(v.as_str()).cast(DataType::Date))
244                    } else {
245                        polars_bail!(SQLSyntax: "invalid DATE literal '{}'", v)
246                    }
247                },
248                SQLDataType::Time(None, TimezoneInfo::None) => {
249                    if is_iso_time(v) {
250                        Ok(lit(v.as_str()).str().to_time(StrptimeOptions {
251                            strict: true,
252                            ..Default::default()
253                        }))
254                    } else {
255                        polars_bail!(SQLSyntax: "invalid TIME literal '{}'", v)
256                    }
257                },
258                SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
259                    if is_iso_datetime(v) {
260                        Ok(lit(v.as_str()).str().to_datetime(
261                            None,
262                            None,
263                            StrptimeOptions {
264                                strict: true,
265                                ..Default::default()
266                            },
267                            lit("latest"),
268                        ))
269                    } else {
270                        let fn_name = match data_type {
271                            SQLDataType::Timestamp(_, _) => "TIMESTAMP",
272                            SQLDataType::Datetime(_) => "DATETIME",
273                            _ => unreachable!(),
274                        };
275                        polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, v)
276                    }
277                },
278                _ => {
279                    polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
280                },
281            },
282            SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
283            SQLExpr::Value(ValueWithSpan { value, .. }) => self.visit_literal(value),
284            SQLExpr::Wildcard(_) => Ok(all().as_expr()),
285            e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
286            other => {
287                polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
288            },
289        }
290    }
291
292    fn visit_subquery(
293        &mut self,
294        subquery: &Subquery,
295        restriction: SubqueryRestriction,
296    ) -> PolarsResult<Expr> {
297        if subquery.with.is_some() {
298            polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
299        }
300        // note: we have to execute subqueries in an isolated scope to prevent
301        // propagating any context/arena mutation into the rest of the query
302        let lf = self
303            .ctx
304            .execute_isolated(|ctx| ctx.execute_query_no_ctes(subquery))?;
305
306        if restriction == SubqueryRestriction::SingleColumn {
307            let new_name = unique_column_name();
308            return Ok(Expr::SubPlan(
309                SpecialEq::new(Arc::new(lf.logical_plan)),
310                // TODO: pass the implode depending on expr.
311                vec![(
312                    new_name.clone(),
313                    first().as_expr().implode().alias(new_name.clone()),
314                )],
315            ));
316        };
317        polars_bail!(SQLInterface: "subquery type not supported");
318    }
319
320    /// Visit a single SQL identifier.
321    ///
322    /// e.g. column
323    fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
324        Ok(col(ident.value.as_str()))
325    }
326
327    /// Visit a compound SQL identifier
328    ///
329    /// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields)
330    fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
331        Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
332    }
333
334    fn visit_like(
335        &mut self,
336        negated: bool,
337        expr: &SQLExpr,
338        pattern: &SQLExpr,
339        escape_char: &Option<String>,
340        case_insensitive: bool,
341    ) -> PolarsResult<Expr> {
342        if escape_char.is_some() {
343            polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
344        }
345        let pat = match self.visit_expr(pattern) {
346            Ok(Expr::Literal(lv)) if lv.extract_str().is_some() => {
347                PlSmallStr::from_str(lv.extract_str().unwrap())
348            },
349            _ => {
350                polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
351            },
352        };
353        if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
354            // empty string or other exact literal match (eg: no wildcard chars)
355            let op = if negated {
356                SQLBinaryOperator::NotEq
357            } else {
358                SQLBinaryOperator::Eq
359            };
360            self.visit_binary_op(expr, &op, pattern)
361        } else {
362            // create regex from pattern containing SQL wildcard chars ('%' => '.*', '_' => '.')
363            let mut rx = regex::escape(pat.as_str())
364                .replace('%', ".*")
365                .replace('_', ".");
366
367            rx = format!(
368                "^{}{}$",
369                if case_insensitive { "(?is)" } else { "(?s)" },
370                rx
371            );
372
373            let expr = self.visit_expr(expr)?;
374            let matches = expr.str().contains(lit(rx), true);
375            Ok(if negated { matches.not() } else { matches })
376        }
377    }
378
379    fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
380        let expr = self.visit_expr(expr)?;
381        Ok(match subscript {
382            Subscript::Index { index } => {
383                let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
384                expr.list().get(idx, true)
385            },
386            Subscript::Slice { .. } => {
387                polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
388            },
389        })
390    }
391
392    /// Handle implicit temporal string comparisons.
393    ///
394    /// eg: clauses such as -
395    ///   "dt >= '2024-04-30'"
396    ///   "dt = '2077-10-10'::date"
397    ///   "dtm::date = '2077-10-10'
398    fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
399        if let (Some(name), Some(s), expr_dtype) = match (left, right) {
400            // identify "col <op> string" expressions
401            (Expr::Column(name), Expr::Literal(lv)) if lv.extract_str().is_some() => {
402                (Some(name.clone()), Some(lv.extract_str().unwrap()), None)
403            },
404            // identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
405            (Expr::Cast { expr, dtype, .. }, Expr::Literal(lv)) if lv.extract_str().is_some() => {
406                let s = lv.extract_str().unwrap();
407                match &**expr {
408                    Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
409                    _ => (None, Some(s), Some(dtype)),
410                }
411            },
412            _ => (None, None, None),
413        } {
414            if expr_dtype.is_none() && self.active_schema.is_none() {
415                right.clone()
416            } else {
417                let left_dtype = expr_dtype.map_or_else(
418                    || {
419                        self.active_schema
420                            .as_ref()
421                            .and_then(|schema| schema.get(&name))
422                    },
423                    |dt| dt.as_literal(),
424                );
425                match left_dtype {
426                    Some(DataType::Time) if is_iso_time(s) => {
427                        right.clone().str().to_time(StrptimeOptions {
428                            strict: true,
429                            ..Default::default()
430                        })
431                    },
432                    Some(DataType::Date) if is_iso_date(s) => {
433                        right.clone().str().to_date(StrptimeOptions {
434                            strict: true,
435                            ..Default::default()
436                        })
437                    },
438                    Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
439                        if s.len() == 10 {
440                            // handle upcast from ISO date string (10 chars) to datetime
441                            lit(format!("{s}T00:00:00"))
442                        } else {
443                            lit(s.replacen(' ', "T", 1))
444                        }
445                        .str()
446                        .to_datetime(
447                            Some(*tu),
448                            tz.clone(),
449                            StrptimeOptions {
450                                strict: true,
451                                ..Default::default()
452                            },
453                            lit("latest"),
454                        )
455                    },
456                    _ => right.clone(),
457                }
458            }
459        } else {
460            right.clone()
461        }
462    }
463
464    fn struct_field_access_expr(
465        &mut self,
466        expr: &Expr,
467        path: &str,
468        infer_index: bool,
469    ) -> PolarsResult<Expr> {
470        let path_elems = if path.starts_with('{') && path.ends_with('}') {
471            path.trim_matches(|c| c == '{' || c == '}')
472        } else {
473            path
474        }
475        .split(',');
476
477        let mut expr = expr.clone();
478        for p in path_elems {
479            let p = p.trim();
480            expr = if infer_index {
481                match p.parse::<i64>() {
482                    Ok(idx) => expr.list().get(lit(idx), true),
483                    Err(_) => expr.struct_().field_by_name(p),
484                }
485            } else {
486                expr.struct_().field_by_name(p)
487            }
488        }
489        Ok(expr)
490    }
491
492    /// Visit a SQL binary operator.
493    ///
494    /// e.g. "column + 1", "column1 <= column2"
495    fn visit_binary_op(
496        &mut self,
497        left: &SQLExpr,
498        op: &SQLBinaryOperator,
499        right: &SQLExpr,
500    ) -> PolarsResult<Expr> {
501        // check for (unsupported) scalar subquery comparisons
502        if matches!(left, SQLExpr::Subquery(_)) || matches!(right, SQLExpr::Subquery(_)) {
503            let (suggestion, str_op) = match op {
504                SQLBinaryOperator::NotEq => ("; use 'NOT IN' instead", "!=".to_string()),
505                SQLBinaryOperator::Eq => ("; use 'IN' instead", format!("{op}")),
506                _ => ("", format!("{op}")),
507            };
508            polars_bail!(
509                SQLSyntax: "subquery comparisons with '{str_op}' are not supported{suggestion}"
510            );
511        }
512
513        // need special handling for interval offsets and comparisons
514        let (lhs, mut rhs) = match (left, op, right) {
515            (_, SQLBinaryOperator::Minus, SQLExpr::Interval(v)) => {
516                let duration = interval_to_duration(v, false)?;
517                return Ok(self
518                    .visit_expr(left)?
519                    .dt()
520                    .offset_by(lit(format!("-{duration}"))));
521            },
522            (_, SQLBinaryOperator::Plus, SQLExpr::Interval(v)) => {
523                let duration = interval_to_duration(v, false)?;
524                return Ok(self
525                    .visit_expr(left)?
526                    .dt()
527                    .offset_by(lit(format!("{duration}"))));
528            },
529            (SQLExpr::Interval(v1), _, SQLExpr::Interval(v2)) => {
530                // shortcut interval comparison evaluation (-> bool)
531                let d1 = interval_to_duration(v1, false)?;
532                let d2 = interval_to_duration(v2, false)?;
533                let res = match op {
534                    SQLBinaryOperator::Gt => Ok(lit(d1 > d2)),
535                    SQLBinaryOperator::Lt => Ok(lit(d1 < d2)),
536                    SQLBinaryOperator::GtEq => Ok(lit(d1 >= d2)),
537                    SQLBinaryOperator::LtEq => Ok(lit(d1 <= d2)),
538                    SQLBinaryOperator::NotEq => Ok(lit(d1 != d2)),
539                    SQLBinaryOperator::Eq | SQLBinaryOperator::Spaceship => Ok(lit(d1 == d2)),
540                    _ => polars_bail!(SQLInterface: "invalid interval comparison operator"),
541                };
542                if res.is_ok() {
543                    return res;
544                }
545                (self.visit_expr(left)?, self.visit_expr(right)?)
546            },
547            _ => (self.visit_expr(left)?, self.visit_expr(right)?),
548        };
549        rhs = self.convert_temporal_strings(&lhs, &rhs);
550
551        Ok(match op {
552            // ----
553            // Bitwise operators
554            // ----
555            SQLBinaryOperator::BitwiseAnd => lhs.and(rhs),  // "x & y"
556            SQLBinaryOperator::BitwiseOr => lhs.or(rhs),  // "x | y"
557            SQLBinaryOperator::Xor => lhs.xor(rhs),  // "x XOR y"
558
559            // ----
560            // General operators
561            // ----
562            SQLBinaryOperator::And => lhs.and(rhs),  // "x AND y"
563            SQLBinaryOperator::Divide => lhs / rhs,  // "x / y"
564            SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64),  // "x // y"
565            SQLBinaryOperator::Eq => lhs.eq(rhs),  // "x = y"
566            SQLBinaryOperator::Gt => lhs.gt(rhs),  // "x > y"
567            SQLBinaryOperator::GtEq => lhs.gt_eq(rhs),  // "x >= y"
568            SQLBinaryOperator::Lt => lhs.lt(rhs),  // "x < y"
569            SQLBinaryOperator::LtEq => lhs.lt_eq(rhs),  // "x <= y"
570            SQLBinaryOperator::Minus => lhs - rhs,  // "x - y"
571            SQLBinaryOperator::Modulo => lhs % rhs,  // "x % y"
572            SQLBinaryOperator::Multiply => lhs * rhs,  // "x * y"
573            SQLBinaryOperator::NotEq => lhs.eq(rhs).not(),  // "x != y"
574            SQLBinaryOperator::Or => lhs.or(rhs),  // "x OR y"
575            SQLBinaryOperator::Plus => lhs + rhs,  // "x + y"
576            SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs),  // "x <=> y"
577            SQLBinaryOperator::StringConcat => {  // "x || y"
578                lhs.cast(DataType::String) + rhs.cast(DataType::String)
579            },
580            SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs),  // "x ^@ y"
581            // ----
582            // Regular expression operators
583            // ----
584            SQLBinaryOperator::PGRegexMatch => match rhs {  // "x ~ y"
585                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true),
586                _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
587            },
588            SQLBinaryOperator::PGRegexNotMatch => match rhs {  // "x !~ y"
589                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true).not(),
590                _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
591            },
592            SQLBinaryOperator::PGRegexIMatch => match rhs {  // "x ~* y"
593                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
594                    let pat = lv.extract_str().unwrap();
595                    lhs.str().contains(lit(format!("(?i){pat}")), true)
596                },
597                _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
598            },
599            SQLBinaryOperator::PGRegexNotIMatch => match rhs {  // "x !~* y"
600                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
601                    let pat = lv.extract_str().unwrap();
602                    lhs.str().contains(lit(format!("(?i){pat}")), true).not()
603                },
604                _ => {
605                    polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
606                },
607            },
608            // ----
609            // LIKE/ILIKE operators
610            // ----
611            SQLBinaryOperator::PGLikeMatch  // "x ~~ y"
612            | SQLBinaryOperator::PGNotLikeMatch  // "x !~~ y"
613            | SQLBinaryOperator::PGILikeMatch  // "x ~~* y"
614            | SQLBinaryOperator::PGNotILikeMatch => {  // "x !~~* y"
615                let expr = if matches!(
616                    op,
617                    SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
618                ) {
619                    SQLExpr::Like {
620                        negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
621                        any: false,
622                        expr: Box::new(left.clone()),
623                        pattern: Box::new(right.clone()),
624                        escape_char: None,
625                    }
626                } else {
627                    SQLExpr::ILike {
628                        negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
629                        any: false,
630                        expr: Box::new(left.clone()),
631                        pattern: Box::new(right.clone()),
632                        escape_char: None,
633                    }
634                };
635                self.visit_expr(&expr)?
636            },
637            // ----
638            // JSON/Struct field access operators
639            // ----
640            SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs {  // "x -> y", "x ->> y"
641                Expr::Literal(lv) if lv.extract_str().is_some() => {
642                    let path = lv.extract_str().unwrap();
643                    let mut expr = self.struct_field_access_expr(&lhs, path, false)?;
644                    if let SQLBinaryOperator::LongArrow = op {
645                        expr = expr.cast(DataType::String);
646                    }
647                    expr
648                },
649                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(idx))) => {
650                    let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
651                    if let SQLBinaryOperator::LongArrow = op {
652                        expr = expr.cast(DataType::String);
653                    }
654                    expr
655                },
656                _ => {
657                    polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
658                },
659            },
660            SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => {  // "x #> y", "x #>> y"
661                match rhs {
662                    Expr::Literal(lv) if lv.extract_str().is_some() => {
663                        let path = lv.extract_str().unwrap();
664                        let mut expr = self.struct_field_access_expr(&lhs, path, true)?;
665                        if let SQLBinaryOperator::HashLongArrow = op {
666                            expr = expr.cast(DataType::String);
667                        }
668                        expr
669                    },
670                    _ => {
671                        polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
672                    }
673                }
674            },
675            other => {
676                polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
677            },
678        })
679    }
680
681    /// Visit a SQL unary operator.
682    ///
683    /// e.g. +column or -column
684    fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
685        let expr = self.visit_expr(expr)?;
686        Ok(match (op, expr.clone()) {
687            // simplify the parse tree by special-casing common unary +/- ops
688            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
689                lit(n)
690            },
691            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
692                lit(n)
693            },
694            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
695                lit(-n)
696            },
697            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
698                lit(-n)
699            },
700            // general case
701            (UnaryOperator::Plus, _) => lit(0) + expr,
702            (UnaryOperator::Minus, _) => lit(0) - expr,
703            (UnaryOperator::Not, _) => match &expr {
704                Expr::Column(name)
705                    if self
706                        .active_schema
707                        .and_then(|schema| schema.get(name))
708                        .is_some_and(|dtype| matches!(dtype, DataType::Boolean)) =>
709                {
710                    // if already boolean, can operate bitwise
711                    expr.not()
712                },
713                // otherwise SQL "NOT" expects logical, not bitwise, behaviour (eg: on integers)
714                _ => expr.strict_cast(DataType::Boolean).not(),
715            },
716            other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
717        })
718    }
719
720    /// Visit a SQL function.
721    ///
722    /// e.g. SUM(column) or COUNT(*)
723    ///
724    /// See [SQLFunctionVisitor] for more details
725    fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
726        let mut visitor = SQLFunctionVisitor {
727            func: function,
728            ctx: self.ctx,
729            active_schema: self.active_schema,
730        };
731        visitor.visit_function()
732    }
733
734    /// Visit a SQL `ALL` expression.
735    ///
736    /// e.g. `a > ALL(y)`
737    fn visit_all(
738        &mut self,
739        left: &SQLExpr,
740        compare_op: &SQLBinaryOperator,
741        right: &SQLExpr,
742    ) -> PolarsResult<Expr> {
743        let left = self.visit_expr(left)?;
744        let right = self.visit_expr(right)?;
745
746        match compare_op {
747            SQLBinaryOperator::Gt => Ok(left.gt(right.max())),
748            SQLBinaryOperator::Lt => Ok(left.lt(right.min())),
749            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
750            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
751            SQLBinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
752            SQLBinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
753            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
754        }
755    }
756
757    /// Visit a SQL `ANY` expression.
758    ///
759    /// e.g. `a != ANY(y)`
760    fn visit_any(
761        &mut self,
762        left: &SQLExpr,
763        compare_op: &SQLBinaryOperator,
764        right: &SQLExpr,
765    ) -> PolarsResult<Expr> {
766        let left = self.visit_expr(left)?;
767        let right = self.visit_expr(right)?;
768
769        match compare_op {
770            SQLBinaryOperator::Gt => Ok(left.gt(right.min())),
771            SQLBinaryOperator::Lt => Ok(left.lt(right.max())),
772            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
773            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
774            SQLBinaryOperator::Eq => Ok(left.is_in(right, false)),
775            SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()),
776            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
777        }
778    }
779
780    /// Visit a SQL `ARRAY` list (including `IN` values).
781    fn visit_array_expr(
782        &mut self,
783        elements: &[SQLExpr],
784        result_as_element: bool,
785        dtype_expr_match: Option<&Expr>,
786    ) -> PolarsResult<Expr> {
787        let mut elems = self.array_expr_to_series(elements)?;
788
789        // handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
790        // (not yet as versatile as the temporal string conversions in visit_binary_op)
791        if let (Some(Expr::Column(name)), Some(schema)) =
792            (dtype_expr_match, self.active_schema.as_ref())
793        {
794            if elems.dtype() == &DataType::String {
795                if let Some(dtype) = schema.get(name) {
796                    if matches!(
797                        dtype,
798                        DataType::Date | DataType::Time | DataType::Datetime(_, _)
799                    ) {
800                        elems = elems.strict_cast(dtype)?;
801                    }
802                }
803            }
804        }
805
806        // if we are parsing the list as an element in a series, implode.
807        // otherwise, return the series as-is.
808        let res = if result_as_element {
809            elems.implode()?.into_series()
810        } else {
811            elems
812        };
813        Ok(lit(res))
814    }
815
816    /// Visit a SQL `CAST` or `TRY_CAST` expression.
817    ///
818    /// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
819    fn visit_cast(
820        &mut self,
821        expr: &SQLExpr,
822        dtype: &SQLDataType,
823        format: &Option<CastFormat>,
824        cast_kind: &CastKind,
825    ) -> PolarsResult<Expr> {
826        if format.is_some() {
827            return Err(
828                polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
829            );
830        }
831        let expr = self.visit_expr(expr)?;
832
833        #[cfg(feature = "json")]
834        if dtype == &SQLDataType::JSON {
835            // @BROKEN: we cannot handle this.
836            return Ok(expr.str().json_decode(DataType::Struct(Vec::new())));
837        }
838        let polars_type = map_sql_dtype_to_polars(dtype)?;
839        Ok(match cast_kind {
840            CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
841            CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
842        })
843    }
844
845    /// Visit a SQL literal.
846    ///
847    /// e.g. 1, 'foo', 1.0, NULL
848    ///
849    /// See [SQLValue] and [LiteralValue] for more details
850    fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
851        // note: double-quoted strings will be parsed as identifiers, not literals
852        Ok(match value {
853            SQLValue::Boolean(b) => lit(*b),
854            SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
855            #[cfg(feature = "binary_encoding")]
856            SQLValue::HexStringLiteral(x) => {
857                if x.len() % 2 != 0 {
858                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
859                };
860                lit(hex::decode(x.clone()).unwrap())
861            },
862            SQLValue::Null => Expr::Literal(LiteralValue::untyped_null()),
863            SQLValue::Number(s, _) => {
864                // Check for existence of decimal separator dot
865                if s.contains('.') {
866                    s.parse::<f64>().map(lit).map_err(|_| ())
867                } else {
868                    s.parse::<i64>().map(lit).map_err(|_| ())
869                }
870                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
871            },
872            SQLValue::SingleQuotedByteStringLiteral(b) => {
873                // note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string
874                // literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser-rs
875                // patterned the token name after BigQuery (where b'str' really IS a byte string)
876                bitstring_to_bytes_literal(b)?
877            },
878            SQLValue::SingleQuotedString(s) => lit(s.clone()),
879            other => {
880                polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
881            },
882        })
883    }
884
885    /// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr.
886    fn visit_any_value(
887        &self,
888        value: &SQLValue,
889        op: Option<&UnaryOperator>,
890    ) -> PolarsResult<AnyValue<'_>> {
891        Ok(match value {
892            SQLValue::Boolean(b) => AnyValue::Boolean(*b),
893            SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
894            #[cfg(feature = "binary_encoding")]
895            SQLValue::HexStringLiteral(x) => {
896                if x.len() % 2 != 0 {
897                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
898                };
899                AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
900            },
901            SQLValue::Null => AnyValue::Null,
902            SQLValue::Number(s, _) => {
903                let negate = match op {
904                    Some(UnaryOperator::Minus) => true,
905                    // no op should be taken as plus.
906                    Some(UnaryOperator::Plus) | None => false,
907                    Some(op) => {
908                        polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
909                    },
910                };
911                // Check for existence of decimal separator dot
912                if s.contains('.') {
913                    s.parse::<f64>()
914                        .map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
915                        .map_err(|_| ())
916                } else {
917                    s.parse::<i64>()
918                        .map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
919                        .map_err(|_| ())
920                }
921                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
922            },
923            SQLValue::SingleQuotedByteStringLiteral(b) => {
924                // note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE
925                let bytes_literal = bitstring_to_bytes_literal(b)?;
926                match bytes_literal {
927                    Expr::Literal(lv) if lv.extract_binary().is_some() => {
928                        AnyValue::BinaryOwned(lv.extract_binary().unwrap().to_vec())
929                    },
930                    _ => {
931                        polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
932                    },
933                }
934            },
935            SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
936            other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
937        })
938    }
939
940    /// Visit a SQL `BETWEEN` expression.
941    /// See [sqlparser::ast::Expr::Between] for more details
942    fn visit_between(
943        &mut self,
944        expr: &SQLExpr,
945        negated: bool,
946        low: &SQLExpr,
947        high: &SQLExpr,
948    ) -> PolarsResult<Expr> {
949        let expr = self.visit_expr(expr)?;
950        let low = self.visit_expr(low)?;
951        let high = self.visit_expr(high)?;
952
953        let low = self.convert_temporal_strings(&expr, &low);
954        let high = self.convert_temporal_strings(&expr, &high);
955        Ok(if negated {
956            expr.clone().lt(low).or(expr.gt(high))
957        } else {
958            expr.clone().gt_eq(low).and(expr.lt_eq(high))
959        })
960    }
961
962    /// Visit a SQL `TRIM` function.
963    /// See [sqlparser::ast::Expr::Trim] for more details
964    fn visit_trim(
965        &mut self,
966        expr: &SQLExpr,
967        trim_where: &Option<TrimWhereField>,
968        trim_what: &Option<Box<SQLExpr>>,
969        trim_characters: &Option<Vec<SQLExpr>>,
970    ) -> PolarsResult<Expr> {
971        if trim_characters.is_some() {
972            // TODO: allow compact snowflake/bigquery syntax?
973            return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
974        };
975        let expr = self.visit_expr(expr)?;
976        let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
977        let trim_what = match trim_what {
978            Some(Expr::Literal(lv)) if lv.extract_str().is_some() => {
979                Some(PlSmallStr::from_str(lv.extract_str().unwrap()))
980            },
981            None => None,
982            _ => return self.err(&expr),
983        };
984        Ok(match (trim_where, trim_what) {
985            (None | Some(TrimWhereField::Both), None) => {
986                expr.str().strip_chars(lit(LiteralValue::untyped_null()))
987            },
988            (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
989            (Some(TrimWhereField::Leading), None) => expr
990                .str()
991                .strip_chars_start(lit(LiteralValue::untyped_null())),
992            (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
993            (Some(TrimWhereField::Trailing), None) => expr
994                .str()
995                .strip_chars_end(lit(LiteralValue::untyped_null())),
996            (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
997        })
998    }
999
1000    fn visit_substring(
1001        &mut self,
1002        expr: &SQLExpr,
1003        substring_from: Option<&SQLExpr>,
1004        substring_for: Option<&SQLExpr>,
1005    ) -> PolarsResult<Expr> {
1006        let e = self.visit_expr(expr)?;
1007
1008        match (substring_from, substring_for) {
1009            // SUBSTRING(expr FROM start FOR length)
1010            (Some(from_expr), Some(for_expr)) => {
1011                let start = self.visit_expr(from_expr)?;
1012                let length = self.visit_expr(for_expr)?;
1013
1014                // note: SQL is 1-indexed, so we need to adjust the offsets accordingly
1015                Ok(match (start.clone(), length.clone()) {
1016                    (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1017                    (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1018                        polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", n)
1019                    },
1020                    (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => {
1021                        e.str().slice(lit(n - 1), length)
1022                    },
1023                    (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => e
1024                        .str()
1025                        .slice(lit(0), (length + lit(n - 1)).clip_min(lit(0))),
1026                    (Expr::Literal(_), _) => {
1027                        polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1028                    },
1029                    (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1030                        polars_bail!(SQLSyntax: "invalid 'length' for SUBSTRING")
1031                    },
1032                    _ => {
1033                        let adjusted_start = start - lit(1);
1034                        when(adjusted_start.clone().lt(lit(0)))
1035                            .then(e.clone().str().slice(
1036                                lit(0),
1037                                (length.clone() + adjusted_start.clone()).clip_min(lit(0)),
1038                            ))
1039                            .otherwise(e.str().slice(adjusted_start, length))
1040                    },
1041                })
1042            },
1043            // SUBSTRING(expr FROM start)
1044            (Some(from_expr), None) => {
1045                let start = self.visit_expr(from_expr)?;
1046
1047                Ok(match start {
1048                    Expr::Literal(lv) if lv.is_null() => lit(lv),
1049                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1050                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1051                        e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null()))
1052                    },
1053                    Expr::Literal(_) => {
1054                        polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1055                    },
1056                    _ => e
1057                        .str()
1058                        .slice(start - lit(1), lit(LiteralValue::untyped_null())),
1059                })
1060            },
1061            // SUBSTRING(expr) - not valid, but handle gracefully
1062            (None, _) => {
1063                polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found 1)")
1064            },
1065        }
1066    }
1067
1068    /// Visit a SQL subquery inside an `IN` expression.
1069    fn visit_in_subquery(
1070        &mut self,
1071        expr: &SQLExpr,
1072        subquery: &Subquery,
1073        negated: bool,
1074    ) -> PolarsResult<Expr> {
1075        let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?;
1076        let expr = self.visit_expr(expr)?;
1077        Ok(if negated {
1078            expr.is_in(subquery_result, false).not()
1079        } else {
1080            expr.is_in(subquery_result, false)
1081        })
1082    }
1083
1084    /// Visit `CASE` control flow expression.
1085    fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
1086        if let SQLExpr::Case {
1087            case_token: _,
1088            end_token: _,
1089            operand,
1090            conditions,
1091            else_result,
1092        } = expr
1093        {
1094            polars_ensure!(
1095                !conditions.is_empty(),
1096                SQLSyntax: "WHEN and THEN expressions must have at least one element"
1097            );
1098
1099            let mut when_thens = conditions.iter();
1100            let first = when_thens.next();
1101            if first.is_none() {
1102                polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
1103            }
1104            let else_res = match else_result {
1105                Some(else_res) => self.visit_expr(else_res)?,
1106                None => lit(LiteralValue::untyped_null()), // ELSE clause is optional; when omitted, it is implicitly NULL
1107            };
1108            if let Some(operand_expr) = operand {
1109                let first_operand_expr = self.visit_expr(operand_expr)?;
1110
1111                let first = first.unwrap();
1112                let first_cond = first_operand_expr.eq(self.visit_expr(&first.condition)?);
1113                let first_then = self.visit_expr(&first.result)?;
1114                let expr = when(first_cond).then(first_then);
1115                let next = when_thens.next();
1116
1117                let mut when_then = if let Some(case_when) = next {
1118                    let second_operand_expr = self.visit_expr(operand_expr)?;
1119                    let cond = second_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1120                    let res = self.visit_expr(&case_when.result)?;
1121                    expr.when(cond).then(res)
1122                } else {
1123                    return Ok(expr.otherwise(else_res));
1124                };
1125                for case_when in when_thens {
1126                    let new_operand_expr = self.visit_expr(operand_expr)?;
1127                    let cond = new_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1128                    let res = self.visit_expr(&case_when.result)?;
1129                    when_then = when_then.when(cond).then(res);
1130                }
1131                return Ok(when_then.otherwise(else_res));
1132            }
1133
1134            let first = first.unwrap();
1135            let first_cond = self.visit_expr(&first.condition)?;
1136            let first_then = self.visit_expr(&first.result)?;
1137            let expr = when(first_cond).then(first_then);
1138            let next = when_thens.next();
1139
1140            let mut when_then = if let Some(case_when) = next {
1141                let cond = self.visit_expr(&case_when.condition)?;
1142                let res = self.visit_expr(&case_when.result)?;
1143                expr.when(cond).then(res)
1144            } else {
1145                return Ok(expr.otherwise(else_res));
1146            };
1147            for case_when in when_thens {
1148                let cond = self.visit_expr(&case_when.condition)?;
1149                let res = self.visit_expr(&case_when.result)?;
1150                when_then = when_then.when(cond).then(res);
1151            }
1152            Ok(when_then.otherwise(else_res))
1153        } else {
1154            unreachable!()
1155        }
1156    }
1157
1158    fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
1159        polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
1160    }
1161}
1162
1163/// parse a SQL expression to a polars expression
1164/// # Example
1165/// ```rust
1166/// # use polars_sql::{SQLContext, sql_expr};
1167/// # use polars_core::prelude::*;
1168/// # use polars_lazy::prelude::*;
1169/// # fn main() {
1170///
1171/// let mut ctx = SQLContext::new();
1172/// let df = df! {
1173///    "a" =>  [1, 2, 3],
1174/// }
1175/// .unwrap();
1176/// let expr = sql_expr("MAX(a)").unwrap();
1177/// df.lazy().select(vec![expr]).collect().unwrap();
1178/// # }
1179/// ```
1180pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1181    let mut ctx = SQLContext::new();
1182
1183    let mut parser = Parser::new(&GenericDialect);
1184    parser = parser.with_options(ParserOptions {
1185        trailing_commas: true,
1186        ..Default::default()
1187    });
1188
1189    let mut ast = parser
1190        .try_with_sql(s.as_ref())
1191        .map_err(to_sql_interface_err)?;
1192    let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1193
1194    Ok(match &expr {
1195        SelectItem::ExprWithAlias { expr, alias } => {
1196            let expr = parse_sql_expr(expr, &mut ctx, None)?;
1197            expr.alias(alias.value.as_str())
1198        },
1199        SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1200        _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1201    })
1202}
1203
1204pub(crate) fn interval_to_duration(interval: &Interval, fixed: bool) -> PolarsResult<Duration> {
1205    if interval.last_field.is_some()
1206        || interval.leading_field.is_some()
1207        || interval.leading_precision.is_some()
1208        || interval.fractional_seconds_precision.is_some()
1209    {
1210        polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
1211    }
1212    let s = match &*interval.value {
1213        SQLExpr::UnaryOp { .. } => {
1214            polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
1215        },
1216        SQLExpr::Value(ValueWithSpan {
1217            value: SQLValue::SingleQuotedString(s),
1218            ..
1219        }) => Some(s),
1220        _ => None,
1221    };
1222    match s {
1223        Some(s) if s.contains('-') => {
1224            polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
1225        },
1226        Some(s) => {
1227            // years, quarters, and months do not have a fixed duration; these
1228            // interval parts can only be used with respect to a reference point
1229            let duration = Duration::parse_interval(s);
1230            if fixed && duration.months() != 0 {
1231                polars_bail!(SQLSyntax: "fixed-duration interval cannot contain years, quarters, or months; found {}", s)
1232            };
1233            Ok(duration)
1234        },
1235        None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
1236    }
1237}
1238
1239pub(crate) fn parse_sql_expr(
1240    expr: &SQLExpr,
1241    ctx: &mut SQLContext,
1242    active_schema: Option<&Schema>,
1243) -> PolarsResult<Expr> {
1244    let mut visitor = SQLExprVisitor { ctx, active_schema };
1245    visitor.visit_expr(expr)
1246}
1247
1248pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1249    match expr {
1250        SQLExpr::Array(arr) => {
1251            let mut visitor = SQLExprVisitor {
1252                ctx,
1253                active_schema: None,
1254            };
1255            visitor.array_expr_to_series(arr.elem.as_slice())
1256        },
1257        _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1258    }
1259}
1260
1261pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1262    let field = match field {
1263        // handle 'DATE_PART' and all valid abbreviations/alternates
1264        DateTimeField::Custom(Ident { value, .. }) => {
1265            let value = value.to_ascii_lowercase();
1266            match value.as_str() {
1267                "millennium" | "millennia" => &DateTimeField::Millennium,
1268                "century" | "centuries" => &DateTimeField::Century,
1269                "decade" | "decades" => &DateTimeField::Decade,
1270                "isoyear" => &DateTimeField::Isoyear,
1271                "year" | "years" | "y" => &DateTimeField::Year,
1272                "quarter" | "quarters" => &DateTimeField::Quarter,
1273                "month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1274                "dayofyear" | "doy" => &DateTimeField::DayOfYear,
1275                "dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1276                "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1277                "isodow" => &DateTimeField::Isodow,
1278                "day" | "days" | "d" => &DateTimeField::Day,
1279                "hour" | "hours" | "h" => &DateTimeField::Hour,
1280                "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1281                "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1282                "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1283                "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1284                "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1285                #[cfg(feature = "timezones")]
1286                "timezone" => &DateTimeField::Timezone,
1287                "time" => &DateTimeField::Time,
1288                "epoch" => &DateTimeField::Epoch,
1289                _ => {
1290                    polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1291                },
1292            }
1293        },
1294        _ => field,
1295    };
1296    Ok(match field {
1297        DateTimeField::Millennium => expr.dt().millennium(),
1298        DateTimeField::Century => expr.dt().century(),
1299        DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1300        DateTimeField::Isoyear => expr.dt().iso_year(),
1301        DateTimeField::Year | DateTimeField::Years => expr.dt().year(),
1302        DateTimeField::Quarter => expr.dt().quarter(),
1303        DateTimeField::Month | DateTimeField::Months => expr.dt().month(),
1304        DateTimeField::Week(weekday) => {
1305            if weekday.is_some() {
1306                polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1307            }
1308            expr.dt().week()
1309        },
1310        DateTimeField::IsoWeek | DateTimeField::Weeks => expr.dt().week(),
1311        DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1312        DateTimeField::DayOfWeek | DateTimeField::Dow => {
1313            let w = expr.dt().weekday();
1314            when(w.clone().eq(typed_lit(7i8)))
1315                .then(typed_lit(0i8))
1316                .otherwise(w)
1317        },
1318        DateTimeField::Isodow => expr.dt().weekday(),
1319        DateTimeField::Day | DateTimeField::Days => expr.dt().day(),
1320        DateTimeField::Hour | DateTimeField::Hours => expr.dt().hour(),
1321        DateTimeField::Minute | DateTimeField::Minutes => expr.dt().minute(),
1322        DateTimeField::Second | DateTimeField::Seconds => expr.dt().second(),
1323        DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1324            (expr.clone().dt().second() * typed_lit(1_000f64))
1325                + expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1326        },
1327        DateTimeField::Microsecond | DateTimeField::Microseconds => {
1328            (expr.clone().dt().second() * typed_lit(1_000_000f64))
1329                + expr.dt().nanosecond().div(typed_lit(1_000f64))
1330        },
1331        DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1332            (expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1333        },
1334        DateTimeField::Time => expr.dt().time(),
1335        #[cfg(feature = "timezones")]
1336        DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(false),
1337        DateTimeField::Epoch => {
1338            expr.clone()
1339                .dt()
1340                .timestamp(TimeUnit::Nanoseconds)
1341                .div(typed_lit(1_000_000_000i64))
1342                + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1343        },
1344        _ => {
1345            polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1346        },
1347    })
1348}
1349
1350/// Allow an expression that represents a 1-indexed parameter to
1351/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars)
1352pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1353    match idx {
1354        Expr::Literal(sc) if sc.is_null() => lit(LiteralValue::untyped_null()),
1355        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => {
1356            if null_if_zero {
1357                lit(LiteralValue::untyped_null())
1358            } else {
1359                idx
1360            }
1361        },
1362        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n < 0 => idx,
1363        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => lit(n - 1),
1364        // TODO: when 'saturating_sub' is available, should be able
1365        //  to streamline the when/then/otherwise block below -
1366        _ => when(idx.clone().gt(lit(0)))
1367            .then(idx.clone() - lit(1))
1368            .otherwise(if null_if_zero {
1369                when(idx.clone().eq(lit(0)))
1370                    .then(lit(LiteralValue::untyped_null()))
1371                    .otherwise(idx.clone())
1372            } else {
1373                idx.clone()
1374            }),
1375    }
1376}
1377
1378fn resolve_column<'a>(
1379    ctx: &'a mut SQLContext,
1380    ident_root: &'a Ident,
1381    name: &'a str,
1382    dtype: &'a DataType,
1383) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1384    let resolved = ctx.resolve_name(&ident_root.value, name);
1385    let resolved = resolved.as_str();
1386    Ok((
1387        if name != resolved {
1388            col(resolved).alias(name)
1389        } else {
1390            col(name)
1391        },
1392        Some(dtype),
1393    ))
1394}
1395
1396pub(crate) fn resolve_compound_identifier(
1397    ctx: &mut SQLContext,
1398    idents: &[Ident],
1399    active_schema: Option<&Schema>,
1400) -> PolarsResult<Vec<Expr>> {
1401    // inference priority: table > struct > column
1402    let ident_root = &idents[0];
1403    let mut remaining_idents = idents.iter().skip(1);
1404    let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1405
1406    // get schema from table (or the active/default schema)
1407    let schema = if let Some(ref mut lf) = lf {
1408        lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)?
1409    } else {
1410        Arc::new(active_schema.cloned().unwrap_or_default())
1411    };
1412
1413    // handle simple/unqualified column reference with no schema
1414    if lf.is_none() && schema.is_empty() {
1415        let (mut column, mut dtype): (Expr, Option<&DataType>) =
1416            (col(ident_root.value.as_str()), None);
1417
1418        // traverse the remaining struct field path (if any)
1419        for ident in remaining_idents {
1420            let name = ident.value.as_str();
1421            match dtype {
1422                Some(DataType::Struct(fields)) if name == "*" => {
1423                    return Ok(fields
1424                        .iter()
1425                        .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1426                        .collect());
1427                },
1428                Some(DataType::Struct(fields)) => {
1429                    dtype = fields
1430                        .iter()
1431                        .find(|fld| fld.name == name)
1432                        .map(|fld| &fld.dtype);
1433                },
1434                Some(dtype) if name == "*" => {
1435                    polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1436                },
1437                _ => dtype = None,
1438            }
1439            column = column.struct_().field_by_name(name);
1440        }
1441        return Ok(vec![column]);
1442    }
1443
1444    let name = &remaining_idents.next().unwrap().value;
1445
1446    // handle "table.*" wildcard expansion
1447    if lf.is_some() && name == "*" {
1448        return schema
1449            .iter_names_and_dtypes()
1450            .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).map(|(expr, _)| expr))
1451            .collect();
1452    }
1453
1454    // resolve column/struct reference
1455    let col_dtype: PolarsResult<(Expr, Option<&DataType>)> =
1456        match (lf.is_none(), schema.get(&ident_root.value)) {
1457            // root is a column/struct in schema (no table)
1458            (true, Some(dtype)) => {
1459                remaining_idents = idents.iter().skip(1);
1460                Ok((col(ident_root.value.as_str()), Some(dtype)))
1461            },
1462            // root is not in schema and no table found
1463            (true, None) => {
1464                polars_bail!(
1465                    SQLInterface: "no table or struct column named '{}' found",
1466                    ident_root
1467                )
1468            },
1469            // root is a table, resolve column from table schema
1470            (false, _) => {
1471                if let Some((_, col_name, dtype)) = schema.get_full(name) {
1472                    resolve_column(ctx, ident_root, col_name, dtype)
1473                } else {
1474                    polars_bail!(
1475                        SQLInterface: "no column named '{}' found in table '{}'",
1476                        name, ident_root
1477                    )
1478                }
1479            },
1480        };
1481
1482    // additional ident levels index into struct fields (eg: "df.col.field.nested_field")
1483    let (mut column, mut dtype) = col_dtype?;
1484    for ident in remaining_idents {
1485        let name = ident.value.as_str();
1486        match dtype {
1487            Some(DataType::Struct(fields)) if name == "*" => {
1488                return Ok(fields
1489                    .iter()
1490                    .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1491                    .collect());
1492            },
1493            Some(DataType::Struct(fields)) => {
1494                dtype = fields
1495                    .iter()
1496                    .find(|fld| fld.name == name)
1497                    .map(|fld| &fld.dtype);
1498            },
1499            Some(dtype) if name == "*" => {
1500                polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1501            },
1502            _ => {
1503                dtype = None;
1504            },
1505        }
1506        column = column.struct_().field_by_name(name);
1507    }
1508    Ok(vec![column])
1509}