1use std::fmt::Display;
10use std::ops::Div;
11
12use polars_core::prelude::*;
13use polars_lazy::prelude::*;
14use polars_plan::plans::DynLiteralValue;
15use polars_plan::prelude::typed_lit;
16use polars_time::Duration;
17use polars_utils::unique_column_name;
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use sqlparser::ast::{
21 AccessExpr, BinaryOperator as SQLBinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
22 DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, Interval, Query as Subquery,
23 SelectItem, Subscript, TimezoneInfo, TrimWhereField, TypedString, UnaryOperator,
24 Value as SQLValue, ValueWithSpan,
25};
26use sqlparser::dialect::GenericDialect;
27use sqlparser::parser::{Parser, ParserOptions};
28
29use crate::SQLContext;
30use crate::functions::SQLFunctionVisitor;
31use crate::types::{
32 bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
33};
34
35#[inline]
36#[cold]
37#[must_use]
38pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
40 PolarsError::SQLInterface(err.to_string().into())
41}
42
43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
45pub enum SubqueryRestriction {
47 SingleColumn,
49 }
53
54pub(crate) struct SQLExprVisitor<'a> {
56 ctx: &'a mut SQLContext,
57 active_schema: Option<&'a Schema>,
58}
59
60impl SQLExprVisitor<'_> {
61 fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
62 let mut array_elements = Vec::with_capacity(elements.len());
63 for e in elements {
64 let val = match e {
65 SQLExpr::Value(ValueWithSpan { value: v, .. }) => self.visit_any_value(v, None),
66 SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
67 SQLExpr::Value(ValueWithSpan { value: v, .. }) => {
68 self.visit_any_value(v, Some(op))
69 },
70 _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
71 },
72 SQLExpr::Array(values) => {
73 let srs = self.array_expr_to_series(&values.elem)?;
74 Ok(AnyValue::List(srs))
75 },
76 _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
77 }?
78 .into_static();
79 array_elements.push(val);
80 }
81 Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
82 }
83
84 fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
85 match expr {
86 SQLExpr::AllOp {
87 left,
88 compare_op,
89 right,
90 } => self.visit_all(left, compare_op, right),
91 SQLExpr::AnyOp {
92 left,
93 compare_op,
94 right,
95 is_some: _,
96 } => self.visit_any(left, compare_op, right),
97 SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
98 SQLExpr::Between {
99 expr,
100 negated,
101 low,
102 high,
103 } => self.visit_between(expr, *negated, low, high),
104 SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
105 SQLExpr::Cast {
106 kind,
107 expr,
108 data_type,
109 format,
110 } => self.visit_cast(expr, data_type, format, kind),
111 SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
112 SQLExpr::CompoundFieldAccess { root, access_chain } => {
113 if access_chain.len() == 1 {
115 match &access_chain[0] {
116 AccessExpr::Subscript(subscript) => {
117 return self.visit_subscript(root, subscript);
118 },
119 AccessExpr::Dot(_) => {
120 polars_bail!(SQLSyntax: "dot-notation field access is currently unsupported: {:?}", access_chain[0])
121 },
122 }
123 }
124 polars_bail!(SQLSyntax: "complex field access chains are currently unsupported: {:?}", access_chain[0])
126 },
127 SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
128 SQLExpr::Extract {
129 field,
130 syntax: _,
131 expr,
132 } => parse_extract_date_part(self.visit_expr(expr)?, field),
133 SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
134 SQLExpr::Function(function) => self.visit_function(function),
135 SQLExpr::Identifier(ident) => self.visit_identifier(ident),
136 SQLExpr::InList {
137 expr,
138 list,
139 negated,
140 } => {
141 let expr = self.visit_expr(expr)?;
142 let elems = self.visit_array_expr(list, true, Some(&expr))?;
143 let is_in = expr.is_in(elems, false);
144 Ok(if *negated { is_in.not() } else { is_in })
145 },
146 SQLExpr::InSubquery {
147 expr,
148 subquery,
149 negated,
150 } => self.visit_in_subquery(expr, subquery, *negated),
151 SQLExpr::Interval(interval) => Ok(lit(interval_to_duration(interval, true)?)),
152 SQLExpr::IsDistinctFrom(e1, e2) => {
153 Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
154 },
155 SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
156 SQLExpr::IsNotDistinctFrom(e1, e2) => {
157 Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
158 },
159 SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
160 SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
161 SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
162 SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
163 SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
164 SQLExpr::Like {
165 negated,
166 any,
167 expr,
168 pattern,
169 escape_char,
170 } => {
171 if *any {
172 polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
173 }
174 let escape_str = escape_char.as_ref().and_then(|v| match v {
175 SQLValue::SingleQuotedString(s) => Some(s.clone()),
176 _ => None,
177 });
178 self.visit_like(*negated, expr, pattern, &escape_str, false)
179 },
180 SQLExpr::ILike {
181 negated,
182 any,
183 expr,
184 pattern,
185 escape_char,
186 } => {
187 if *any {
188 polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
189 }
190 let escape_str = escape_char.as_ref().and_then(|v| match v {
191 SQLValue::SingleQuotedString(s) => Some(s.clone()),
192 _ => None,
193 });
194 self.visit_like(*negated, expr, pattern, &escape_str, true)
195 },
196 SQLExpr::Nested(expr) => self.visit_expr(expr),
197 SQLExpr::Position { expr, r#in } => Ok(
198 (self
200 .visit_expr(r#in)?
201 .str()
202 .find(self.visit_expr(expr)?, true)
203 + typed_lit(1u32))
204 .fill_null(typed_lit(0u32)),
205 ),
206 SQLExpr::RLike {
207 negated,
209 expr,
210 pattern,
211 regexp: _,
212 } => {
213 let matches = self
214 .visit_expr(expr)?
215 .str()
216 .contains(self.visit_expr(pattern)?, true);
217 Ok(if *negated { matches.not() } else { matches })
218 },
219 SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
220 SQLExpr::Substring {
221 expr,
222 substring_from,
223 substring_for,
224 ..
225 } => self.visit_substring(expr, substring_from.as_deref(), substring_for.as_deref()),
226 SQLExpr::Trim {
227 expr,
228 trim_where,
229 trim_what,
230 trim_characters,
231 } => self.visit_trim(expr, trim_where, trim_what, trim_characters),
232 SQLExpr::TypedString(TypedString {
233 data_type,
234 value:
235 ValueWithSpan {
236 value: SQLValue::SingleQuotedString(v),
237 ..
238 },
239 uses_odbc_syntax: _,
240 }) => match data_type {
241 SQLDataType::Date => {
242 if is_iso_date(v) {
243 Ok(lit(v.as_str()).cast(DataType::Date))
244 } else {
245 polars_bail!(SQLSyntax: "invalid DATE literal '{}'", v)
246 }
247 },
248 SQLDataType::Time(None, TimezoneInfo::None) => {
249 if is_iso_time(v) {
250 Ok(lit(v.as_str()).str().to_time(StrptimeOptions {
251 strict: true,
252 ..Default::default()
253 }))
254 } else {
255 polars_bail!(SQLSyntax: "invalid TIME literal '{}'", v)
256 }
257 },
258 SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
259 if is_iso_datetime(v) {
260 Ok(lit(v.as_str()).str().to_datetime(
261 None,
262 None,
263 StrptimeOptions {
264 strict: true,
265 ..Default::default()
266 },
267 lit("latest"),
268 ))
269 } else {
270 let fn_name = match data_type {
271 SQLDataType::Timestamp(_, _) => "TIMESTAMP",
272 SQLDataType::Datetime(_) => "DATETIME",
273 _ => unreachable!(),
274 };
275 polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, v)
276 }
277 },
278 _ => {
279 polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
280 },
281 },
282 SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
283 SQLExpr::Value(ValueWithSpan { value, .. }) => self.visit_literal(value),
284 SQLExpr::Wildcard(_) => Ok(all().as_expr()),
285 e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
286 other => {
287 polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
288 },
289 }
290 }
291
292 fn visit_subquery(
293 &mut self,
294 subquery: &Subquery,
295 restriction: SubqueryRestriction,
296 ) -> PolarsResult<Expr> {
297 if subquery.with.is_some() {
298 polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
299 }
300 let lf = self
303 .ctx
304 .execute_isolated(|ctx| ctx.execute_query_no_ctes(subquery))?;
305
306 if restriction == SubqueryRestriction::SingleColumn {
307 let new_name = unique_column_name();
308 return Ok(Expr::SubPlan(
309 SpecialEq::new(Arc::new(lf.logical_plan)),
310 vec![(
312 new_name.clone(),
313 first().as_expr().implode().alias(new_name.clone()),
314 )],
315 ));
316 };
317 polars_bail!(SQLInterface: "subquery type not supported");
318 }
319
320 fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
324 Ok(col(ident.value.as_str()))
325 }
326
327 fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
331 Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
332 }
333
334 fn visit_like(
335 &mut self,
336 negated: bool,
337 expr: &SQLExpr,
338 pattern: &SQLExpr,
339 escape_char: &Option<String>,
340 case_insensitive: bool,
341 ) -> PolarsResult<Expr> {
342 if escape_char.is_some() {
343 polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
344 }
345 let pat = match self.visit_expr(pattern) {
346 Ok(Expr::Literal(lv)) if lv.extract_str().is_some() => {
347 PlSmallStr::from_str(lv.extract_str().unwrap())
348 },
349 _ => {
350 polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
351 },
352 };
353 if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
354 let op = if negated {
356 SQLBinaryOperator::NotEq
357 } else {
358 SQLBinaryOperator::Eq
359 };
360 self.visit_binary_op(expr, &op, pattern)
361 } else {
362 let mut rx = regex::escape(pat.as_str())
364 .replace('%', ".*")
365 .replace('_', ".");
366
367 rx = format!(
368 "^{}{}$",
369 if case_insensitive { "(?is)" } else { "(?s)" },
370 rx
371 );
372
373 let expr = self.visit_expr(expr)?;
374 let matches = expr.str().contains(lit(rx), true);
375 Ok(if negated { matches.not() } else { matches })
376 }
377 }
378
379 fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
380 let expr = self.visit_expr(expr)?;
381 Ok(match subscript {
382 Subscript::Index { index } => {
383 let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
384 expr.list().get(idx, true)
385 },
386 Subscript::Slice { .. } => {
387 polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
388 },
389 })
390 }
391
392 fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
399 if let (Some(name), Some(s), expr_dtype) = match (left, right) {
400 (Expr::Column(name), Expr::Literal(lv)) if lv.extract_str().is_some() => {
402 (Some(name.clone()), Some(lv.extract_str().unwrap()), None)
403 },
404 (Expr::Cast { expr, dtype, .. }, Expr::Literal(lv)) if lv.extract_str().is_some() => {
406 let s = lv.extract_str().unwrap();
407 match &**expr {
408 Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
409 _ => (None, Some(s), Some(dtype)),
410 }
411 },
412 _ => (None, None, None),
413 } {
414 if expr_dtype.is_none() && self.active_schema.is_none() {
415 right.clone()
416 } else {
417 let left_dtype = expr_dtype.map_or_else(
418 || {
419 self.active_schema
420 .as_ref()
421 .and_then(|schema| schema.get(&name))
422 },
423 |dt| dt.as_literal(),
424 );
425 match left_dtype {
426 Some(DataType::Time) if is_iso_time(s) => {
427 right.clone().str().to_time(StrptimeOptions {
428 strict: true,
429 ..Default::default()
430 })
431 },
432 Some(DataType::Date) if is_iso_date(s) => {
433 right.clone().str().to_date(StrptimeOptions {
434 strict: true,
435 ..Default::default()
436 })
437 },
438 Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
439 if s.len() == 10 {
440 lit(format!("{s}T00:00:00"))
442 } else {
443 lit(s.replacen(' ', "T", 1))
444 }
445 .str()
446 .to_datetime(
447 Some(*tu),
448 tz.clone(),
449 StrptimeOptions {
450 strict: true,
451 ..Default::default()
452 },
453 lit("latest"),
454 )
455 },
456 _ => right.clone(),
457 }
458 }
459 } else {
460 right.clone()
461 }
462 }
463
464 fn struct_field_access_expr(
465 &mut self,
466 expr: &Expr,
467 path: &str,
468 infer_index: bool,
469 ) -> PolarsResult<Expr> {
470 let path_elems = if path.starts_with('{') && path.ends_with('}') {
471 path.trim_matches(|c| c == '{' || c == '}')
472 } else {
473 path
474 }
475 .split(',');
476
477 let mut expr = expr.clone();
478 for p in path_elems {
479 let p = p.trim();
480 expr = if infer_index {
481 match p.parse::<i64>() {
482 Ok(idx) => expr.list().get(lit(idx), true),
483 Err(_) => expr.struct_().field_by_name(p),
484 }
485 } else {
486 expr.struct_().field_by_name(p)
487 }
488 }
489 Ok(expr)
490 }
491
492 fn visit_binary_op(
496 &mut self,
497 left: &SQLExpr,
498 op: &SQLBinaryOperator,
499 right: &SQLExpr,
500 ) -> PolarsResult<Expr> {
501 if matches!(left, SQLExpr::Subquery(_)) || matches!(right, SQLExpr::Subquery(_)) {
503 let (suggestion, str_op) = match op {
504 SQLBinaryOperator::NotEq => ("; use 'NOT IN' instead", "!=".to_string()),
505 SQLBinaryOperator::Eq => ("; use 'IN' instead", format!("{op}")),
506 _ => ("", format!("{op}")),
507 };
508 polars_bail!(
509 SQLSyntax: "subquery comparisons with '{str_op}' are not supported{suggestion}"
510 );
511 }
512
513 let (lhs, mut rhs) = match (left, op, right) {
515 (_, SQLBinaryOperator::Minus, SQLExpr::Interval(v)) => {
516 let duration = interval_to_duration(v, false)?;
517 return Ok(self
518 .visit_expr(left)?
519 .dt()
520 .offset_by(lit(format!("-{duration}"))));
521 },
522 (_, SQLBinaryOperator::Plus, SQLExpr::Interval(v)) => {
523 let duration = interval_to_duration(v, false)?;
524 return Ok(self
525 .visit_expr(left)?
526 .dt()
527 .offset_by(lit(format!("{duration}"))));
528 },
529 (SQLExpr::Interval(v1), _, SQLExpr::Interval(v2)) => {
530 let d1 = interval_to_duration(v1, false)?;
532 let d2 = interval_to_duration(v2, false)?;
533 let res = match op {
534 SQLBinaryOperator::Gt => Ok(lit(d1 > d2)),
535 SQLBinaryOperator::Lt => Ok(lit(d1 < d2)),
536 SQLBinaryOperator::GtEq => Ok(lit(d1 >= d2)),
537 SQLBinaryOperator::LtEq => Ok(lit(d1 <= d2)),
538 SQLBinaryOperator::NotEq => Ok(lit(d1 != d2)),
539 SQLBinaryOperator::Eq | SQLBinaryOperator::Spaceship => Ok(lit(d1 == d2)),
540 _ => polars_bail!(SQLInterface: "invalid interval comparison operator"),
541 };
542 if res.is_ok() {
543 return res;
544 }
545 (self.visit_expr(left)?, self.visit_expr(right)?)
546 },
547 _ => (self.visit_expr(left)?, self.visit_expr(right)?),
548 };
549 rhs = self.convert_temporal_strings(&lhs, &rhs);
550
551 Ok(match op {
552 SQLBinaryOperator::BitwiseAnd => lhs.and(rhs), SQLBinaryOperator::BitwiseOr => lhs.or(rhs), SQLBinaryOperator::Xor => lhs.xor(rhs), SQLBinaryOperator::And => lhs.and(rhs), SQLBinaryOperator::Divide => lhs / rhs, SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), SQLBinaryOperator::Eq => lhs.eq(rhs), SQLBinaryOperator::Gt => lhs.gt(rhs), SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), SQLBinaryOperator::Lt => lhs.lt(rhs), SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), SQLBinaryOperator::Minus => lhs - rhs, SQLBinaryOperator::Modulo => lhs % rhs, SQLBinaryOperator::Multiply => lhs * rhs, SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), SQLBinaryOperator::Or => lhs.or(rhs), SQLBinaryOperator::Plus => lhs + rhs, SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), SQLBinaryOperator::StringConcat => { lhs.cast(DataType::String) + rhs.cast(DataType::String)
579 },
580 SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), SQLBinaryOperator::PGRegexMatch => match rhs { Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true),
586 _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
587 },
588 SQLBinaryOperator::PGRegexNotMatch => match rhs { Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true).not(),
590 _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
591 },
592 SQLBinaryOperator::PGRegexIMatch => match rhs { Expr::Literal(ref lv) if lv.extract_str().is_some() => {
594 let pat = lv.extract_str().unwrap();
595 lhs.str().contains(lit(format!("(?i){pat}")), true)
596 },
597 _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
598 },
599 SQLBinaryOperator::PGRegexNotIMatch => match rhs { Expr::Literal(ref lv) if lv.extract_str().is_some() => {
601 let pat = lv.extract_str().unwrap();
602 lhs.str().contains(lit(format!("(?i){pat}")), true).not()
603 },
604 _ => {
605 polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
606 },
607 },
608 SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch | SQLBinaryOperator::PGILikeMatch | SQLBinaryOperator::PGNotILikeMatch => { let expr = if matches!(
616 op,
617 SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
618 ) {
619 SQLExpr::Like {
620 negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
621 any: false,
622 expr: Box::new(left.clone()),
623 pattern: Box::new(right.clone()),
624 escape_char: None,
625 }
626 } else {
627 SQLExpr::ILike {
628 negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
629 any: false,
630 expr: Box::new(left.clone()),
631 pattern: Box::new(right.clone()),
632 escape_char: None,
633 }
634 };
635 self.visit_expr(&expr)?
636 },
637 SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { Expr::Literal(lv) if lv.extract_str().is_some() => {
642 let path = lv.extract_str().unwrap();
643 let mut expr = self.struct_field_access_expr(&lhs, path, false)?;
644 if let SQLBinaryOperator::LongArrow = op {
645 expr = expr.cast(DataType::String);
646 }
647 expr
648 },
649 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(idx))) => {
650 let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
651 if let SQLBinaryOperator::LongArrow = op {
652 expr = expr.cast(DataType::String);
653 }
654 expr
655 },
656 _ => {
657 polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
658 },
659 },
660 SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { match rhs {
662 Expr::Literal(lv) if lv.extract_str().is_some() => {
663 let path = lv.extract_str().unwrap();
664 let mut expr = self.struct_field_access_expr(&lhs, path, true)?;
665 if let SQLBinaryOperator::HashLongArrow = op {
666 expr = expr.cast(DataType::String);
667 }
668 expr
669 },
670 _ => {
671 polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
672 }
673 }
674 },
675 other => {
676 polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
677 },
678 })
679 }
680
681 fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
685 let expr = self.visit_expr(expr)?;
686 Ok(match (op, expr.clone()) {
687 (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
689 lit(n)
690 },
691 (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
692 lit(n)
693 },
694 (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
695 lit(-n)
696 },
697 (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
698 lit(-n)
699 },
700 (UnaryOperator::Plus, _) => lit(0) + expr,
702 (UnaryOperator::Minus, _) => lit(0) - expr,
703 (UnaryOperator::Not, _) => match &expr {
704 Expr::Column(name)
705 if self
706 .active_schema
707 .and_then(|schema| schema.get(name))
708 .is_some_and(|dtype| matches!(dtype, DataType::Boolean)) =>
709 {
710 expr.not()
712 },
713 _ => expr.strict_cast(DataType::Boolean).not(),
715 },
716 other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
717 })
718 }
719
720 fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
726 let mut visitor = SQLFunctionVisitor {
727 func: function,
728 ctx: self.ctx,
729 active_schema: self.active_schema,
730 };
731 visitor.visit_function()
732 }
733
734 fn visit_all(
738 &mut self,
739 left: &SQLExpr,
740 compare_op: &SQLBinaryOperator,
741 right: &SQLExpr,
742 ) -> PolarsResult<Expr> {
743 let left = self.visit_expr(left)?;
744 let right = self.visit_expr(right)?;
745
746 match compare_op {
747 SQLBinaryOperator::Gt => Ok(left.gt(right.max())),
748 SQLBinaryOperator::Lt => Ok(left.lt(right.min())),
749 SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
750 SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
751 SQLBinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
752 SQLBinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
753 _ => polars_bail!(SQLInterface: "invalid comparison operator"),
754 }
755 }
756
757 fn visit_any(
761 &mut self,
762 left: &SQLExpr,
763 compare_op: &SQLBinaryOperator,
764 right: &SQLExpr,
765 ) -> PolarsResult<Expr> {
766 let left = self.visit_expr(left)?;
767 let right = self.visit_expr(right)?;
768
769 match compare_op {
770 SQLBinaryOperator::Gt => Ok(left.gt(right.min())),
771 SQLBinaryOperator::Lt => Ok(left.lt(right.max())),
772 SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
773 SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
774 SQLBinaryOperator::Eq => Ok(left.is_in(right, false)),
775 SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()),
776 _ => polars_bail!(SQLInterface: "invalid comparison operator"),
777 }
778 }
779
780 fn visit_array_expr(
782 &mut self,
783 elements: &[SQLExpr],
784 result_as_element: bool,
785 dtype_expr_match: Option<&Expr>,
786 ) -> PolarsResult<Expr> {
787 let mut elems = self.array_expr_to_series(elements)?;
788
789 if let (Some(Expr::Column(name)), Some(schema)) =
792 (dtype_expr_match, self.active_schema.as_ref())
793 {
794 if elems.dtype() == &DataType::String {
795 if let Some(dtype) = schema.get(name) {
796 if matches!(
797 dtype,
798 DataType::Date | DataType::Time | DataType::Datetime(_, _)
799 ) {
800 elems = elems.strict_cast(dtype)?;
801 }
802 }
803 }
804 }
805
806 let res = if result_as_element {
809 elems.implode()?.into_series()
810 } else {
811 elems
812 };
813 Ok(lit(res))
814 }
815
816 fn visit_cast(
820 &mut self,
821 expr: &SQLExpr,
822 dtype: &SQLDataType,
823 format: &Option<CastFormat>,
824 cast_kind: &CastKind,
825 ) -> PolarsResult<Expr> {
826 if format.is_some() {
827 return Err(
828 polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
829 );
830 }
831 let expr = self.visit_expr(expr)?;
832
833 #[cfg(feature = "json")]
834 if dtype == &SQLDataType::JSON {
835 return Ok(expr.str().json_decode(DataType::Struct(Vec::new())));
837 }
838 let polars_type = map_sql_dtype_to_polars(dtype)?;
839 Ok(match cast_kind {
840 CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
841 CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
842 })
843 }
844
845 fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
851 Ok(match value {
853 SQLValue::Boolean(b) => lit(*b),
854 SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
855 #[cfg(feature = "binary_encoding")]
856 SQLValue::HexStringLiteral(x) => {
857 if x.len() % 2 != 0 {
858 polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
859 };
860 lit(hex::decode(x.clone()).unwrap())
861 },
862 SQLValue::Null => Expr::Literal(LiteralValue::untyped_null()),
863 SQLValue::Number(s, _) => {
864 if s.contains('.') {
866 s.parse::<f64>().map(lit).map_err(|_| ())
867 } else {
868 s.parse::<i64>().map(lit).map_err(|_| ())
869 }
870 .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
871 },
872 SQLValue::SingleQuotedByteStringLiteral(b) => {
873 bitstring_to_bytes_literal(b)?
877 },
878 SQLValue::SingleQuotedString(s) => lit(s.clone()),
879 other => {
880 polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
881 },
882 })
883 }
884
885 fn visit_any_value(
887 &self,
888 value: &SQLValue,
889 op: Option<&UnaryOperator>,
890 ) -> PolarsResult<AnyValue<'_>> {
891 Ok(match value {
892 SQLValue::Boolean(b) => AnyValue::Boolean(*b),
893 SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
894 #[cfg(feature = "binary_encoding")]
895 SQLValue::HexStringLiteral(x) => {
896 if x.len() % 2 != 0 {
897 polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
898 };
899 AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
900 },
901 SQLValue::Null => AnyValue::Null,
902 SQLValue::Number(s, _) => {
903 let negate = match op {
904 Some(UnaryOperator::Minus) => true,
905 Some(UnaryOperator::Plus) | None => false,
907 Some(op) => {
908 polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
909 },
910 };
911 if s.contains('.') {
913 s.parse::<f64>()
914 .map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
915 .map_err(|_| ())
916 } else {
917 s.parse::<i64>()
918 .map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
919 .map_err(|_| ())
920 }
921 .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
922 },
923 SQLValue::SingleQuotedByteStringLiteral(b) => {
924 let bytes_literal = bitstring_to_bytes_literal(b)?;
926 match bytes_literal {
927 Expr::Literal(lv) if lv.extract_binary().is_some() => {
928 AnyValue::BinaryOwned(lv.extract_binary().unwrap().to_vec())
929 },
930 _ => {
931 polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
932 },
933 }
934 },
935 SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
936 other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
937 })
938 }
939
940 fn visit_between(
943 &mut self,
944 expr: &SQLExpr,
945 negated: bool,
946 low: &SQLExpr,
947 high: &SQLExpr,
948 ) -> PolarsResult<Expr> {
949 let expr = self.visit_expr(expr)?;
950 let low = self.visit_expr(low)?;
951 let high = self.visit_expr(high)?;
952
953 let low = self.convert_temporal_strings(&expr, &low);
954 let high = self.convert_temporal_strings(&expr, &high);
955 Ok(if negated {
956 expr.clone().lt(low).or(expr.gt(high))
957 } else {
958 expr.clone().gt_eq(low).and(expr.lt_eq(high))
959 })
960 }
961
962 fn visit_trim(
965 &mut self,
966 expr: &SQLExpr,
967 trim_where: &Option<TrimWhereField>,
968 trim_what: &Option<Box<SQLExpr>>,
969 trim_characters: &Option<Vec<SQLExpr>>,
970 ) -> PolarsResult<Expr> {
971 if trim_characters.is_some() {
972 return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
974 };
975 let expr = self.visit_expr(expr)?;
976 let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
977 let trim_what = match trim_what {
978 Some(Expr::Literal(lv)) if lv.extract_str().is_some() => {
979 Some(PlSmallStr::from_str(lv.extract_str().unwrap()))
980 },
981 None => None,
982 _ => return self.err(&expr),
983 };
984 Ok(match (trim_where, trim_what) {
985 (None | Some(TrimWhereField::Both), None) => {
986 expr.str().strip_chars(lit(LiteralValue::untyped_null()))
987 },
988 (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
989 (Some(TrimWhereField::Leading), None) => expr
990 .str()
991 .strip_chars_start(lit(LiteralValue::untyped_null())),
992 (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
993 (Some(TrimWhereField::Trailing), None) => expr
994 .str()
995 .strip_chars_end(lit(LiteralValue::untyped_null())),
996 (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
997 })
998 }
999
1000 fn visit_substring(
1001 &mut self,
1002 expr: &SQLExpr,
1003 substring_from: Option<&SQLExpr>,
1004 substring_for: Option<&SQLExpr>,
1005 ) -> PolarsResult<Expr> {
1006 let e = self.visit_expr(expr)?;
1007
1008 match (substring_from, substring_for) {
1009 (Some(from_expr), Some(for_expr)) => {
1011 let start = self.visit_expr(from_expr)?;
1012 let length = self.visit_expr(for_expr)?;
1013
1014 Ok(match (start.clone(), length.clone()) {
1016 (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1017 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1018 polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", n)
1019 },
1020 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => {
1021 e.str().slice(lit(n - 1), length)
1022 },
1023 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => e
1024 .str()
1025 .slice(lit(0), (length + lit(n - 1)).clip_min(lit(0))),
1026 (Expr::Literal(_), _) => {
1027 polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1028 },
1029 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1030 polars_bail!(SQLSyntax: "invalid 'length' for SUBSTRING")
1031 },
1032 _ => {
1033 let adjusted_start = start - lit(1);
1034 when(adjusted_start.clone().lt(lit(0)))
1035 .then(e.clone().str().slice(
1036 lit(0),
1037 (length.clone() + adjusted_start.clone()).clip_min(lit(0)),
1038 ))
1039 .otherwise(e.str().slice(adjusted_start, length))
1040 },
1041 })
1042 },
1043 (Some(from_expr), None) => {
1045 let start = self.visit_expr(from_expr)?;
1046
1047 Ok(match start {
1048 Expr::Literal(lv) if lv.is_null() => lit(lv),
1049 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1050 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1051 e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null()))
1052 },
1053 Expr::Literal(_) => {
1054 polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1055 },
1056 _ => e
1057 .str()
1058 .slice(start - lit(1), lit(LiteralValue::untyped_null())),
1059 })
1060 },
1061 (None, _) => {
1063 polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found 1)")
1064 },
1065 }
1066 }
1067
1068 fn visit_in_subquery(
1070 &mut self,
1071 expr: &SQLExpr,
1072 subquery: &Subquery,
1073 negated: bool,
1074 ) -> PolarsResult<Expr> {
1075 let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?;
1076 let expr = self.visit_expr(expr)?;
1077 Ok(if negated {
1078 expr.is_in(subquery_result, false).not()
1079 } else {
1080 expr.is_in(subquery_result, false)
1081 })
1082 }
1083
1084 fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
1086 if let SQLExpr::Case {
1087 case_token: _,
1088 end_token: _,
1089 operand,
1090 conditions,
1091 else_result,
1092 } = expr
1093 {
1094 polars_ensure!(
1095 !conditions.is_empty(),
1096 SQLSyntax: "WHEN and THEN expressions must have at least one element"
1097 );
1098
1099 let mut when_thens = conditions.iter();
1100 let first = when_thens.next();
1101 if first.is_none() {
1102 polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
1103 }
1104 let else_res = match else_result {
1105 Some(else_res) => self.visit_expr(else_res)?,
1106 None => lit(LiteralValue::untyped_null()), };
1108 if let Some(operand_expr) = operand {
1109 let first_operand_expr = self.visit_expr(operand_expr)?;
1110
1111 let first = first.unwrap();
1112 let first_cond = first_operand_expr.eq(self.visit_expr(&first.condition)?);
1113 let first_then = self.visit_expr(&first.result)?;
1114 let expr = when(first_cond).then(first_then);
1115 let next = when_thens.next();
1116
1117 let mut when_then = if let Some(case_when) = next {
1118 let second_operand_expr = self.visit_expr(operand_expr)?;
1119 let cond = second_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1120 let res = self.visit_expr(&case_when.result)?;
1121 expr.when(cond).then(res)
1122 } else {
1123 return Ok(expr.otherwise(else_res));
1124 };
1125 for case_when in when_thens {
1126 let new_operand_expr = self.visit_expr(operand_expr)?;
1127 let cond = new_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1128 let res = self.visit_expr(&case_when.result)?;
1129 when_then = when_then.when(cond).then(res);
1130 }
1131 return Ok(when_then.otherwise(else_res));
1132 }
1133
1134 let first = first.unwrap();
1135 let first_cond = self.visit_expr(&first.condition)?;
1136 let first_then = self.visit_expr(&first.result)?;
1137 let expr = when(first_cond).then(first_then);
1138 let next = when_thens.next();
1139
1140 let mut when_then = if let Some(case_when) = next {
1141 let cond = self.visit_expr(&case_when.condition)?;
1142 let res = self.visit_expr(&case_when.result)?;
1143 expr.when(cond).then(res)
1144 } else {
1145 return Ok(expr.otherwise(else_res));
1146 };
1147 for case_when in when_thens {
1148 let cond = self.visit_expr(&case_when.condition)?;
1149 let res = self.visit_expr(&case_when.result)?;
1150 when_then = when_then.when(cond).then(res);
1151 }
1152 Ok(when_then.otherwise(else_res))
1153 } else {
1154 unreachable!()
1155 }
1156 }
1157
1158 fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
1159 polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
1160 }
1161}
1162
1163pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1181 let mut ctx = SQLContext::new();
1182
1183 let mut parser = Parser::new(&GenericDialect);
1184 parser = parser.with_options(ParserOptions {
1185 trailing_commas: true,
1186 ..Default::default()
1187 });
1188
1189 let mut ast = parser
1190 .try_with_sql(s.as_ref())
1191 .map_err(to_sql_interface_err)?;
1192 let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1193
1194 Ok(match &expr {
1195 SelectItem::ExprWithAlias { expr, alias } => {
1196 let expr = parse_sql_expr(expr, &mut ctx, None)?;
1197 expr.alias(alias.value.as_str())
1198 },
1199 SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1200 _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1201 })
1202}
1203
1204pub(crate) fn interval_to_duration(interval: &Interval, fixed: bool) -> PolarsResult<Duration> {
1205 if interval.last_field.is_some()
1206 || interval.leading_field.is_some()
1207 || interval.leading_precision.is_some()
1208 || interval.fractional_seconds_precision.is_some()
1209 {
1210 polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
1211 }
1212 let s = match &*interval.value {
1213 SQLExpr::UnaryOp { .. } => {
1214 polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
1215 },
1216 SQLExpr::Value(ValueWithSpan {
1217 value: SQLValue::SingleQuotedString(s),
1218 ..
1219 }) => Some(s),
1220 _ => None,
1221 };
1222 match s {
1223 Some(s) if s.contains('-') => {
1224 polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
1225 },
1226 Some(s) => {
1227 let duration = Duration::parse_interval(s);
1230 if fixed && duration.months() != 0 {
1231 polars_bail!(SQLSyntax: "fixed-duration interval cannot contain years, quarters, or months; found {}", s)
1232 };
1233 Ok(duration)
1234 },
1235 None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
1236 }
1237}
1238
1239pub(crate) fn parse_sql_expr(
1240 expr: &SQLExpr,
1241 ctx: &mut SQLContext,
1242 active_schema: Option<&Schema>,
1243) -> PolarsResult<Expr> {
1244 let mut visitor = SQLExprVisitor { ctx, active_schema };
1245 visitor.visit_expr(expr)
1246}
1247
1248pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1249 match expr {
1250 SQLExpr::Array(arr) => {
1251 let mut visitor = SQLExprVisitor {
1252 ctx,
1253 active_schema: None,
1254 };
1255 visitor.array_expr_to_series(arr.elem.as_slice())
1256 },
1257 _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1258 }
1259}
1260
1261pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1262 let field = match field {
1263 DateTimeField::Custom(Ident { value, .. }) => {
1265 let value = value.to_ascii_lowercase();
1266 match value.as_str() {
1267 "millennium" | "millennia" => &DateTimeField::Millennium,
1268 "century" | "centuries" => &DateTimeField::Century,
1269 "decade" | "decades" => &DateTimeField::Decade,
1270 "isoyear" => &DateTimeField::Isoyear,
1271 "year" | "years" | "y" => &DateTimeField::Year,
1272 "quarter" | "quarters" => &DateTimeField::Quarter,
1273 "month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1274 "dayofyear" | "doy" => &DateTimeField::DayOfYear,
1275 "dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1276 "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1277 "isodow" => &DateTimeField::Isodow,
1278 "day" | "days" | "d" => &DateTimeField::Day,
1279 "hour" | "hours" | "h" => &DateTimeField::Hour,
1280 "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1281 "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1282 "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1283 "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1284 "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1285 #[cfg(feature = "timezones")]
1286 "timezone" => &DateTimeField::Timezone,
1287 "time" => &DateTimeField::Time,
1288 "epoch" => &DateTimeField::Epoch,
1289 _ => {
1290 polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1291 },
1292 }
1293 },
1294 _ => field,
1295 };
1296 Ok(match field {
1297 DateTimeField::Millennium => expr.dt().millennium(),
1298 DateTimeField::Century => expr.dt().century(),
1299 DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1300 DateTimeField::Isoyear => expr.dt().iso_year(),
1301 DateTimeField::Year | DateTimeField::Years => expr.dt().year(),
1302 DateTimeField::Quarter => expr.dt().quarter(),
1303 DateTimeField::Month | DateTimeField::Months => expr.dt().month(),
1304 DateTimeField::Week(weekday) => {
1305 if weekday.is_some() {
1306 polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1307 }
1308 expr.dt().week()
1309 },
1310 DateTimeField::IsoWeek | DateTimeField::Weeks => expr.dt().week(),
1311 DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1312 DateTimeField::DayOfWeek | DateTimeField::Dow => {
1313 let w = expr.dt().weekday();
1314 when(w.clone().eq(typed_lit(7i8)))
1315 .then(typed_lit(0i8))
1316 .otherwise(w)
1317 },
1318 DateTimeField::Isodow => expr.dt().weekday(),
1319 DateTimeField::Day | DateTimeField::Days => expr.dt().day(),
1320 DateTimeField::Hour | DateTimeField::Hours => expr.dt().hour(),
1321 DateTimeField::Minute | DateTimeField::Minutes => expr.dt().minute(),
1322 DateTimeField::Second | DateTimeField::Seconds => expr.dt().second(),
1323 DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1324 (expr.clone().dt().second() * typed_lit(1_000f64))
1325 + expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1326 },
1327 DateTimeField::Microsecond | DateTimeField::Microseconds => {
1328 (expr.clone().dt().second() * typed_lit(1_000_000f64))
1329 + expr.dt().nanosecond().div(typed_lit(1_000f64))
1330 },
1331 DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1332 (expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1333 },
1334 DateTimeField::Time => expr.dt().time(),
1335 #[cfg(feature = "timezones")]
1336 DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(false),
1337 DateTimeField::Epoch => {
1338 expr.clone()
1339 .dt()
1340 .timestamp(TimeUnit::Nanoseconds)
1341 .div(typed_lit(1_000_000_000i64))
1342 + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1343 },
1344 _ => {
1345 polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1346 },
1347 })
1348}
1349
1350pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1353 match idx {
1354 Expr::Literal(sc) if sc.is_null() => lit(LiteralValue::untyped_null()),
1355 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => {
1356 if null_if_zero {
1357 lit(LiteralValue::untyped_null())
1358 } else {
1359 idx
1360 }
1361 },
1362 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n < 0 => idx,
1363 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => lit(n - 1),
1364 _ => when(idx.clone().gt(lit(0)))
1367 .then(idx.clone() - lit(1))
1368 .otherwise(if null_if_zero {
1369 when(idx.clone().eq(lit(0)))
1370 .then(lit(LiteralValue::untyped_null()))
1371 .otherwise(idx.clone())
1372 } else {
1373 idx.clone()
1374 }),
1375 }
1376}
1377
1378fn resolve_column<'a>(
1379 ctx: &'a mut SQLContext,
1380 ident_root: &'a Ident,
1381 name: &'a str,
1382 dtype: &'a DataType,
1383) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1384 let resolved = ctx.resolve_name(&ident_root.value, name);
1385 let resolved = resolved.as_str();
1386 Ok((
1387 if name != resolved {
1388 col(resolved).alias(name)
1389 } else {
1390 col(name)
1391 },
1392 Some(dtype),
1393 ))
1394}
1395
1396pub(crate) fn resolve_compound_identifier(
1397 ctx: &mut SQLContext,
1398 idents: &[Ident],
1399 active_schema: Option<&Schema>,
1400) -> PolarsResult<Vec<Expr>> {
1401 let ident_root = &idents[0];
1403 let mut remaining_idents = idents.iter().skip(1);
1404 let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1405
1406 let schema = if let Some(ref mut lf) = lf {
1408 lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)?
1409 } else {
1410 Arc::new(active_schema.cloned().unwrap_or_default())
1411 };
1412
1413 if lf.is_none() && schema.is_empty() {
1415 let (mut column, mut dtype): (Expr, Option<&DataType>) =
1416 (col(ident_root.value.as_str()), None);
1417
1418 for ident in remaining_idents {
1420 let name = ident.value.as_str();
1421 match dtype {
1422 Some(DataType::Struct(fields)) if name == "*" => {
1423 return Ok(fields
1424 .iter()
1425 .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1426 .collect());
1427 },
1428 Some(DataType::Struct(fields)) => {
1429 dtype = fields
1430 .iter()
1431 .find(|fld| fld.name == name)
1432 .map(|fld| &fld.dtype);
1433 },
1434 Some(dtype) if name == "*" => {
1435 polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1436 },
1437 _ => dtype = None,
1438 }
1439 column = column.struct_().field_by_name(name);
1440 }
1441 return Ok(vec![column]);
1442 }
1443
1444 let name = &remaining_idents.next().unwrap().value;
1445
1446 if lf.is_some() && name == "*" {
1448 return schema
1449 .iter_names_and_dtypes()
1450 .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).map(|(expr, _)| expr))
1451 .collect();
1452 }
1453
1454 let col_dtype: PolarsResult<(Expr, Option<&DataType>)> =
1456 match (lf.is_none(), schema.get(&ident_root.value)) {
1457 (true, Some(dtype)) => {
1459 remaining_idents = idents.iter().skip(1);
1460 Ok((col(ident_root.value.as_str()), Some(dtype)))
1461 },
1462 (true, None) => {
1464 polars_bail!(
1465 SQLInterface: "no table or struct column named '{}' found",
1466 ident_root
1467 )
1468 },
1469 (false, _) => {
1471 if let Some((_, col_name, dtype)) = schema.get_full(name) {
1472 resolve_column(ctx, ident_root, col_name, dtype)
1473 } else {
1474 polars_bail!(
1475 SQLInterface: "no column named '{}' found in table '{}'",
1476 name, ident_root
1477 )
1478 }
1479 },
1480 };
1481
1482 let (mut column, mut dtype) = col_dtype?;
1484 for ident in remaining_idents {
1485 let name = ident.value.as_str();
1486 match dtype {
1487 Some(DataType::Struct(fields)) if name == "*" => {
1488 return Ok(fields
1489 .iter()
1490 .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1491 .collect());
1492 },
1493 Some(DataType::Struct(fields)) => {
1494 dtype = fields
1495 .iter()
1496 .find(|fld| fld.name == name)
1497 .map(|fld| &fld.dtype);
1498 },
1499 Some(dtype) if name == "*" => {
1500 polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1501 },
1502 _ => {
1503 dtype = None;
1504 },
1505 }
1506 column = column.struct_().field_by_name(name);
1507 }
1508 Ok(vec![column])
1509}