1use std::fmt::Display;
10use std::ops::Div;
11
12use polars_core::prelude::*;
13use polars_lazy::prelude::*;
14use polars_plan::prelude::typed_lit;
15use polars_plan::prelude::LiteralValue::Null;
16use polars_time::Duration;
17use rand::distributions::Alphanumeric;
18use rand::{thread_rng, Rng};
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21use sqlparser::ast::{
22 BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind,
23 DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident,
24 Interval, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField,
25 UnaryOperator, Value as SQLValue,
26};
27use sqlparser::dialect::GenericDialect;
28use sqlparser::parser::{Parser, ParserOptions};
29
30use crate::functions::SQLFunctionVisitor;
31use crate::types::{
32 bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
33};
34use crate::SQLContext;
35
36#[inline]
37#[cold]
38#[must_use]
39pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
41 PolarsError::SQLInterface(err.to_string().into())
42}
43
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
46pub enum SubqueryRestriction {
48 SingleColumn,
50 }
54
55pub(crate) struct SQLExprVisitor<'a> {
57 ctx: &'a mut SQLContext,
58 active_schema: Option<&'a Schema>,
59}
60
61impl SQLExprVisitor<'_> {
62 fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
63 let array_elements = elements
64 .iter()
65 .map(|e| match e {
66 SQLExpr::Value(v) => self.visit_any_value(v, None),
67 SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
68 SQLExpr::Value(v) => self.visit_any_value(v, Some(op)),
69 _ => Err(polars_err!(SQLInterface: "expression {:?} is not currently supported", e)),
70 },
71 SQLExpr::Array(_) => {
72 Err(polars_err!(SQLInterface: "nested array literals are not currently supported:\n{:?}", e))
76 },
77 _ => Err(polars_err!(SQLInterface: "expression {:?} is not currently supported", e)),
78 })
79 .collect::<PolarsResult<Vec<_>>>()?;
80
81 Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
82 }
83
84 fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
85 match expr {
86 SQLExpr::AllOp {
87 left,
88 compare_op,
89 right,
90 } => self.visit_all(left, compare_op, right),
91 SQLExpr::AnyOp {
92 left,
93 compare_op,
94 right,
95 is_some: _,
96 } => self.visit_any(left, compare_op, right),
97 SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
98 SQLExpr::Between {
99 expr,
100 negated,
101 low,
102 high,
103 } => self.visit_between(expr, *negated, low, high),
104 SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
105 SQLExpr::Cast {
106 kind,
107 expr,
108 data_type,
109 format,
110 } => self.visit_cast(expr, data_type, format, kind),
111 SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
112 SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
113 SQLExpr::Extract {
114 field,
115 syntax: _,
116 expr,
117 } => parse_extract_date_part(self.visit_expr(expr)?, field),
118 SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
119 SQLExpr::Function(function) => self.visit_function(function),
120 SQLExpr::Identifier(ident) => self.visit_identifier(ident),
121 SQLExpr::InList {
122 expr,
123 list,
124 negated,
125 } => {
126 let expr = self.visit_expr(expr)?;
127 let elems = self.visit_array_expr(list, false, Some(&expr))?;
128 let is_in = expr.is_in(elems);
129 Ok(if *negated { is_in.not() } else { is_in })
130 },
131 SQLExpr::InSubquery {
132 expr,
133 subquery,
134 negated,
135 } => self.visit_in_subquery(expr, subquery, *negated),
136 SQLExpr::Interval(interval) => self.visit_interval(interval),
137 SQLExpr::IsDistinctFrom(e1, e2) => {
138 Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
139 },
140 SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
141 SQLExpr::IsNotDistinctFrom(e1, e2) => {
142 Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
143 },
144 SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
145 SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
146 SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
147 SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
148 SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
149 SQLExpr::Like {
150 negated,
151 any,
152 expr,
153 pattern,
154 escape_char,
155 } => {
156 if *any {
157 polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
158 }
159 self.visit_like(*negated, expr, pattern, escape_char, false)
160 },
161 SQLExpr::ILike {
162 negated,
163 any,
164 expr,
165 pattern,
166 escape_char,
167 } => {
168 if *any {
169 polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
170 }
171 self.visit_like(*negated, expr, pattern, escape_char, true)
172 },
173 SQLExpr::Nested(expr) => self.visit_expr(expr),
174 SQLExpr::Position { expr, r#in } => Ok(
175 (self
177 .visit_expr(r#in)?
178 .str()
179 .find(self.visit_expr(expr)?, true)
180 + typed_lit(1u32))
181 .fill_null(typed_lit(0u32)),
182 ),
183 SQLExpr::RLike {
184 negated,
186 expr,
187 pattern,
188 regexp: _,
189 } => {
190 let matches = self
191 .visit_expr(expr)?
192 .str()
193 .contains(self.visit_expr(pattern)?, true);
194 Ok(if *negated { matches.not() } else { matches })
195 },
196 SQLExpr::Subscript { expr, subscript } => self.visit_subscript(expr, subscript),
197 SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
198 SQLExpr::Trim {
199 expr,
200 trim_where,
201 trim_what,
202 trim_characters,
203 } => self.visit_trim(expr, trim_where, trim_what, trim_characters),
204 SQLExpr::TypedString { data_type, value } => match data_type {
205 SQLDataType::Date => {
206 if is_iso_date(value) {
207 Ok(lit(value.as_str()).cast(DataType::Date))
208 } else {
209 polars_bail!(SQLSyntax: "invalid DATE literal '{}'", value)
210 }
211 },
212 SQLDataType::Time(None, TimezoneInfo::None) => {
213 if is_iso_time(value) {
214 Ok(lit(value.as_str()).str().to_time(StrptimeOptions {
215 strict: true,
216 ..Default::default()
217 }))
218 } else {
219 polars_bail!(SQLSyntax: "invalid TIME literal '{}'", value)
220 }
221 },
222 SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
223 if is_iso_datetime(value) {
224 Ok(lit(value.as_str()).str().to_datetime(
225 None,
226 None,
227 StrptimeOptions {
228 strict: true,
229 ..Default::default()
230 },
231 lit("latest"),
232 ))
233 } else {
234 let fn_name = match data_type {
235 SQLDataType::Timestamp(_, _) => "TIMESTAMP",
236 SQLDataType::Datetime(_) => "DATETIME",
237 _ => unreachable!(),
238 };
239 polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, value)
240 }
241 },
242 _ => {
243 polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
244 },
245 },
246 SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
247 SQLExpr::Value(value) => self.visit_literal(value),
248 SQLExpr::Wildcard(_) => Ok(Expr::Wildcard),
249 e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
250 other => {
251 polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
252 },
253 }
254 }
255
256 fn visit_subquery(
257 &mut self,
258 subquery: &Subquery,
259 restriction: SubqueryRestriction,
260 ) -> PolarsResult<Expr> {
261 if subquery.with.is_some() {
262 polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
263 }
264 let mut lf = self.ctx.execute_query_no_ctes(subquery)?;
265 let schema = self.ctx.get_frame_schema(&mut lf)?;
266
267 if restriction == SubqueryRestriction::SingleColumn {
268 if schema.len() != 1 {
269 polars_bail!(SQLSyntax: "SQL subquery returns more than one column");
270 }
271 let rand_string: String = thread_rng()
272 .sample_iter(&Alphanumeric)
273 .take(16)
274 .map(char::from)
275 .collect();
276
277 let schema_entry = schema.get_at_index(0);
278 if let Some((old_name, _)) = schema_entry {
279 let new_name = String::from(old_name.as_str()) + rand_string.as_str();
280 lf = lf.rename([old_name.to_string()], [new_name.clone()], true);
281 return Ok(Expr::SubPlan(
282 SpecialEq::new(Arc::new(lf.logical_plan)),
283 vec![new_name],
284 ));
285 }
286 };
287 polars_bail!(SQLInterface: "subquery type not supported");
288 }
289
290 fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
294 Ok(col(ident.value.as_str()))
295 }
296
297 fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
301 Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
302 }
303
304 fn visit_interval(&self, interval: &Interval) -> PolarsResult<Expr> {
305 if interval.last_field.is_some()
306 || interval.leading_field.is_some()
307 || interval.leading_precision.is_some()
308 || interval.fractional_seconds_precision.is_some()
309 {
310 polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
311 }
312 let s = match &*interval.value {
313 SQLExpr::UnaryOp { .. } => {
314 polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
315 },
316 SQLExpr::Value(SQLValue::SingleQuotedString(s)) => Some(s),
317 _ => None,
318 };
319 match s {
320 Some(s) if s.contains('-') => {
321 polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
322 },
323 Some(s) => Ok(lit(Duration::parse_interval(s))),
324 None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
325 }
326 }
327
328 fn visit_like(
329 &mut self,
330 negated: bool,
331 expr: &SQLExpr,
332 pattern: &SQLExpr,
333 escape_char: &Option<String>,
334 case_insensitive: bool,
335 ) -> PolarsResult<Expr> {
336 if escape_char.is_some() {
337 polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
338 }
339 let pat = match self.visit_expr(pattern) {
340 Ok(Expr::Literal(LiteralValue::String(s))) => s,
341 _ => {
342 polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
343 },
344 };
345 if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
346 let op = if negated {
348 BinaryOperator::NotEq
349 } else {
350 BinaryOperator::Eq
351 };
352 self.visit_binary_op(expr, &op, pattern)
353 } else {
354 let mut rx = regex::escape(pat.as_str())
356 .replace('%', ".*")
357 .replace('_', ".");
358
359 rx = format!(
360 "^{}{}$",
361 if case_insensitive { "(?is)" } else { "(?s)" },
362 rx
363 );
364
365 let expr = self.visit_expr(expr)?;
366 let matches = expr.str().contains(lit(rx), true);
367 Ok(if negated { matches.not() } else { matches })
368 }
369 }
370
371 fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
372 let expr = self.visit_expr(expr)?;
373 Ok(match subscript {
374 Subscript::Index { index } => {
375 let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
376 expr.list().get(idx, true)
377 },
378 Subscript::Slice { .. } => {
379 polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
380 },
381 })
382 }
383
384 fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
388 if let (Some(name), Some(s), expr_dtype) = match (left, right) {
389 (Expr::Column(name), Expr::Literal(LiteralValue::String(s))) => {
391 (Some(name.clone()), Some(s), None)
392 },
393 (Expr::Cast { expr, dtype, .. }, Expr::Literal(LiteralValue::String(s))) => {
395 match &**expr {
396 Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
397 _ => (None, Some(s), Some(dtype)),
398 }
399 },
400 _ => (None, None, None),
401 } {
402 if expr_dtype.is_none() && self.active_schema.is_none() {
403 right.clone()
404 } else {
405 let left_dtype = expr_dtype.or_else(|| {
406 self.active_schema
407 .as_ref()
408 .and_then(|schema| schema.get(&name))
409 });
410 match left_dtype {
411 Some(DataType::Time) if is_iso_time(s) => {
412 right.clone().str().to_time(StrptimeOptions {
413 strict: true,
414 ..Default::default()
415 })
416 },
417 Some(DataType::Date) if is_iso_date(s) => {
418 right.clone().str().to_date(StrptimeOptions {
419 strict: true,
420 ..Default::default()
421 })
422 },
423 Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
424 if s.len() == 10 {
425 lit(format!("{}T00:00:00", s))
427 } else {
428 lit(s.replacen(' ', "T", 1))
429 }
430 .str()
431 .to_datetime(
432 Some(*tu),
433 tz.clone(),
434 StrptimeOptions {
435 strict: true,
436 ..Default::default()
437 },
438 lit("latest"),
439 )
440 },
441 _ => right.clone(),
442 }
443 }
444 } else {
445 right.clone()
446 }
447 }
448
449 fn struct_field_access_expr(
450 &mut self,
451 expr: &Expr,
452 path: &str,
453 infer_index: bool,
454 ) -> PolarsResult<Expr> {
455 let path_elems = if path.starts_with('{') && path.ends_with('}') {
456 path.trim_matches(|c| c == '{' || c == '}')
457 } else {
458 path
459 }
460 .split(',');
461
462 let mut expr = expr.clone();
463 for p in path_elems {
464 let p = p.trim();
465 expr = if infer_index {
466 match p.parse::<i64>() {
467 Ok(idx) => expr.list().get(lit(idx), true),
468 Err(_) => expr.struct_().field_by_name(p),
469 }
470 } else {
471 expr.struct_().field_by_name(p)
472 }
473 }
474 Ok(expr)
475 }
476
477 fn visit_binary_op(
481 &mut self,
482 left: &SQLExpr,
483 op: &BinaryOperator,
484 right: &SQLExpr,
485 ) -> PolarsResult<Expr> {
486 let lhs = self.visit_expr(left)?;
487 let mut rhs = self.visit_expr(right)?;
488 rhs = self.convert_temporal_strings(&lhs, &rhs);
489
490 Ok(match op {
491 SQLBinaryOperator::BitwiseAnd => lhs.and(rhs), SQLBinaryOperator::BitwiseOr => lhs.or(rhs), SQLBinaryOperator::Xor => lhs.xor(rhs), SQLBinaryOperator::And => lhs.and(rhs), SQLBinaryOperator::Divide => lhs / rhs, SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), SQLBinaryOperator::Eq => lhs.eq(rhs), SQLBinaryOperator::Gt => lhs.gt(rhs), SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), SQLBinaryOperator::Lt => lhs.lt(rhs), SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), SQLBinaryOperator::Minus => lhs - rhs, SQLBinaryOperator::Modulo => lhs % rhs, SQLBinaryOperator::Multiply => lhs * rhs, SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), SQLBinaryOperator::Or => lhs.or(rhs), SQLBinaryOperator::Plus => lhs + rhs, SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), SQLBinaryOperator::StringConcat => { lhs.cast(DataType::String) + rhs.cast(DataType::String)
518 },
519 SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), SQLBinaryOperator::PGRegexMatch => match rhs { Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true),
525 _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
526 },
527 SQLBinaryOperator::PGRegexNotMatch => match rhs { Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true).not(),
529 _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
530 },
531 SQLBinaryOperator::PGRegexIMatch => match rhs { Expr::Literal(LiteralValue::String(pat)) => {
533 lhs.str().contains(lit(format!("(?i){}", pat)), true)
534 },
535 _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
536 },
537 SQLBinaryOperator::PGRegexNotIMatch => match rhs { Expr::Literal(LiteralValue::String(pat)) => {
539 lhs.str().contains(lit(format!("(?i){}", pat)), true).not()
540 },
541 _ => {
542 polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
543 },
544 },
545 SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch | SQLBinaryOperator::PGILikeMatch | SQLBinaryOperator::PGNotILikeMatch => { let expr = if matches!(
553 op,
554 SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
555 ) {
556 SQLExpr::Like {
557 negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
558 any: false,
559 expr: Box::new(left.clone()),
560 pattern: Box::new(right.clone()),
561 escape_char: None,
562 }
563 } else {
564 SQLExpr::ILike {
565 negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
566 any: false,
567 expr: Box::new(left.clone()),
568 pattern: Box::new(right.clone()),
569 escape_char: None,
570 }
571 };
572 self.visit_expr(&expr)?
573 },
574 SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { Expr::Literal(LiteralValue::String(path)) => {
579 let mut expr = self.struct_field_access_expr(&lhs, &path, false)?;
580 if let SQLBinaryOperator::LongArrow = op {
581 expr = expr.cast(DataType::String);
582 }
583 expr
584 },
585 Expr::Literal(LiteralValue::Int(idx)) => {
586 let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
587 if let SQLBinaryOperator::LongArrow = op {
588 expr = expr.cast(DataType::String);
589 }
590 expr
591 },
592 _ => {
593 polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
594 },
595 },
596 SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { if let Expr::Literal(LiteralValue::String(path)) = rhs {
598 let mut expr = self.struct_field_access_expr(&lhs, &path, true)?;
599 if let SQLBinaryOperator::HashLongArrow = op {
600 expr = expr.cast(DataType::String);
601 }
602 expr
603 } else {
604 polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
605 }
606 },
607 other => {
608 polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
609 },
610 })
611 }
612
613 fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
617 let expr = self.visit_expr(expr)?;
618 Ok(match (op, expr.clone()) {
619 (UnaryOperator::Plus, Expr::Literal(LiteralValue::Int(n))) => lit(n),
621 (UnaryOperator::Plus, Expr::Literal(LiteralValue::Float(n))) => lit(n),
622 (UnaryOperator::Minus, Expr::Literal(LiteralValue::Int(n))) => lit(-n),
623 (UnaryOperator::Minus, Expr::Literal(LiteralValue::Float(n))) => lit(-n),
624 (UnaryOperator::Plus, _) => lit(0) + expr,
626 (UnaryOperator::Minus, _) => lit(0) - expr,
627 (UnaryOperator::Not, _) => expr.not(),
628 other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
629 })
630 }
631
632 fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
638 let mut visitor = SQLFunctionVisitor {
639 func: function,
640 ctx: self.ctx,
641 active_schema: self.active_schema,
642 };
643 visitor.visit_function()
644 }
645
646 fn visit_all(
650 &mut self,
651 left: &SQLExpr,
652 compare_op: &BinaryOperator,
653 right: &SQLExpr,
654 ) -> PolarsResult<Expr> {
655 let left = self.visit_expr(left)?;
656 let right = self.visit_expr(right)?;
657
658 match compare_op {
659 BinaryOperator::Gt => Ok(left.gt(right.max())),
660 BinaryOperator::Lt => Ok(left.lt(right.min())),
661 BinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
662 BinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
663 BinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
664 BinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
665 _ => polars_bail!(SQLInterface: "invalid comparison operator"),
666 }
667 }
668
669 fn visit_any(
673 &mut self,
674 left: &SQLExpr,
675 compare_op: &BinaryOperator,
676 right: &SQLExpr,
677 ) -> PolarsResult<Expr> {
678 let left = self.visit_expr(left)?;
679 let right = self.visit_expr(right)?;
680
681 match compare_op {
682 BinaryOperator::Gt => Ok(left.gt(right.min())),
683 BinaryOperator::Lt => Ok(left.lt(right.max())),
684 BinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
685 BinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
686 BinaryOperator::Eq => Ok(left.is_in(right)),
687 BinaryOperator::NotEq => Ok(left.is_in(right).not()),
688 _ => polars_bail!(SQLInterface: "invalid comparison operator"),
689 }
690 }
691
692 fn visit_array_expr(
694 &mut self,
695 elements: &[SQLExpr],
696 result_as_element: bool,
697 dtype_expr_match: Option<&Expr>,
698 ) -> PolarsResult<Expr> {
699 let mut elems = self.array_expr_to_series(elements)?;
700
701 if let (Some(Expr::Column(name)), Some(schema)) =
704 (dtype_expr_match, self.active_schema.as_ref())
705 {
706 if elems.dtype() == &DataType::String {
707 if let Some(dtype) = schema.get(name) {
708 if matches!(
709 dtype,
710 DataType::Date | DataType::Time | DataType::Datetime(_, _)
711 ) {
712 elems = elems.strict_cast(dtype)?;
713 }
714 }
715 }
716 }
717
718 let res = if result_as_element {
721 elems.implode()?.into_series()
722 } else {
723 elems
724 };
725 Ok(lit(res))
726 }
727
728 fn visit_cast(
732 &mut self,
733 expr: &SQLExpr,
734 dtype: &SQLDataType,
735 format: &Option<CastFormat>,
736 cast_kind: &CastKind,
737 ) -> PolarsResult<Expr> {
738 if format.is_some() {
739 return Err(
740 polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
741 );
742 }
743 let expr = self.visit_expr(expr)?;
744
745 #[cfg(feature = "json")]
746 if dtype == &SQLDataType::JSON {
747 return Ok(expr.str().json_decode(None, None));
748 }
749 let polars_type = map_sql_dtype_to_polars(dtype)?;
750 Ok(match cast_kind {
751 CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
752 CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
753 })
754 }
755
756 fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
762 Ok(match value {
764 SQLValue::Boolean(b) => lit(*b),
765 SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
766 #[cfg(feature = "binary_encoding")]
767 SQLValue::HexStringLiteral(x) => {
768 if x.len() % 2 != 0 {
769 polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
770 };
771 lit(hex::decode(x.clone()).unwrap())
772 },
773 SQLValue::Null => Expr::Literal(LiteralValue::Null),
774 SQLValue::Number(s, _) => {
775 if s.contains('.') {
777 s.parse::<f64>().map(lit).map_err(|_| ())
778 } else {
779 s.parse::<i64>().map(lit).map_err(|_| ())
780 }
781 .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
782 },
783 SQLValue::SingleQuotedByteStringLiteral(b) => {
784 bitstring_to_bytes_literal(b)?
788 },
789 SQLValue::SingleQuotedString(s) => lit(s.clone()),
790 other => {
791 polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
792 },
793 })
794 }
795
796 fn visit_any_value(
798 &self,
799 value: &SQLValue,
800 op: Option<&UnaryOperator>,
801 ) -> PolarsResult<AnyValue> {
802 Ok(match value {
803 SQLValue::Boolean(b) => AnyValue::Boolean(*b),
804 SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
805 #[cfg(feature = "binary_encoding")]
806 SQLValue::HexStringLiteral(x) => {
807 if x.len() % 2 != 0 {
808 polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
809 };
810 AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
811 },
812 SQLValue::Null => AnyValue::Null,
813 SQLValue::Number(s, _) => {
814 let negate = match op {
815 Some(UnaryOperator::Minus) => true,
816 Some(UnaryOperator::Plus) | None => false,
818 Some(op) => {
819 polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
820 },
821 };
822 if s.contains('.') {
824 s.parse::<f64>()
825 .map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
826 .map_err(|_| ())
827 } else {
828 s.parse::<i64>()
829 .map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
830 .map_err(|_| ())
831 }
832 .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
833 },
834 SQLValue::SingleQuotedByteStringLiteral(b) => {
835 let bytes_literal = bitstring_to_bytes_literal(b)?;
837 match bytes_literal {
838 Expr::Literal(LiteralValue::Binary(v)) => AnyValue::BinaryOwned(v.to_vec()),
839 _ => {
840 polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
841 },
842 }
843 },
844 SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
845 other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
846 })
847 }
848
849 fn visit_between(
852 &mut self,
853 expr: &SQLExpr,
854 negated: bool,
855 low: &SQLExpr,
856 high: &SQLExpr,
857 ) -> PolarsResult<Expr> {
858 let expr = self.visit_expr(expr)?;
859 let low = self.visit_expr(low)?;
860 let high = self.visit_expr(high)?;
861
862 let low = self.convert_temporal_strings(&expr, &low);
863 let high = self.convert_temporal_strings(&expr, &high);
864 Ok(if negated {
865 expr.clone().lt(low).or(expr.gt(high))
866 } else {
867 expr.clone().gt_eq(low).and(expr.lt_eq(high))
868 })
869 }
870
871 fn visit_trim(
874 &mut self,
875 expr: &SQLExpr,
876 trim_where: &Option<TrimWhereField>,
877 trim_what: &Option<Box<SQLExpr>>,
878 trim_characters: &Option<Vec<SQLExpr>>,
879 ) -> PolarsResult<Expr> {
880 if trim_characters.is_some() {
881 return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
883 };
884 let expr = self.visit_expr(expr)?;
885 let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
886 let trim_what = match trim_what {
887 Some(Expr::Literal(LiteralValue::String(val))) => Some(val),
888 None => None,
889 _ => return self.err(&expr),
890 };
891 Ok(match (trim_where, trim_what) {
892 (None | Some(TrimWhereField::Both), None) => expr.str().strip_chars(lit(Null)),
893 (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
894 (Some(TrimWhereField::Leading), None) => expr.str().strip_chars_start(lit(Null)),
895 (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
896 (Some(TrimWhereField::Trailing), None) => expr.str().strip_chars_end(lit(Null)),
897 (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
898 })
899 }
900
901 fn visit_in_subquery(
903 &mut self,
904 expr: &SQLExpr,
905 subquery: &Subquery,
906 negated: bool,
907 ) -> PolarsResult<Expr> {
908 let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?;
909 let expr = self.visit_expr(expr)?;
910 Ok(if negated {
911 expr.is_in(subquery_result).not()
912 } else {
913 expr.is_in(subquery_result)
914 })
915 }
916
917 fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
919 if let SQLExpr::Case {
920 operand,
921 conditions,
922 results,
923 else_result,
924 } = expr
925 {
926 polars_ensure!(
927 conditions.len() == results.len(),
928 SQLSyntax: "WHEN and THEN expressions must have the same length"
929 );
930 polars_ensure!(
931 !conditions.is_empty(),
932 SQLSyntax: "WHEN and THEN expressions must have at least one element"
933 );
934
935 let mut when_thens = conditions.iter().zip(results.iter());
936 let first = when_thens.next();
937 if first.is_none() {
938 polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
939 }
940 let else_res = match else_result {
941 Some(else_res) => self.visit_expr(else_res)?,
942 None => lit(Null), };
944 if let Some(operand_expr) = operand {
945 let first_operand_expr = self.visit_expr(operand_expr)?;
946
947 let first = first.unwrap();
948 let first_cond = first_operand_expr.eq(self.visit_expr(first.0)?);
949 let first_then = self.visit_expr(first.1)?;
950 let expr = when(first_cond).then(first_then);
951 let next = when_thens.next();
952
953 let mut when_then = if let Some((cond, res)) = next {
954 let second_operand_expr = self.visit_expr(operand_expr)?;
955 let cond = second_operand_expr.eq(self.visit_expr(cond)?);
956 let res = self.visit_expr(res)?;
957 expr.when(cond).then(res)
958 } else {
959 return Ok(expr.otherwise(else_res));
960 };
961 for (cond, res) in when_thens {
962 let new_operand_expr = self.visit_expr(operand_expr)?;
963 let cond = new_operand_expr.eq(self.visit_expr(cond)?);
964 let res = self.visit_expr(res)?;
965 when_then = when_then.when(cond).then(res);
966 }
967 return Ok(when_then.otherwise(else_res));
968 }
969
970 let first = first.unwrap();
971 let first_cond = self.visit_expr(first.0)?;
972 let first_then = self.visit_expr(first.1)?;
973 let expr = when(first_cond).then(first_then);
974 let next = when_thens.next();
975
976 let mut when_then = if let Some((cond, res)) = next {
977 let cond = self.visit_expr(cond)?;
978 let res = self.visit_expr(res)?;
979 expr.when(cond).then(res)
980 } else {
981 return Ok(expr.otherwise(else_res));
982 };
983 for (cond, res) in when_thens {
984 let cond = self.visit_expr(cond)?;
985 let res = self.visit_expr(res)?;
986 when_then = when_then.when(cond).then(res);
987 }
988 Ok(when_then.otherwise(else_res))
989 } else {
990 unreachable!()
991 }
992 }
993
994 fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
995 polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
996 }
997}
998
999pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1017 let mut ctx = SQLContext::new();
1018
1019 let mut parser = Parser::new(&GenericDialect);
1020 parser = parser.with_options(ParserOptions {
1021 trailing_commas: true,
1022 ..Default::default()
1023 });
1024
1025 let mut ast = parser
1026 .try_with_sql(s.as_ref())
1027 .map_err(to_sql_interface_err)?;
1028 let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1029
1030 Ok(match &expr {
1031 SelectItem::ExprWithAlias { expr, alias } => {
1032 let expr = parse_sql_expr(expr, &mut ctx, None)?;
1033 expr.alias(alias.value.as_str())
1034 },
1035 SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1036 _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1037 })
1038}
1039
1040pub(crate) fn parse_sql_expr(
1041 expr: &SQLExpr,
1042 ctx: &mut SQLContext,
1043 active_schema: Option<&Schema>,
1044) -> PolarsResult<Expr> {
1045 let mut visitor = SQLExprVisitor { ctx, active_schema };
1046 visitor.visit_expr(expr)
1047}
1048
1049pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1050 match expr {
1051 SQLExpr::Array(arr) => {
1052 let mut visitor = SQLExprVisitor {
1053 ctx,
1054 active_schema: None,
1055 };
1056 visitor.array_expr_to_series(arr.elem.as_slice())
1057 },
1058 _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1059 }
1060}
1061
1062pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1063 let field = match field {
1064 DateTimeField::Custom(Ident { value, .. }) => {
1066 let value = value.to_ascii_lowercase();
1067 match value.as_str() {
1068 "millennium" | "millennia" => &DateTimeField::Millennium,
1069 "century" | "centuries" => &DateTimeField::Century,
1070 "decade" | "decades" => &DateTimeField::Decade,
1071 "isoyear" => &DateTimeField::Isoyear,
1072 "year" | "years" | "y" => &DateTimeField::Year,
1073 "quarter" | "quarters" => &DateTimeField::Quarter,
1074 "month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1075 "dayofyear" | "doy" => &DateTimeField::DayOfYear,
1076 "dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1077 "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1078 "isodow" => &DateTimeField::Isodow,
1079 "day" | "days" | "d" => &DateTimeField::Day,
1080 "hour" | "hours" | "h" => &DateTimeField::Hour,
1081 "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1082 "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1083 "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1084 "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1085 "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1086 #[cfg(feature = "timezones")]
1087 "timezone" => &DateTimeField::Timezone,
1088 "time" => &DateTimeField::Time,
1089 "epoch" => &DateTimeField::Epoch,
1090 _ => {
1091 polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1092 },
1093 }
1094 },
1095 _ => field,
1096 };
1097 Ok(match field {
1098 DateTimeField::Millennium => expr.dt().millennium(),
1099 DateTimeField::Century => expr.dt().century(),
1100 DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1101 DateTimeField::Isoyear => expr.dt().iso_year(),
1102 DateTimeField::Year => expr.dt().year(),
1103 DateTimeField::Quarter => expr.dt().quarter(),
1104 DateTimeField::Month => expr.dt().month(),
1105 DateTimeField::Week(weekday) => {
1106 if weekday.is_some() {
1107 polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1108 }
1109 expr.dt().week()
1110 },
1111 DateTimeField::IsoWeek => expr.dt().week(),
1112 DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1113 DateTimeField::DayOfWeek | DateTimeField::Dow => {
1114 let w = expr.dt().weekday();
1115 when(w.clone().eq(typed_lit(7i8)))
1116 .then(typed_lit(0i8))
1117 .otherwise(w)
1118 },
1119 DateTimeField::Isodow => expr.dt().weekday(),
1120 DateTimeField::Day => expr.dt().day(),
1121 DateTimeField::Hour => expr.dt().hour(),
1122 DateTimeField::Minute => expr.dt().minute(),
1123 DateTimeField::Second => expr.dt().second(),
1124 DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1125 (expr.clone().dt().second() * typed_lit(1_000f64))
1126 + expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1127 },
1128 DateTimeField::Microsecond | DateTimeField::Microseconds => {
1129 (expr.clone().dt().second() * typed_lit(1_000_000f64))
1130 + expr.dt().nanosecond().div(typed_lit(1_000f64))
1131 },
1132 DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1133 (expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1134 },
1135 DateTimeField::Time => expr.dt().time(),
1136 #[cfg(feature = "timezones")]
1137 DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(),
1138 DateTimeField::Epoch => {
1139 expr.clone()
1140 .dt()
1141 .timestamp(TimeUnit::Nanoseconds)
1142 .div(typed_lit(1_000_000_000i64))
1143 + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1144 },
1145 _ => {
1146 polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1147 },
1148 })
1149}
1150
1151pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1154 match idx {
1155 Expr::Literal(Null) => lit(Null),
1156 Expr::Literal(LiteralValue::Int(0)) => {
1157 if null_if_zero {
1158 lit(Null)
1159 } else {
1160 idx
1161 }
1162 },
1163 Expr::Literal(LiteralValue::Int(n)) if n < 0 => idx,
1164 Expr::Literal(LiteralValue::Int(n)) => lit(n - 1),
1165 _ => when(idx.clone().gt(lit(0)))
1168 .then(idx.clone() - lit(1))
1169 .otherwise(if null_if_zero {
1170 when(idx.clone().eq(lit(0)))
1171 .then(lit(Null))
1172 .otherwise(idx.clone())
1173 } else {
1174 idx.clone()
1175 }),
1176 }
1177}
1178
1179fn resolve_column<'a>(
1180 ctx: &'a mut SQLContext,
1181 ident_root: &'a Ident,
1182 name: &'a str,
1183 dtype: &'a DataType,
1184) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1185 let resolved = ctx.resolve_name(&ident_root.value, name);
1186 let resolved = resolved.as_str();
1187 Ok((
1188 if name != resolved {
1189 col(resolved).alias(name)
1190 } else {
1191 col(name)
1192 },
1193 Some(dtype),
1194 ))
1195}
1196
1197pub(crate) fn resolve_compound_identifier(
1198 ctx: &mut SQLContext,
1199 idents: &[Ident],
1200 active_schema: Option<&Schema>,
1201) -> PolarsResult<Vec<Expr>> {
1202 let ident_root = &idents[0];
1204 let mut remaining_idents = idents.iter().skip(1);
1205 let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1206
1207 let schema = if let Some(ref mut lf) = lf {
1208 lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)
1209 } else {
1210 Ok(Arc::new(if let Some(active_schema) = active_schema {
1211 active_schema.clone()
1212 } else {
1213 Schema::default()
1214 }))
1215 }?;
1216
1217 let col_dtype: PolarsResult<(Expr, Option<&DataType>)> = if lf.is_none() && schema.is_empty() {
1218 Ok((col(ident_root.value.as_str()), None))
1219 } else {
1220 let name = &remaining_idents.next().unwrap().value;
1221 if lf.is_some() && name == "*" {
1222 return Ok(schema
1223 .iter_names_and_dtypes()
1224 .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).unwrap().0)
1225 .collect::<Vec<_>>());
1226 } else if let Some((_, name, dtype)) = schema.get_full(name) {
1227 resolve_column(ctx, ident_root, name, dtype)
1228 } else if lf.is_none() {
1229 remaining_idents = idents.iter().skip(1);
1230 Ok((
1231 col(ident_root.value.as_str()),
1232 schema.get(&ident_root.value),
1233 ))
1234 } else {
1235 polars_bail!(
1236 SQLInterface: "no column named '{}' found in table '{}'",
1237 name,
1238 ident_root
1239 )
1240 }
1241 };
1242
1243 let (mut column, mut dtype) = col_dtype?;
1245 for ident in remaining_idents {
1246 let name = ident.value.as_str();
1247 match dtype {
1248 Some(DataType::Struct(fields)) if name == "*" => {
1249 return Ok(fields
1250 .iter()
1251 .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1252 .collect())
1253 },
1254 Some(DataType::Struct(fields)) => {
1255 dtype = fields
1256 .iter()
1257 .find(|fld| fld.name == name)
1258 .map(|fld| &fld.dtype);
1259 },
1260 Some(dtype) if name == "*" => {
1261 polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1262 },
1263 _ => {
1264 dtype = None;
1265 },
1266 }
1267 column = column.struct_().field_by_name(name);
1268 }
1269 Ok(vec![column])
1270}