Skip to main content

polars_sql/
sql_expr.rs

1//! Expressions that are supported by the Polars SQL interface.
2//!
3//! This is useful for syntax highlighting
4//!
5//! This module defines:
6//! - all Polars SQL keywords [`all_keywords`]
7//! - all Polars SQL functions [`all_functions`]
8
9use std::fmt::Display;
10use std::ops::Div;
11
12use polars_core::prelude::*;
13use polars_lazy::prelude::*;
14use polars_plan::plans::DynLiteralValue;
15use polars_plan::prelude::typed_lit;
16use polars_time::Duration;
17use polars_time::chunkedarray::StringMethods;
18use polars_utils::unique_column_name;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21use sqlparser::ast::{
22    AccessExpr, BinaryOperator as SQLBinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
23    DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, Interval, Query as Subquery,
24    SelectItem, Subscript, TimezoneInfo, TrimWhereField, TypedString,
25    UnaryOperator as SQLUnaryOperator, Value as SQLValue, ValueWithSpan,
26};
27use sqlparser::dialect::GenericDialect;
28use sqlparser::keywords;
29use sqlparser::parser::{Parser, ParserOptions};
30use sqlparser::tokenizer::Token;
31
32use crate::SQLContext;
33use crate::functions::SQLFunctionVisitor;
34use crate::types::{
35    bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
36    timeunit_from_precision,
37};
38
39#[inline]
40#[cold]
41#[must_use]
42/// Convert a Display-able error to PolarsError::SQLInterface
43pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
44    PolarsError::SQLInterface(err.to_string().into())
45}
46
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
49/// Categorises the type of (allowed) subquery constraint
50pub enum SubqueryRestriction {
51    /// Subquery must return a single column
52    SingleColumn,
53    // SingleRow,
54    // SingleValue,
55    // Any
56}
57
58/// Recursively walks a SQL Expr to create a polars Expr
59pub(crate) struct SQLExprVisitor<'a> {
60    ctx: &'a mut SQLContext,
61    active_schema: Option<&'a Schema>,
62}
63
64/// Cast a literal series to the target dtype; uses dedicated ops for
65/// `string → date/time/datetime` and `strict_cast` for everything else.
66fn cast_literal_series(s: &Series, dtype: &DataType) -> PolarsResult<Series> {
67    if s.dtype() == &DataType::String {
68        let ca = s.str()?;
69        match dtype {
70            DataType::Date => return Ok(ca.as_date(None, false)?.into_series()),
71            DataType::Time => return Ok(ca.as_time(None, false)?.into_series()),
72            DataType::Datetime(tu, tz) => {
73                let ambiguous = StringChunked::from_slice(PlSmallStr::EMPTY, &["latest"]);
74                return Ok(ca
75                    .as_datetime(None, *tu, false, false, tz.as_ref(), &ambiguous)?
76                    .into_series());
77            },
78            _ => {},
79        }
80    }
81    s.strict_cast(dtype)
82}
83
84/// Extract the literal value; returns `(sql_value, optional_op)`.
85fn extract_literal_with_op<'a>(
86    expr: &'a SQLExpr,
87    outer_op: Option<&'a SQLUnaryOperator>,
88) -> Option<(&'a SQLValue, Option<&'a SQLUnaryOperator>)> {
89    match expr {
90        SQLExpr::Value(ValueWithSpan { value: v, .. }) => Some((v, outer_op)),
91        SQLExpr::UnaryOp { op, expr } if outer_op.is_none() => match expr.as_ref() {
92            SQLExpr::Value(ValueWithSpan { value: v, .. }) => Some((v, Some(op))),
93            _ => None,
94        },
95        _ => None,
96    }
97}
98
99/// Ensure we have a common dtype from an input set of dtypes (must match).
100fn resolve_common_dtype(dtypes: &[DataType], desc: Option<&str>) -> PolarsResult<Option<DataType>> {
101    let Some(first) = dtypes.first() else {
102        return Ok(None);
103    };
104    let desc = desc.unwrap_or("");
105    for dtype in &dtypes[1..] {
106        polars_ensure!(
107            dtype == first,
108            SQLInterface: "{desc}expected consistent dtypes (found {first:?} and {dtype:?})",
109        );
110    }
111    Ok(Some(first.clone()))
112}
113
114impl SQLExprVisitor<'_> {
115    fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
116        let mut array_elements = Vec::with_capacity(elements.len());
117        let mut cast_dtypes = Vec::new();
118        for e in elements {
119            let val = match e {
120                SQLExpr::Value(ValueWithSpan { value: v, .. }) => self.visit_any_value(v, None),
121                SQLExpr::UnaryOp { op, expr } => match extract_literal_with_op(expr, Some(op)) {
122                    Some((v, op)) => self.visit_any_value(v, op),
123                    None => match expr.as_ref() {
124                        SQLExpr::Cast {
125                            data_type, expr: inner, format: None, ..
126                        } => {
127                            cast_dtypes.push(map_sql_dtype_to_polars(data_type)?);
128                            match extract_literal_with_op(inner, Some(op)) {
129                                Some((v, op)) => self.visit_any_value(v, op),
130                                None => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
131                            }
132                        },
133                        _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
134                    },
135                },
136                SQLExpr::Array(values) => {
137                    let srs = self.array_expr_to_series(&values.elem)?;
138                    Ok(AnyValue::List(srs))
139                },
140                SQLExpr::TypedString(TypedString {
141                    data_type,
142                    value: ValueWithSpan {
143                        value: SQLValue::SingleQuotedString(v), ..
144                    },
145                    ..
146                }) => {
147                    cast_dtypes.push(self.resolve_typed_literal_dtype(data_type, v)?);
148                    Ok(AnyValue::StringOwned(v.as_str().into()))
149                },
150                SQLExpr::Cast {
151                    data_type, expr, format: None, ..
152                } => {
153                    cast_dtypes.push(map_sql_dtype_to_polars(data_type)?);
154                    match extract_literal_with_op(expr, None) {
155                        Some((v, op)) => self.visit_any_value(v, op),
156                        None => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
157                    }
158                },
159                _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
160            }?
161            .into_static();
162            array_elements.push(val);
163        }
164        let mut s = Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)?;
165        if let Some(dtype) = resolve_common_dtype(&cast_dtypes, Some("array literal "))? {
166            s = cast_literal_series(&s, &dtype)?;
167        }
168        Ok(s)
169    }
170
171    fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
172        match expr {
173            SQLExpr::AllOp {
174                left,
175                compare_op,
176                right,
177            } => self.visit_all(left, compare_op, right),
178            SQLExpr::AnyOp {
179                left,
180                compare_op,
181                right,
182                is_some: _,
183            } => self.visit_any(left, compare_op, right),
184            SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
185            SQLExpr::Between {
186                expr,
187                negated,
188                low,
189                high,
190            } => self.visit_between(expr, *negated, low, high),
191            SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
192            SQLExpr::Cast {
193                kind,
194                expr,
195                data_type,
196                format,
197            } => self.visit_cast(expr, data_type, format, kind),
198            SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
199            SQLExpr::CompoundFieldAccess { root, access_chain } => {
200                // simple subscript access (eg: "array_col[1]")
201                if access_chain.len() == 1 {
202                    match &access_chain[0] {
203                        AccessExpr::Subscript(subscript) => {
204                            return self.visit_subscript(root, subscript);
205                        },
206                        AccessExpr::Dot(_) => {
207                            polars_bail!(SQLSyntax: "dot-notation field access is currently unsupported: {:?}", access_chain[0])
208                        },
209                    }
210                }
211                // chained dot/bracket notation (eg: "struct_col.field[2].foo[0].bar")
212                polars_bail!(SQLSyntax: "complex field access chains are currently unsupported: {:?}", access_chain[0])
213            },
214            SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
215            SQLExpr::Extract {
216                field,
217                syntax: _,
218                expr,
219            } => parse_extract_date_part(self.visit_expr(expr)?, field),
220            SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
221            SQLExpr::Function(function) => self.visit_function(function),
222            SQLExpr::Identifier(ident) => self.visit_identifier(ident),
223            SQLExpr::InList {
224                expr,
225                list,
226                negated,
227            } => {
228                let expr = self.visit_expr(expr)?;
229                let elems = self.visit_array_expr(list, true, Some(&expr))?;
230                let is_in = expr.is_in(elems, false);
231                Ok(if *negated { is_in.not() } else { is_in })
232            },
233            SQLExpr::InSubquery {
234                expr,
235                subquery,
236                negated,
237            } => self.visit_in_subquery(expr, subquery, *negated),
238            SQLExpr::Interval(interval) => Ok(lit(interval_to_duration(interval, true)?)),
239            SQLExpr::IsDistinctFrom(e1, e2) => {
240                Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
241            },
242            SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
243            SQLExpr::IsNotDistinctFrom(e1, e2) => {
244                Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
245            },
246            SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
247            SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
248            SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
249            SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
250            SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
251            SQLExpr::Like {
252                negated,
253                any,
254                expr,
255                pattern,
256                escape_char,
257            } => {
258                if *any {
259                    polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
260                }
261                let escape_str = escape_char.as_ref().and_then(|v| match v {
262                    SQLValue::SingleQuotedString(s) => Some(s.clone()),
263                    _ => None,
264                });
265                self.visit_like(*negated, expr, pattern, &escape_str, false)
266            },
267            SQLExpr::ILike {
268                negated,
269                any,
270                expr,
271                pattern,
272                escape_char,
273            } => {
274                if *any {
275                    polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
276                }
277                let escape_str = escape_char.as_ref().and_then(|v| match v {
278                    SQLValue::SingleQuotedString(s) => Some(s.clone()),
279                    _ => None,
280                });
281                self.visit_like(*negated, expr, pattern, &escape_str, true)
282            },
283            SQLExpr::Nested(expr) => self.visit_expr(expr),
284            SQLExpr::Position { expr, r#in } => Ok(
285                // note: SQL is 1-indexed
286                (self
287                    .visit_expr(r#in)?
288                    .str()
289                    .find(self.visit_expr(expr)?, true)
290                    + typed_lit(1u32))
291                .fill_null(typed_lit(0u32)),
292            ),
293            SQLExpr::RLike {
294                // note: parses both RLIKE and REGEXP
295                negated,
296                expr,
297                pattern,
298                regexp: _,
299            } => {
300                let matches = self
301                    .visit_expr(expr)?
302                    .str()
303                    .contains(self.visit_expr(pattern)?, true);
304                Ok(if *negated { matches.not() } else { matches })
305            },
306            SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
307            SQLExpr::Substring {
308                expr,
309                substring_from,
310                substring_for,
311                ..
312            } => self.visit_substring(expr, substring_from.as_deref(), substring_for.as_deref()),
313            SQLExpr::Trim {
314                expr,
315                trim_where,
316                trim_what,
317                trim_characters,
318            } => self.visit_trim(expr, trim_where, trim_what, trim_characters),
319            SQLExpr::TypedString(TypedString {
320                data_type,
321                value:
322                    ValueWithSpan {
323                        value: SQLValue::SingleQuotedString(v),
324                        ..
325                    },
326                uses_odbc_syntax: _,
327            }) => {
328                let dtype = self.resolve_typed_literal_dtype(data_type, v)?;
329                match dtype {
330                    DataType::Date => Ok(lit(v.as_str()).cast(DataType::Date)),
331                    DataType::Time => Ok(lit(v.as_str()).str().to_time(StrptimeOptions {
332                        strict: true,
333                        ..Default::default()
334                    })),
335                    DataType::Datetime(_, _) => Ok(lit(v.as_str()).str().to_datetime(
336                        None,
337                        None,
338                        StrptimeOptions {
339                            strict: true,
340                            ..Default::default()
341                        },
342                        lit("latest"),
343                    )),
344                    _ => unreachable!(),
345                }
346            },
347            SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
348            SQLExpr::Value(ValueWithSpan { value, .. }) => self.visit_literal(value),
349            SQLExpr::Wildcard(_) => Ok(all().as_expr()),
350            e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
351            other => {
352                polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
353            },
354        }
355    }
356
357    fn visit_subquery(
358        &mut self,
359        subquery: &Subquery,
360        restriction: SubqueryRestriction,
361    ) -> PolarsResult<Expr> {
362        if subquery.with.is_some() {
363            polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
364        }
365        // note: we have to execute subqueries in an isolated scope to prevent
366        // propagating any context/arena mutation into the rest of the query
367        let lf = self
368            .ctx
369            .execute_isolated(|ctx| ctx.execute_query_no_ctes(subquery))?;
370
371        if restriction == SubqueryRestriction::SingleColumn {
372            let new_name = unique_column_name();
373            return Ok(Expr::SubPlan(
374                SpecialEq::new(Arc::new(lf.logical_plan)),
375                // TODO: pass the implode depending on expr.
376                vec![(
377                    new_name.clone(),
378                    first().as_expr().implode(true).alias(new_name.clone()),
379                )],
380            ));
381        };
382        polars_bail!(SQLInterface: "subquery type not supported");
383    }
384
385    /// Visit a single SQL identifier.
386    ///
387    /// e.g. column
388    fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
389        Ok(col(ident.value.as_str()))
390    }
391
392    /// Visit a compound SQL identifier
393    ///
394    /// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields)
395    fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
396        Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
397    }
398
399    fn visit_like(
400        &mut self,
401        negated: bool,
402        expr: &SQLExpr,
403        pattern: &SQLExpr,
404        escape_char: &Option<String>,
405        case_insensitive: bool,
406    ) -> PolarsResult<Expr> {
407        if escape_char.is_some() {
408            polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
409        }
410        let pat = match self.visit_expr(pattern) {
411            Ok(Expr::Literal(lv)) if lv.extract_str().is_some() => {
412                PlSmallStr::from_str(lv.extract_str().unwrap())
413            },
414            _ => {
415                polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
416            },
417        };
418        if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
419            // empty string or other exact literal match (eg: no wildcard chars)
420            let op = if negated {
421                SQLBinaryOperator::NotEq
422            } else {
423                SQLBinaryOperator::Eq
424            };
425            self.visit_binary_op(expr, &op, pattern)
426        } else {
427            // create regex from pattern containing SQL wildcard chars ('%' => '.*', '_' => '.')
428            let mut rx = regex::escape(pat.as_str())
429                .replace('%', ".*")
430                .replace('_', ".");
431
432            rx = format!(
433                "^{}{}$",
434                if case_insensitive { "(?is)" } else { "(?s)" },
435                rx
436            );
437
438            let expr = self.visit_expr(expr)?;
439            let matches = expr.str().contains(lit(rx), true);
440            Ok(if negated { matches.not() } else { matches })
441        }
442    }
443
444    fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
445        let expr = self.visit_expr(expr)?;
446        Ok(match subscript {
447            Subscript::Index { index } => {
448                let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
449                expr.list().get(idx, true)
450            },
451            Subscript::Slice { .. } => {
452                polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
453            },
454        })
455    }
456
457    /// Handle implicit temporal string comparisons.
458    ///
459    /// eg: clauses such as -
460    ///   "dt >= '2024-04-30'"
461    ///   "dt = '2077-10-10'::date"
462    ///   "dtm::date = '2077-10-10'
463    fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
464        if let (Some(name), Some(s), expr_dtype) = match (left, right) {
465            // identify "col <op> string" expressions
466            (Expr::Column(name), Expr::Literal(lv)) if lv.extract_str().is_some() => {
467                (Some(name.clone()), Some(lv.extract_str().unwrap()), None)
468            },
469            // identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
470            (Expr::Cast { expr, dtype, .. }, Expr::Literal(lv)) if lv.extract_str().is_some() => {
471                let s = lv.extract_str().unwrap();
472                match &**expr {
473                    Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
474                    _ => (None, Some(s), Some(dtype)),
475                }
476            },
477            _ => (None, None, None),
478        } {
479            if expr_dtype.is_none() && self.active_schema.is_none() {
480                right.clone()
481            } else {
482                let left_dtype = expr_dtype.map_or_else(
483                    || {
484                        self.active_schema
485                            .as_ref()
486                            .and_then(|schema| schema.get(&name))
487                    },
488                    |dt| dt.as_literal(),
489                );
490                match left_dtype {
491                    Some(DataType::Time) if is_iso_time(s) => {
492                        right.clone().str().to_time(StrptimeOptions {
493                            strict: true,
494                            ..Default::default()
495                        })
496                    },
497                    Some(DataType::Date) if is_iso_date(s) => {
498                        right.clone().str().to_date(StrptimeOptions {
499                            strict: true,
500                            ..Default::default()
501                        })
502                    },
503                    Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
504                        if s.len() == 10 {
505                            // handle upcast from ISO date string (10 chars) to datetime
506                            lit(format!("{s}T00:00:00"))
507                        } else {
508                            lit(s.replacen(' ', "T", 1))
509                        }
510                        .str()
511                        .to_datetime(
512                            Some(*tu),
513                            tz.clone(),
514                            StrptimeOptions {
515                                strict: true,
516                                ..Default::default()
517                            },
518                            lit("latest"),
519                        )
520                    },
521                    _ => right.clone(),
522                }
523            }
524        } else {
525            right.clone()
526        }
527    }
528
529    fn struct_field_access_expr(
530        &mut self,
531        expr: &Expr,
532        path: &str,
533        infer_index: bool,
534    ) -> PolarsResult<Expr> {
535        let path_elems = if path.starts_with('{') && path.ends_with('}') {
536            path.trim_matches(|c| c == '{' || c == '}')
537        } else {
538            path
539        }
540        .split(',');
541
542        let mut expr = expr.clone();
543        for p in path_elems {
544            let p = p.trim();
545            expr = if infer_index {
546                match p.parse::<i64>() {
547                    Ok(idx) => expr.list().get(lit(idx), true),
548                    Err(_) => expr.struct_().field_by_name(p),
549                }
550            } else {
551                expr.struct_().field_by_name(p)
552            }
553        }
554        Ok(expr)
555    }
556
557    /// Visit a SQL binary operator.
558    ///
559    /// e.g. "column + 1", "column1 <= column2"
560    fn visit_binary_op(
561        &mut self,
562        left: &SQLExpr,
563        op: &SQLBinaryOperator,
564        right: &SQLExpr,
565    ) -> PolarsResult<Expr> {
566        // check for (unsupported) scalar subquery comparisons
567        if matches!(left, SQLExpr::Subquery(_)) || matches!(right, SQLExpr::Subquery(_)) {
568            let (suggestion, str_op) = match op {
569                SQLBinaryOperator::NotEq => ("; use 'NOT IN' instead", "!=".to_string()),
570                SQLBinaryOperator::Eq => ("; use 'IN' instead", format!("{op}")),
571                _ => ("", format!("{op}")),
572            };
573            polars_bail!(
574                SQLSyntax: "subquery comparisons with '{str_op}' are not supported{suggestion}"
575            );
576        }
577
578        // need special handling for interval offsets and comparisons
579        let (lhs, mut rhs) = match (left, op, right) {
580            (_, SQLBinaryOperator::Minus, SQLExpr::Interval(v)) => {
581                let duration = interval_to_duration(v, false)?;
582                return Ok(self
583                    .visit_expr(left)?
584                    .dt()
585                    .offset_by(lit(format!("-{duration}"))));
586            },
587            (_, SQLBinaryOperator::Plus, SQLExpr::Interval(v)) => {
588                let duration = interval_to_duration(v, false)?;
589                return Ok(self
590                    .visit_expr(left)?
591                    .dt()
592                    .offset_by(lit(format!("{duration}"))));
593            },
594            (SQLExpr::Interval(v1), _, SQLExpr::Interval(v2)) => {
595                // shortcut interval comparison evaluation (-> bool)
596                let d1 = interval_to_duration(v1, false)?;
597                let d2 = interval_to_duration(v2, false)?;
598                let res = match op {
599                    SQLBinaryOperator::Gt => Ok(lit(d1 > d2)),
600                    SQLBinaryOperator::Lt => Ok(lit(d1 < d2)),
601                    SQLBinaryOperator::GtEq => Ok(lit(d1 >= d2)),
602                    SQLBinaryOperator::LtEq => Ok(lit(d1 <= d2)),
603                    SQLBinaryOperator::NotEq => Ok(lit(d1 != d2)),
604                    SQLBinaryOperator::Eq | SQLBinaryOperator::Spaceship => Ok(lit(d1 == d2)),
605                    _ => polars_bail!(SQLInterface: "invalid interval comparison operator"),
606                };
607                if res.is_ok() {
608                    return res;
609                }
610                (self.visit_expr(left)?, self.visit_expr(right)?)
611            },
612            _ => (self.visit_expr(left)?, self.visit_expr(right)?),
613        };
614        rhs = self.convert_temporal_strings(&lhs, &rhs);
615
616        Ok(match op {
617            // ----
618            // Bitwise operators
619            // ----
620            SQLBinaryOperator::BitwiseAnd => lhs.and(rhs),  // "x & y"
621            SQLBinaryOperator::BitwiseOr => lhs.or(rhs),  // "x | y"
622            SQLBinaryOperator::Xor => lhs.xor(rhs),  // "x XOR y"
623
624            // ----
625            // Comparison operators
626            // ----
627            SQLBinaryOperator::Eq => lhs.eq(rhs),  // "x = y"
628            SQLBinaryOperator::Gt => lhs.gt(rhs),  // "x > y"
629            SQLBinaryOperator::GtEq => lhs.gt_eq(rhs),  // "x >= y"
630            SQLBinaryOperator::Lt => lhs.lt(rhs),  // "x < y"
631            SQLBinaryOperator::LtEq => lhs.lt_eq(rhs),  // "x <= y"
632            SQLBinaryOperator::NotEq => lhs.eq(rhs).not(),  // "x != y"
633            SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs),  // "x <=> y"
634
635            // ----
636            // Logical operators
637            // ----
638            SQLBinaryOperator::And => lhs.and(rhs),  // "x AND y"
639            SQLBinaryOperator::Or => lhs.or(rhs),  // "x OR y"
640
641            // ----
642            // Mathematical operators
643            // ----
644            SQLBinaryOperator::Divide => lhs.true_div(rhs),  // "x / y"
645            SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64),  // "x // y"
646            SQLBinaryOperator::Minus => lhs - rhs,  // "x - y"
647            SQLBinaryOperator::Modulo => lhs % rhs,  // "x % y"
648            SQLBinaryOperator::Multiply => lhs * rhs,  // "x * y"
649            SQLBinaryOperator::Plus => lhs + rhs,  // "x + y"
650
651            // ----
652            // Regular expression operators
653            // ----
654            SQLBinaryOperator::PGRegexMatch => match rhs {  // "x ~ y"
655                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true),
656                _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
657            },
658            SQLBinaryOperator::PGRegexNotMatch => match rhs {  // "x !~ y"
659                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true).not(),
660                _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
661            },
662            SQLBinaryOperator::PGRegexIMatch => match rhs {  // "x ~* y"
663                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
664                    let pat = lv.extract_str().unwrap();
665                    lhs.str().contains(lit(format!("(?i){pat}")), true)
666                },
667                _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
668            },
669            SQLBinaryOperator::PGRegexNotIMatch => match rhs {  // "x !~* y"
670                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
671                    let pat = lv.extract_str().unwrap();
672                    lhs.str().contains(lit(format!("(?i){pat}")), true).not()
673                },
674                _ => {
675                    polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
676                },
677            },
678            // ----
679            // LIKE/ILIKE operators
680            // ----
681            SQLBinaryOperator::PGLikeMatch  // "x ~~ y"
682            | SQLBinaryOperator::PGNotLikeMatch  // "x !~~ y"
683            | SQLBinaryOperator::PGILikeMatch  // "x ~~* y"
684            | SQLBinaryOperator::PGNotILikeMatch => {  // "x !~~* y"
685                let expr = if matches!(
686                    op,
687                    SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
688                ) {
689                    SQLExpr::Like {
690                        negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
691                        any: false,
692                        expr: Box::new(left.clone()),
693                        pattern: Box::new(right.clone()),
694                        escape_char: None,
695                    }
696                } else {
697                    SQLExpr::ILike {
698                        negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
699                        any: false,
700                        expr: Box::new(left.clone()),
701                        pattern: Box::new(right.clone()),
702                        escape_char: None,
703                    }
704                };
705                self.visit_expr(&expr)?
706            },
707            // ----
708            // String operators
709            // ----
710            SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs),  // "x ^@ y"
711            SQLBinaryOperator::StringConcat => {  // "x || y"
712                lhs.cast(DataType::String) + rhs.cast(DataType::String)
713            },
714            // ----
715            // Nested (JSON) field access operators
716            // ----
717            SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs {  // "x -> y", "x ->> y"
718                Expr::Literal(lv) if lv.extract_str().is_some() => {
719                    let path = lv.extract_str().unwrap();
720                    let mut expr = self.struct_field_access_expr(&lhs, path, false)?;
721                    if let SQLBinaryOperator::LongArrow = op {
722                        expr = expr.cast(DataType::String);
723                    }
724                    expr
725                },
726                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(idx))) => {
727                    let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
728                    if let SQLBinaryOperator::LongArrow = op {
729                        expr = expr.cast(DataType::String);
730                    }
731                    expr
732                },
733                _ => {
734                    polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
735                },
736            },
737            SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => {  // "x #> y", "x #>> y"
738                match rhs {
739                    Expr::Literal(lv) if lv.extract_str().is_some() => {
740                        let path = lv.extract_str().unwrap();
741                        let mut expr = self.struct_field_access_expr(&lhs, path, true)?;
742                        if let SQLBinaryOperator::HashLongArrow = op {
743                            expr = expr.cast(DataType::String);
744                        }
745                        expr
746                    },
747                    _ => {
748                        polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
749                    }
750                }
751            },
752            other => {
753                polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
754            },
755        })
756    }
757
758    /// Visit a SQL unary operator.
759    ///
760    /// e.g. +column or -column
761    fn visit_unary_op(&mut self, op: &SQLUnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
762        let expr = self.visit_expr(expr)?;
763        Ok(match (op, expr.clone()) {
764            // simplify the parse tree by special-casing common unary +/- ops
765            (SQLUnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
766                lit(n)
767            },
768            (
769                SQLUnaryOperator::Plus,
770                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n))),
771            ) => lit(n),
772            (
773                SQLUnaryOperator::Minus,
774                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))),
775            ) => lit(-n),
776            (
777                SQLUnaryOperator::Minus,
778                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n))),
779            ) => lit(-n),
780            // general case
781            (SQLUnaryOperator::Plus, _) => lit(0) + expr,
782            (SQLUnaryOperator::Minus, _) => lit(0) - expr,
783            (SQLUnaryOperator::Not, _) => match &expr {
784                Expr::Column(name)
785                    if self
786                        .active_schema
787                        .and_then(|schema| schema.get(name))
788                        .is_some_and(|dtype| matches!(dtype, DataType::Boolean)) =>
789                {
790                    // if already boolean, can operate bitwise
791                    expr.not()
792                },
793                // otherwise SQL "NOT" expects logical, not bitwise, behaviour (eg: on integers)
794                _ => expr.strict_cast(DataType::Boolean).not(),
795            },
796            other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
797        })
798    }
799
800    /// Visit a SQL function.
801    ///
802    /// e.g. SUM(column) or COUNT(*)
803    ///
804    /// See [SQLFunctionVisitor] for more details
805    fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
806        let mut visitor = SQLFunctionVisitor {
807            func: function,
808            ctx: self.ctx,
809            active_schema: self.active_schema,
810            filter: None,
811        };
812        visitor.visit_function()
813    }
814
815    /// Visit a SQL `ALL` expression.
816    ///
817    /// e.g. `a > ALL(y)`
818    fn visit_all(
819        &mut self,
820        left: &SQLExpr,
821        compare_op: &SQLBinaryOperator,
822        right: &SQLExpr,
823    ) -> PolarsResult<Expr> {
824        let left = self.visit_expr(left)?;
825        let right = self.visit_expr(right)?;
826
827        match compare_op {
828            SQLBinaryOperator::Gt => Ok(left.gt(right.max())),
829            SQLBinaryOperator::Lt => Ok(left.lt(right.min())),
830            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
831            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
832            SQLBinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
833            SQLBinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
834            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
835        }
836    }
837
838    /// Visit a SQL `ANY` expression.
839    ///
840    /// e.g. `a != ANY(y)`
841    fn visit_any(
842        &mut self,
843        left: &SQLExpr,
844        compare_op: &SQLBinaryOperator,
845        right: &SQLExpr,
846    ) -> PolarsResult<Expr> {
847        let left = self.visit_expr(left)?;
848        let right = self.visit_expr(right)?;
849
850        match compare_op {
851            SQLBinaryOperator::Gt => Ok(left.gt(right.min())),
852            SQLBinaryOperator::Lt => Ok(left.lt(right.max())),
853            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
854            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
855            SQLBinaryOperator::Eq => Ok(left.is_in(right, false)),
856            SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()),
857            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
858        }
859    }
860
861    /// Visit a SQL `ARRAY` list (including `IN` values).
862    fn visit_array_expr(
863        &mut self,
864        elements: &[SQLExpr],
865        result_as_element: bool,
866        dtype_expr_match: Option<&Expr>,
867    ) -> PolarsResult<Expr> {
868        let mut elems = self.array_expr_to_series(elements)?;
869
870        // handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
871        // (not yet as versatile as the temporal string conversions in visit_binary_op)
872        if let (Some(Expr::Column(name)), Some(schema)) =
873            (dtype_expr_match, self.active_schema.as_ref())
874        {
875            if elems.dtype() == &DataType::String {
876                if let Some(dtype) = schema.get(name) {
877                    if matches!(
878                        dtype,
879                        DataType::Date | DataType::Time | DataType::Datetime(_, _)
880                    ) {
881                        elems = elems.strict_cast(dtype)?;
882                    }
883                }
884            }
885        }
886
887        // if we are parsing the list as an element in a series, implode.
888        // otherwise, return the series as-is.
889        let res = if result_as_element {
890            elems.implode()?.into_series()
891        } else {
892            elems
893        };
894        Ok(lit(res))
895    }
896
897    /// Visit a SQL `CAST` or `TRY_CAST` expression.
898    ///
899    /// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
900    fn visit_cast(
901        &mut self,
902        expr: &SQLExpr,
903        dtype: &SQLDataType,
904        format: &Option<CastFormat>,
905        cast_kind: &CastKind,
906    ) -> PolarsResult<Expr> {
907        if format.is_some() {
908            return Err(
909                polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
910            );
911        }
912        let expr = self.visit_expr(expr)?;
913
914        #[cfg(feature = "json")]
915        if dtype == &SQLDataType::JSON {
916            // @BROKEN: we cannot handle this.
917            return Ok(expr.str().json_decode(DataType::Struct(Vec::new())));
918        }
919        let polars_type = map_sql_dtype_to_polars(dtype)?;
920        Ok(match cast_kind {
921            CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
922            CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
923        })
924    }
925
926    /// Visit a SQL literal.
927    ///
928    /// e.g. 1, 'foo', 1.0, NULL
929    ///
930    /// See [SQLValue] and [LiteralValue] for more details
931    fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
932        // note: double-quoted strings will be parsed as identifiers, not literals
933        Ok(match value {
934            SQLValue::Boolean(b) => lit(*b),
935            SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
936            #[cfg(feature = "binary_encoding")]
937            SQLValue::HexStringLiteral(x) => {
938                if x.len() % 2 != 0 {
939                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
940                };
941                lit(hex::decode(x.clone())
942                    .map_err(|_| polars_err!(SQLSyntax: "invalid hex string literal: '{}'", x))?)
943            },
944            SQLValue::Null => Expr::Literal(LiteralValue::untyped_null()),
945            SQLValue::Number(s, _) => {
946                // Check for existence of decimal separator dot
947                if s.contains('.') {
948                    s.parse::<f64>().map(lit).map_err(|_| ())
949                } else {
950                    s.parse::<i64>().map(lit).map_err(|_| ())
951                }
952                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
953            },
954            SQLValue::SingleQuotedByteStringLiteral(b) => {
955                // note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string
956                // literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser-rs
957                // patterned the token name after BigQuery (where b'str' really IS a byte string)
958                bitstring_to_bytes_literal(b)?
959            },
960            SQLValue::SingleQuotedString(s) => lit(s.clone()),
961            other => {
962                polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
963            },
964        })
965    }
966
967    /// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr.
968    fn visit_any_value(
969        &self,
970        value: &SQLValue,
971        op: Option<&SQLUnaryOperator>,
972    ) -> PolarsResult<AnyValue<'_>> {
973        Ok(match value {
974            SQLValue::Boolean(b) => AnyValue::Boolean(*b),
975            SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
976            #[cfg(feature = "binary_encoding")]
977            SQLValue::HexStringLiteral(x) => {
978                if x.len() % 2 != 0 {
979                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
980                };
981                AnyValue::BinaryOwned(
982                    hex::decode(x.clone()).map_err(
983                        |_| polars_err!(SQLSyntax: "invalid hex string literal: '{}'", x),
984                    )?,
985                )
986            },
987            SQLValue::Null => AnyValue::Null,
988            SQLValue::Number(s, _) => {
989                let negate = match op {
990                    Some(SQLUnaryOperator::Minus) => true,
991                    // no op should be taken as plus.
992                    Some(SQLUnaryOperator::Plus) | None => false,
993                    Some(op) => {
994                        polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
995                    },
996                };
997                // Check for existence of decimal separator dot
998                if s.contains('.') {
999                    s.parse::<f64>()
1000                        .map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
1001                        .map_err(|_| ())
1002                } else {
1003                    s.parse::<i64>()
1004                        .map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
1005                        .map_err(|_| ())
1006                }
1007                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
1008            },
1009            SQLValue::SingleQuotedByteStringLiteral(b) => {
1010                // note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE
1011                let bytes_literal = bitstring_to_bytes_literal(b)?;
1012                match bytes_literal {
1013                    Expr::Literal(lv) if lv.extract_binary().is_some() => {
1014                        AnyValue::BinaryOwned(lv.extract_binary().unwrap().to_vec())
1015                    },
1016                    _ => {
1017                        polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
1018                    },
1019                }
1020            },
1021            SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
1022            other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
1023        })
1024    }
1025
1026    /// Validate a typed string literal (e.g. DATE '2024-01-01'), returning its dtype.
1027    fn resolve_typed_literal_dtype(
1028        &self,
1029        data_type: &SQLDataType,
1030        value: &str,
1031    ) -> PolarsResult<DataType> {
1032        match data_type {
1033            SQLDataType::Date => {
1034                polars_ensure!(is_iso_date(value), SQLSyntax: "invalid DATE literal '{}'", value);
1035                Ok(DataType::Date)
1036            },
1037            SQLDataType::Time(None, TimezoneInfo::None) => {
1038                polars_ensure!(is_iso_time(value), SQLSyntax: "invalid TIME literal '{}'", value);
1039                Ok(DataType::Time)
1040            },
1041            SQLDataType::Timestamp(prec, TimezoneInfo::None) | SQLDataType::Datetime(prec) => {
1042                let fn_name = match data_type {
1043                    SQLDataType::Timestamp(_, _) => "TIMESTAMP",
1044                    _ => "DATETIME",
1045                };
1046                polars_ensure!(
1047                    is_iso_datetime(value),
1048                    SQLSyntax: "invalid {} literal '{}'", fn_name, value,
1049                );
1050                Ok(DataType::Datetime(timeunit_from_precision(prec)?, None))
1051            },
1052            _ => {
1053                polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
1054            },
1055        }
1056    }
1057
1058    /// Visit a SQL `BETWEEN` expression.
1059    /// See [sqlparser::ast::Expr::Between] for more details
1060    fn visit_between(
1061        &mut self,
1062        expr: &SQLExpr,
1063        negated: bool,
1064        low: &SQLExpr,
1065        high: &SQLExpr,
1066    ) -> PolarsResult<Expr> {
1067        let expr = self.visit_expr(expr)?;
1068        let low = self.visit_expr(low)?;
1069        let high = self.visit_expr(high)?;
1070
1071        let low = self.convert_temporal_strings(&expr, &low);
1072        let high = self.convert_temporal_strings(&expr, &high);
1073        Ok(if negated {
1074            expr.clone().lt(low).or(expr.gt(high))
1075        } else {
1076            expr.clone().gt_eq(low).and(expr.lt_eq(high))
1077        })
1078    }
1079
1080    /// Visit a SQL `TRIM` function.
1081    /// See [sqlparser::ast::Expr::Trim] for more details
1082    fn visit_trim(
1083        &mut self,
1084        expr: &SQLExpr,
1085        trim_where: &Option<TrimWhereField>,
1086        trim_what: &Option<Box<SQLExpr>>,
1087        trim_characters: &Option<Vec<SQLExpr>>,
1088    ) -> PolarsResult<Expr> {
1089        if trim_characters.is_some() {
1090            // TODO: allow compact snowflake/bigquery syntax?
1091            return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
1092        };
1093        let expr = self.visit_expr(expr)?;
1094        let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
1095        let trim_what = match trim_what {
1096            Some(Expr::Literal(lv)) if lv.extract_str().is_some() => {
1097                Some(PlSmallStr::from_str(lv.extract_str().unwrap()))
1098            },
1099            None => None,
1100            _ => return self.err(&expr),
1101        };
1102        Ok(match (trim_where, trim_what) {
1103            (None | Some(TrimWhereField::Both), None) => {
1104                expr.str().strip_chars(lit(LiteralValue::untyped_null()))
1105            },
1106            (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
1107            (Some(TrimWhereField::Leading), None) => expr
1108                .str()
1109                .strip_chars_start(lit(LiteralValue::untyped_null())),
1110            (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
1111            (Some(TrimWhereField::Trailing), None) => expr
1112                .str()
1113                .strip_chars_end(lit(LiteralValue::untyped_null())),
1114            (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
1115        })
1116    }
1117
1118    fn visit_substring(
1119        &mut self,
1120        expr: &SQLExpr,
1121        substring_from: Option<&SQLExpr>,
1122        substring_for: Option<&SQLExpr>,
1123    ) -> PolarsResult<Expr> {
1124        let e = self.visit_expr(expr)?;
1125
1126        match (substring_from, substring_for) {
1127            // SUBSTRING(expr FROM start FOR length)
1128            (Some(from_expr), Some(for_expr)) => {
1129                let start = self.visit_expr(from_expr)?;
1130                let length = self.visit_expr(for_expr)?;
1131
1132                // note: SQL is 1-indexed, so we need to adjust the offsets accordingly
1133                Ok(match (start.clone(), length.clone()) {
1134                    (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1135                    (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1136                        polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", n)
1137                    },
1138                    (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => {
1139                        e.str().slice(lit(n - 1), length)
1140                    },
1141                    (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => e
1142                        .str()
1143                        .slice(lit(0), (length + lit(n - 1)).clip_min(lit(0))),
1144                    (Expr::Literal(_), _) => {
1145                        polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1146                    },
1147                    (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1148                        polars_bail!(SQLSyntax: "invalid 'length' for SUBSTRING")
1149                    },
1150                    _ => {
1151                        let adjusted_start = start - lit(1);
1152                        when(adjusted_start.clone().lt(lit(0)))
1153                            .then(e.clone().str().slice(
1154                                lit(0),
1155                                (length.clone() + adjusted_start.clone()).clip_min(lit(0)),
1156                            ))
1157                            .otherwise(e.str().slice(adjusted_start, length))
1158                    },
1159                })
1160            },
1161            // SUBSTRING(expr FROM start)
1162            (Some(from_expr), None) => {
1163                let start = self.visit_expr(from_expr)?;
1164
1165                Ok(match start {
1166                    Expr::Literal(lv) if lv.is_null() => lit(lv),
1167                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1168                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1169                        e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null()))
1170                    },
1171                    Expr::Literal(_) => {
1172                        polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1173                    },
1174                    _ => e
1175                        .str()
1176                        .slice(start - lit(1), lit(LiteralValue::untyped_null())),
1177                })
1178            },
1179            // SUBSTRING(expr) - not valid, but handle gracefully
1180            (None, _) => {
1181                polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found 1)")
1182            },
1183        }
1184    }
1185
1186    /// Visit a SQL subquery inside an `IN` expression.
1187    fn visit_in_subquery(
1188        &mut self,
1189        expr: &SQLExpr,
1190        subquery: &Subquery,
1191        negated: bool,
1192    ) -> PolarsResult<Expr> {
1193        let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?;
1194        let expr = self.visit_expr(expr)?;
1195        Ok(if negated {
1196            expr.is_in(subquery_result, false).not()
1197        } else {
1198            expr.is_in(subquery_result, false)
1199        })
1200    }
1201
1202    /// Visit `CASE` control flow expression.
1203    fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
1204        if let SQLExpr::Case {
1205            case_token: _,
1206            end_token: _,
1207            operand,
1208            conditions,
1209            else_result,
1210        } = expr
1211        {
1212            polars_ensure!(
1213                !conditions.is_empty(),
1214                SQLSyntax: "WHEN and THEN expressions must have at least one element"
1215            );
1216
1217            let mut when_thens = conditions.iter();
1218            let first = when_thens.next();
1219            if first.is_none() {
1220                polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
1221            }
1222            let else_res = match else_result {
1223                Some(else_res) => self.visit_expr(else_res)?,
1224                None => lit(LiteralValue::untyped_null()), // ELSE clause is optional; when omitted, it is implicitly NULL
1225            };
1226            if let Some(operand_expr) = operand {
1227                let first_operand_expr = self.visit_expr(operand_expr)?;
1228
1229                let first = first.unwrap();
1230                let first_cond = first_operand_expr.eq(self.visit_expr(&first.condition)?);
1231                let first_then = self.visit_expr(&first.result)?;
1232                let expr = when(first_cond).then(first_then);
1233                let next = when_thens.next();
1234
1235                let mut when_then = if let Some(case_when) = next {
1236                    let second_operand_expr = self.visit_expr(operand_expr)?;
1237                    let cond = second_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1238                    let res = self.visit_expr(&case_when.result)?;
1239                    expr.when(cond).then(res)
1240                } else {
1241                    return Ok(expr.otherwise(else_res));
1242                };
1243                for case_when in when_thens {
1244                    let new_operand_expr = self.visit_expr(operand_expr)?;
1245                    let cond = new_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1246                    let res = self.visit_expr(&case_when.result)?;
1247                    when_then = when_then.when(cond).then(res);
1248                }
1249                return Ok(when_then.otherwise(else_res));
1250            }
1251
1252            let first = first.unwrap();
1253            let first_cond = self.visit_expr(&first.condition)?;
1254            let first_then = self.visit_expr(&first.result)?;
1255            let expr = when(first_cond).then(first_then);
1256            let next = when_thens.next();
1257
1258            let mut when_then = if let Some(case_when) = next {
1259                let cond = self.visit_expr(&case_when.condition)?;
1260                let res = self.visit_expr(&case_when.result)?;
1261                expr.when(cond).then(res)
1262            } else {
1263                return Ok(expr.otherwise(else_res));
1264            };
1265            for case_when in when_thens {
1266                let cond = self.visit_expr(&case_when.condition)?;
1267                let res = self.visit_expr(&case_when.result)?;
1268                when_then = when_then.when(cond).then(res);
1269            }
1270            Ok(when_then.otherwise(else_res))
1271        } else {
1272            unreachable!()
1273        }
1274    }
1275
1276    fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
1277        polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
1278    }
1279}
1280
1281/// parse a SQL expression to a polars expression
1282/// # Example
1283/// ```rust
1284/// # use polars_sql::{SQLContext, sql_expr};
1285/// # use polars_core::prelude::*;
1286/// # use polars_lazy::prelude::*;
1287/// # fn main() {
1288///
1289/// let mut ctx = SQLContext::new();
1290/// let df = df! {
1291///    "a" =>  [1, 2, 3],
1292/// }
1293/// .unwrap();
1294/// let expr = sql_expr("MAX(a)").unwrap();
1295/// df.lazy().select(vec![expr]).collect().unwrap();
1296/// # }
1297/// ```
1298pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1299    let mut ctx = SQLContext::new();
1300    let s = s.as_ref();
1301
1302    let mut parser = Parser::new(&GenericDialect);
1303    parser = parser.with_options(ParserOptions {
1304        trailing_commas: true,
1305        ..Default::default()
1306    });
1307
1308    // `sql_expr` should only translate expressions, not statements or clauses
1309    let mut ast = parser.try_with_sql(s).map_err(to_sql_interface_err)?;
1310    if let Token::Word(word) = &ast.peek_token().token {
1311        if keywords::RESERVED_FOR_COLUMN_ALIAS.contains(&word.keyword) {
1312            polars_bail!(SQLInterface: "expected an expression (found '{}' clause)", word.value)
1313        }
1314    }
1315    let expr = ast
1316        .parse_select_item()
1317        .map_err(|_| polars_err!(SQLInterface: "unable to parse '{}' as Expr", s))?;
1318
1319    // ensure all input was consumed; remaining tokens indicate invalid trailing SQL
1320    match &ast.peek_token().token {
1321        Token::EOF => {},
1322        Token::Word(word) if keywords::RESERVED_FOR_COLUMN_ALIAS.contains(&word.keyword) => {
1323            polars_bail!(SQLInterface: "expected an expression (found '{}' clause)", word.value)
1324        },
1325        token => {
1326            polars_bail!(SQLInterface: "invalid expression (found unexpected token '{}')", token)
1327        },
1328    }
1329    Ok(match &expr {
1330        SelectItem::ExprWithAlias { expr, alias } => {
1331            let expr = parse_sql_expr(expr, &mut ctx, None)?;
1332            expr.alias(alias.value.as_str())
1333        },
1334        SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1335        _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s),
1336    })
1337}
1338
1339pub(crate) fn interval_to_duration(interval: &Interval, fixed: bool) -> PolarsResult<Duration> {
1340    if interval.last_field.is_some()
1341        || interval.leading_field.is_some()
1342        || interval.leading_precision.is_some()
1343        || interval.fractional_seconds_precision.is_some()
1344    {
1345        polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
1346    }
1347    let s = match &*interval.value {
1348        SQLExpr::UnaryOp { .. } => {
1349            polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
1350        },
1351        SQLExpr::Value(ValueWithSpan {
1352            value: SQLValue::SingleQuotedString(s),
1353            ..
1354        }) => Some(s),
1355        _ => None,
1356    };
1357    match s {
1358        Some(s) if s.contains('-') => {
1359            polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
1360        },
1361        Some(s) => {
1362            // years, quarters, and months do not have a fixed duration; these
1363            // interval parts can only be used with respect to a reference point
1364            let duration = Duration::parse_interval(s);
1365            if fixed && duration.months() != 0 {
1366                polars_bail!(SQLSyntax: "fixed-duration interval cannot contain years, quarters, or months; found {}", s)
1367            };
1368            Ok(duration)
1369        },
1370        None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
1371    }
1372}
1373
1374pub(crate) fn parse_sql_expr(
1375    expr: &SQLExpr,
1376    ctx: &mut SQLContext,
1377    active_schema: Option<&Schema>,
1378) -> PolarsResult<Expr> {
1379    let mut visitor = SQLExprVisitor { ctx, active_schema };
1380    visitor.visit_expr(expr)
1381}
1382
1383pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1384    match expr {
1385        SQLExpr::Array(arr) => {
1386            let mut visitor = SQLExprVisitor {
1387                ctx,
1388                active_schema: None,
1389            };
1390            visitor.array_expr_to_series(arr.elem.as_slice())
1391        },
1392        _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1393    }
1394}
1395
1396pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1397    let field = match field {
1398        // handle 'DATE_PART' and all valid abbreviations/alternates
1399        DateTimeField::Custom(Ident { value, .. }) => {
1400            let value = value.to_ascii_lowercase();
1401            match value.as_str() {
1402                "millennium" | "millennia" => &DateTimeField::Millennium,
1403                "century" | "centuries" | "c" => &DateTimeField::Century,
1404                "decade" | "decades" => &DateTimeField::Decade,
1405                "isoyear" => &DateTimeField::Isoyear,
1406                "year" | "years" | "y" => &DateTimeField::Year,
1407                "quarter" | "quarters" => &DateTimeField::Quarter,
1408                "month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1409                "dayofyear" | "doy" => &DateTimeField::DayOfYear,
1410                "dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1411                "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1412                "isodow" => &DateTimeField::Isodow,
1413                "day" | "days" | "dayofmonth" | "d" => &DateTimeField::Day,
1414                "hour" | "hours" | "h" => &DateTimeField::Hour,
1415                "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1416                "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1417                "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1418                "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1419                "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1420                #[cfg(feature = "timezones")]
1421                "timezone" => &DateTimeField::Timezone,
1422                "time" => &DateTimeField::Time,
1423                "epoch" => &DateTimeField::Epoch,
1424                _ => {
1425                    polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1426                },
1427            }
1428        },
1429        _ => field,
1430    };
1431    Ok(match field {
1432        DateTimeField::Millennium => expr.dt().millennium(),
1433        DateTimeField::Century => expr.dt().century(),
1434        DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1435        DateTimeField::Isoyear => expr.dt().iso_year(),
1436        DateTimeField::Year | DateTimeField::Years => expr.dt().year(),
1437        DateTimeField::Quarter => expr.dt().quarter(),
1438        DateTimeField::Month | DateTimeField::Months => expr.dt().month(),
1439        DateTimeField::Week(weekday) => {
1440            if weekday.is_some() {
1441                polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1442            }
1443            expr.dt().week()
1444        },
1445        DateTimeField::IsoWeek | DateTimeField::Weeks => expr.dt().week(),
1446        DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1447        DateTimeField::DayOfWeek | DateTimeField::Dow => {
1448            let w = expr.dt().weekday();
1449            when(w.clone().eq(typed_lit(7i8)))
1450                .then(typed_lit(0i8))
1451                .otherwise(w)
1452        },
1453        DateTimeField::Isodow => expr.dt().weekday(),
1454        DateTimeField::Day | DateTimeField::Days => expr.dt().day(),
1455        DateTimeField::Hour | DateTimeField::Hours => expr.dt().hour(),
1456        DateTimeField::Minute | DateTimeField::Minutes => expr.dt().minute(),
1457        DateTimeField::Second | DateTimeField::Seconds => expr.dt().second(),
1458        DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1459            (expr.clone().dt().second() * typed_lit(1_000f64))
1460                + expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1461        },
1462        DateTimeField::Microsecond | DateTimeField::Microseconds => {
1463            (expr.clone().dt().second() * typed_lit(1_000_000f64))
1464                + expr.dt().nanosecond().div(typed_lit(1_000f64))
1465        },
1466        DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1467            (expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1468        },
1469        DateTimeField::Time => expr.dt().time(),
1470        #[cfg(feature = "timezones")]
1471        DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(false),
1472        DateTimeField::Epoch => {
1473            expr.clone()
1474                .dt()
1475                .timestamp(TimeUnit::Nanoseconds)
1476                .div(typed_lit(1_000_000_000i64))
1477                + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1478        },
1479        _ => {
1480            polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1481        },
1482    })
1483}
1484
1485/// Allow an expression that represents a 1-indexed parameter to
1486/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars)
1487pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1488    match idx {
1489        Expr::Literal(sc) if sc.is_null() => lit(LiteralValue::untyped_null()),
1490        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => {
1491            if null_if_zero {
1492                lit(LiteralValue::untyped_null())
1493            } else {
1494                idx
1495            }
1496        },
1497        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n < 0 => idx,
1498        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => lit(n - 1),
1499        // TODO: when 'saturating_sub' is available, should be able
1500        //  to streamline the when/then/otherwise block below -
1501        _ => when(idx.clone().gt(lit(0)))
1502            .then(idx.clone() - lit(1))
1503            .otherwise(if null_if_zero {
1504                when(idx.clone().eq(lit(0)))
1505                    .then(lit(LiteralValue::untyped_null()))
1506                    .otherwise(idx.clone())
1507            } else {
1508                idx.clone()
1509            }),
1510    }
1511}
1512
1513fn resolve_column<'a>(
1514    ctx: &'a mut SQLContext,
1515    ident_root: &'a Ident,
1516    name: &'a str,
1517    dtype: &'a DataType,
1518) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1519    let resolved = ctx.resolve_name(&ident_root.value, name);
1520    let resolved = resolved.as_str();
1521    Ok((
1522        if name != resolved {
1523            col(resolved).alias(name)
1524        } else {
1525            col(name)
1526        },
1527        Some(dtype),
1528    ))
1529}
1530
1531pub(crate) fn resolve_compound_identifier(
1532    ctx: &mut SQLContext,
1533    idents: &[Ident],
1534    active_schema: Option<&Schema>,
1535) -> PolarsResult<Vec<Expr>> {
1536    // inference priority: table > struct > column
1537    let ident_root = &idents[0];
1538    let mut remaining_idents = idents.iter().skip(1);
1539    let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1540
1541    // get schema from table (or the active/default schema)
1542    let schema = if let Some(ref mut lf) = lf {
1543        lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)?
1544    } else {
1545        Arc::new(active_schema.cloned().unwrap_or_default())
1546    };
1547
1548    // handle simple/unqualified column reference with no schema
1549    if lf.is_none() && schema.is_empty() {
1550        let (mut column, mut dtype): (Expr, Option<&DataType>) =
1551            (col(ident_root.value.as_str()), None);
1552
1553        // traverse the remaining struct field path (if any)
1554        for ident in remaining_idents {
1555            let name = ident.value.as_str();
1556            match dtype {
1557                Some(DataType::Struct(fields)) if name == "*" => {
1558                    return Ok(fields
1559                        .iter()
1560                        .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1561                        .collect());
1562                },
1563                Some(DataType::Struct(fields)) => {
1564                    dtype = fields
1565                        .iter()
1566                        .find(|fld| fld.name == name)
1567                        .map(|fld| &fld.dtype);
1568                },
1569                Some(dtype) if name == "*" => {
1570                    polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1571                },
1572                _ => dtype = None,
1573            }
1574            column = column.struct_().field_by_name(name);
1575        }
1576        return Ok(vec![column]);
1577    }
1578
1579    let name = &remaining_idents.next().unwrap().value;
1580
1581    // handle "table.*" wildcard expansion
1582    if lf.is_some() && name == "*" {
1583        return schema
1584            .iter_names_and_dtypes()
1585            .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).map(|(expr, _)| expr))
1586            .collect();
1587    }
1588
1589    // resolve column/struct reference
1590    let col_dtype: PolarsResult<(Expr, Option<&DataType>)> =
1591        match (lf.is_none(), schema.get(&ident_root.value)) {
1592            // root is a column/struct in schema (no table)
1593            (true, Some(dtype)) => {
1594                remaining_idents = idents.iter().skip(1);
1595                Ok((col(ident_root.value.as_str()), Some(dtype)))
1596            },
1597            // root is not in schema and no table found
1598            (true, None) => {
1599                polars_bail!(
1600                    SQLInterface: "no table or struct column named '{}' found",
1601                    ident_root
1602                )
1603            },
1604            // root is a table, resolve column from table schema
1605            (false, _) => {
1606                if let Some((_, col_name, dtype)) = schema.get_full(name) {
1607                    resolve_column(ctx, ident_root, col_name, dtype)
1608                } else {
1609                    polars_bail!(
1610                        SQLInterface: "no column named '{}' found in table '{}'",
1611                        name, ident_root
1612                    )
1613                }
1614            },
1615        };
1616
1617    // additional ident levels index into struct fields (eg: "df.col.field.nested_field")
1618    let (mut column, mut dtype) = col_dtype?;
1619    for ident in remaining_idents {
1620        let name = ident.value.as_str();
1621        match dtype {
1622            Some(DataType::Struct(fields)) if name == "*" => {
1623                return Ok(fields
1624                    .iter()
1625                    .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1626                    .collect());
1627            },
1628            Some(DataType::Struct(fields)) => {
1629                dtype = fields
1630                    .iter()
1631                    .find(|fld| fld.name == name)
1632                    .map(|fld| &fld.dtype);
1633            },
1634            Some(dtype) if name == "*" => {
1635                polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1636            },
1637            _ => {
1638                dtype = None;
1639            },
1640        }
1641        column = column.struct_().field_by_name(name);
1642    }
1643    Ok(vec![column])
1644}