polars_sql/
functions.rs

1use std::ops::Sub;
2
3use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions};
4use polars_core::prelude::{
5    DataType, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail, polars_err,
6};
7use polars_lazy::dsl::Expr;
8use polars_ops::chunked_array::UnicodeForm;
9use polars_ops::series::RoundMode;
10use polars_plan::dsl::{coalesce, concat_str, len, max_horizontal, min_horizontal, when};
11use polars_plan::plans::{DynLiteralValue, LiteralValue, typed_lit};
12use polars_plan::prelude::{StrptimeOptions, col, cols, lit};
13use polars_utils::pl_str::PlSmallStr;
14use sqlparser::ast::helpers::attached_token::AttachedToken;
15use sqlparser::ast::{
16    DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg,
17    FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident,
18    OrderByExpr, Value as SQLValue, WindowSpec, WindowType,
19};
20use sqlparser::tokenizer::Span;
21
22use crate::SQLContext;
23use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr};
24
25pub(crate) struct SQLFunctionVisitor<'a> {
26    pub(crate) func: &'a SQLFunction,
27    pub(crate) ctx: &'a mut SQLContext,
28    pub(crate) active_schema: Option<&'a Schema>,
29}
30
31/// SQL functions that are supported by Polars
32pub(crate) enum PolarsSQLFunctions {
33    // ----
34    // Bitwise functions
35    // ----
36    /// SQL 'bit_and' function.
37    /// Returns the bitwise AND of the input expressions.
38    /// ```sql
39    /// SELECT BIT_AND(column_1, column_2) FROM df;
40    /// ```
41    BitAnd,
42    /// SQL 'bit_count' function.
43    /// Returns the number of set bits in the input expression.
44    /// ```sql
45    /// SELECT BIT_COUNT(column_1) FROM df;
46    /// ```
47    #[cfg(feature = "bitwise")]
48    BitCount,
49    /// SQL 'bit_or' function.
50    /// Returns the bitwise OR of the input expressions.
51    /// ```sql
52    /// SELECT BIT_OR(column_1, column_2) FROM df;
53    /// ```
54    BitOr,
55    /// SQL 'bit_xor' function.
56    /// Returns the bitwise XOR of the input expressions.
57    /// ```sql
58    /// SELECT BIT_XOR(column_1, column_2) FROM df;
59    /// ```
60    BitXor,
61
62    // ----
63    // Math functions
64    // ----
65    /// SQL 'abs' function.
66    /// Returns the absolute value of the input expression.
67    /// ```sql
68    /// SELECT ABS(column_1) FROM df;
69    /// ```
70    Abs,
71    /// SQL 'ceil' function.
72    /// Returns the nearest integer closest from zero.
73    /// ```sql
74    /// SELECT CEIL(column_1) FROM df;
75    /// ```
76    Ceil,
77    /// SQL 'div' function.
78    /// Returns the integer quotient of the division.
79    /// ```sql
80    /// SELECT DIV(column_1, 2) FROM df;
81    /// ```
82    Div,
83    /// SQL 'exp' function.
84    /// Computes the exponential of the given value.
85    /// ```sql
86    /// SELECT EXP(column_1) FROM df;
87    /// ```
88    Exp,
89    /// SQL 'floor' function.
90    /// Returns the nearest integer away from zero.
91    ///   0.5 will be rounded
92    /// ```sql
93    /// SELECT FLOOR(column_1) FROM df;
94    /// ```
95    Floor,
96    /// SQL 'pi' function.
97    /// Returns a (very good) approximation of 𝜋.
98    /// ```sql
99    /// SELECT PI() FROM df;
100    /// ```
101    Pi,
102    /// SQL 'ln' function.
103    /// Computes the natural logarithm of the given value.
104    /// ```sql
105    /// SELECT LN(column_1) FROM df;
106    /// ```
107    Ln,
108    /// SQL 'log2' function.
109    /// Computes the logarithm of the given value in base 2.
110    /// ```sql
111    /// SELECT LOG2(column_1) FROM df;
112    /// ```
113    Log2,
114    /// SQL 'log10' function.
115    /// Computes the logarithm of the given value in base 10.
116    /// ```sql
117    /// SELECT LOG10(column_1) FROM df;
118    /// ```
119    Log10,
120    /// SQL 'log' function.
121    /// Computes the `base` logarithm of the given value.
122    /// ```sql
123    /// SELECT LOG(column_1, 10) FROM df;
124    /// ```
125    Log,
126    /// SQL 'log1p' function.
127    /// Computes the natural logarithm of "given value plus one".
128    /// ```sql
129    /// SELECT LOG1P(column_1) FROM df;
130    /// ```
131    Log1p,
132    /// SQL 'pow' function.
133    /// Returns the value to the power of the given exponent.
134    /// ```sql
135    /// SELECT POW(column_1, 2) FROM df;
136    /// ```
137    Pow,
138    /// SQL 'mod' function.
139    /// Returns the remainder of a numeric expression divided by another numeric expression.
140    /// ```sql
141    /// SELECT MOD(column_1, 2) FROM df;
142    /// ```
143    Mod,
144    /// SQL 'sqrt' function.
145    /// Returns the square root (√) of a number.
146    /// ```sql
147    /// SELECT SQRT(column_1) FROM df;
148    /// ```
149    Sqrt,
150    /// SQL 'cbrt' function.
151    /// Returns the cube root (∛) of a number.
152    /// ```sql
153    /// SELECT CBRT(column_1) FROM df;
154    /// ```
155    Cbrt,
156    /// SQL 'round' function.
157    /// Round a number to `x` decimals (default: 0) away from zero.
158    ///   .5 is rounded away from zero.
159    /// ```sql
160    /// SELECT ROUND(column_1, 3) FROM df;
161    /// ```
162    Round,
163    /// SQL 'sign' function.
164    /// Returns the sign of the argument as -1, 0, or +1.
165    /// ```sql
166    /// SELECT SIGN(column_1) FROM df;
167    /// ```
168    Sign,
169
170    // ----
171    // Trig functions
172    // ----
173    /// SQL 'cos' function.
174    /// Compute the cosine sine of the input expression (in radians).
175    /// ```sql
176    /// SELECT COS(column_1) FROM df;
177    /// ```
178    Cos,
179    /// SQL 'cot' function.
180    /// Compute the cotangent of the input expression (in radians).
181    /// ```sql
182    /// SELECT COT(column_1) FROM df;
183    /// ```
184    Cot,
185    /// SQL 'sin' function.
186    /// Compute the sine of the input expression (in radians).
187    /// ```sql
188    /// SELECT SIN(column_1) FROM df;
189    /// ```
190    Sin,
191    /// SQL 'tan' function.
192    /// Compute the tangent of the input expression (in radians).
193    /// ```sql
194    /// SELECT TAN(column_1) FROM df;
195    /// ```
196    Tan,
197    /// SQL 'cosd' function.
198    /// Compute the cosine sine of the input expression (in degrees).
199    /// ```sql
200    /// SELECT COSD(column_1) FROM df;
201    /// ```
202    CosD,
203    /// SQL 'cotd' function.
204    /// Compute cotangent of the input expression (in degrees).
205    /// ```sql
206    /// SELECT COTD(column_1) FROM df;
207    /// ```
208    CotD,
209    /// SQL 'sind' function.
210    /// Compute the sine of the input expression (in degrees).
211    /// ```sql
212    /// SELECT SIND(column_1) FROM df;
213    /// ```
214    SinD,
215    /// SQL 'tand' function.
216    /// Compute the tangent of the input expression (in degrees).
217    /// ```sql
218    /// SELECT TAND(column_1) FROM df;
219    /// ```
220    TanD,
221    /// SQL 'acos' function.
222    /// Compute inverse cosine of the input expression (in radians).
223    /// ```sql
224    /// SELECT ACOS(column_1) FROM df;
225    /// ```
226    Acos,
227    /// SQL 'asin' function.
228    /// Compute inverse sine of the input expression (in radians).
229    /// ```sql
230    /// SELECT ASIN(column_1) FROM df;
231    /// ```
232    Asin,
233    /// SQL 'atan' function.
234    /// Compute inverse tangent of the input expression (in radians).
235    /// ```sql
236    /// SELECT ATAN(column_1) FROM df;
237    /// ```
238    Atan,
239    /// SQL 'atan2' function.
240    /// Compute the inverse tangent of column_1/column_2 (in radians).
241    /// ```sql
242    /// SELECT ATAN2(column_1, column_2) FROM df;
243    /// ```
244    Atan2,
245    /// SQL 'acosd' function.
246    /// Compute inverse cosine of the input expression (in degrees).
247    /// ```sql
248    /// SELECT ACOSD(column_1) FROM df;
249    /// ```
250    AcosD,
251    /// SQL 'asind' function.
252    /// Compute inverse sine of the input expression (in degrees).
253    /// ```sql
254    /// SELECT ASIND(column_1) FROM df;
255    /// ```
256    AsinD,
257    /// SQL 'atand' function.
258    /// Compute inverse tangent of the input expression (in degrees).
259    /// ```sql
260    /// SELECT ATAND(column_1) FROM df;
261    /// ```
262    AtanD,
263    /// SQL 'atan2d' function.
264    /// Compute the inverse tangent of column_1/column_2 (in degrees).
265    /// ```sql
266    /// SELECT ATAN2D(column_1) FROM df;
267    /// ```
268    Atan2D,
269    /// SQL 'degrees' function.
270    /// Convert between radians and degrees.
271    /// ```sql
272    /// SELECT DEGREES(column_1) FROM df;
273    /// ```
274    ///
275    ///
276    Degrees,
277    /// SQL 'RADIANS' function.
278    /// Convert between degrees and radians.
279    /// ```sql
280    /// SELECT RADIANS(column_1) FROM df;
281    /// ```
282    Radians,
283
284    // ----
285    // Temporal functions
286    // ----
287    /// SQL 'date_part' function.
288    /// Extracts a part of a date (or datetime) such as 'year', 'month', etc.
289    /// ```sql
290    /// SELECT DATE_PART('year', column_1) FROM df;
291    /// SELECT DATE_PART('day', column_1) FROM df;
292    DatePart,
293    /// SQL 'strftime' function.
294    /// Converts a datetime to a string using a format string.
295    /// ```sql
296    /// SELECT STRFTIME(column_1, '%d-%m-%Y %H:%M') FROM df;
297    /// ```
298    Strftime,
299
300    // ----
301    // String functions
302    // ----
303    /// SQL 'bit_length' function (bytes).
304    /// ```sql
305    /// SELECT BIT_LENGTH(column_1) FROM df;
306    /// ```
307    BitLength,
308    /// SQL 'concat' function.
309    /// Returns all input expressions concatenated together as a string.
310    /// ```sql
311    /// SELECT CONCAT(column_1, column_2) FROM df;
312    /// ```
313    Concat,
314    /// SQL 'concat_ws' function.
315    /// Returns all input expressions concatenated together
316    /// (and interleaved with a separator) as a string.
317    /// ```sql
318    /// SELECT CONCAT_WS(':', column_1, column_2, column_3) FROM df;
319    /// ```
320    ConcatWS,
321    /// SQL 'date' function.
322    /// Converts a formatted string date to an actual Date type; ISO-8601 format is assumed
323    /// unless a strftime-compatible formatting string is provided as the second parameter.
324    /// ```sql
325    /// SELECT DATE('2021-03-15') FROM df;
326    /// SELECT DATE('2021-15-03', '%Y-d%-%m') FROM df;
327    /// SELECT DATE('2021-03', '%Y-%m') FROM df;
328    /// ```
329    Date,
330    /// SQL 'ends_with' function.
331    /// Returns True if the value ends with the second argument.
332    /// ```sql
333    /// SELECT ENDS_WITH(column_1, 'a') FROM df;
334    /// SELECT column_2 from df WHERE ENDS_WITH(column_1, 'a');
335    /// ```
336    EndsWith,
337    /// SQL 'initcap' function.
338    /// Returns the value with the first letter capitalized.
339    /// ```sql
340    /// SELECT INITCAP(column_1) FROM df;
341    /// ```
342    #[cfg(feature = "nightly")]
343    InitCap,
344    /// SQL 'left' function.
345    /// Returns the first (leftmost) `n` characters.
346    /// ```sql
347    /// SELECT LEFT(column_1, 3) FROM df;
348    /// ```
349    Left,
350    /// SQL 'length' function (characters.
351    /// Returns the character length of the string.
352    /// ```sql
353    /// SELECT LENGTH(column_1) FROM df;
354    /// ```
355    Length,
356    /// SQL 'lower' function.
357    /// Returns an lowercased column.
358    /// ```sql
359    /// SELECT LOWER(column_1) FROM df;
360    /// ```
361    Lower,
362    /// SQL 'ltrim' function.
363    /// Strip whitespaces from the left.
364    /// ```sql
365    /// SELECT LTRIM(column_1) FROM df;
366    /// ```
367    LTrim,
368    /// SQL 'normalize' function.
369    /// Convert string to Unicode normalization form
370    /// (one of NFC, NFKC, NFD, or NFKD - unquoted).
371    /// ```sql
372    /// SELECT NORMALIZE(column_1, NFC) FROM df;
373    /// ```
374    Normalize,
375    /// SQL 'octet_length' function.
376    /// Returns the length of a given string in bytes.
377    /// ```sql
378    /// SELECT OCTET_LENGTH(column_1) FROM df;
379    /// ```
380    OctetLength,
381    /// SQL 'regexp_like' function.
382    /// True if `pattern` matches the value (optional: `flags`).
383    /// ```sql
384    /// SELECT REGEXP_LIKE(column_1, 'xyz', 'i') FROM df;
385    /// ```
386    RegexpLike,
387    /// SQL 'replace' function.
388    /// Replace a given substring with another string.
389    /// ```sql
390    /// SELECT REPLACE(column_1, 'old', 'new') FROM df;
391    /// ```
392    Replace,
393    /// SQL 'reverse' function.
394    /// Return the reversed string.
395    /// ```sql
396    /// SELECT REVERSE(column_1) FROM df;
397    /// ```
398    Reverse,
399    /// SQL 'right' function.
400    /// Returns the last (rightmost) `n` characters.
401    /// ```sql
402    /// SELECT RIGHT(column_1, 3) FROM df;
403    /// ```
404    Right,
405    /// SQL 'rtrim' function.
406    /// Strip whitespaces from the right.
407    /// ```sql
408    /// SELECT RTRIM(column_1) FROM df;
409    /// ```
410    RTrim,
411    /// SQL 'split_part' function.
412    /// Splits a string into an array of strings using the given delimiter
413    /// and returns the `n`-th part (1-indexed).
414    /// ```sql
415    /// SELECT SPLIT_PART(column_1, ',', 2) FROM df;
416    /// ```
417    SplitPart,
418    /// SQL 'starts_with' function.
419    /// Returns True if the value starts with the second argument.
420    /// ```sql
421    /// SELECT STARTS_WITH(column_1, 'a') FROM df;
422    /// SELECT column_2 from df WHERE STARTS_WITH(column_1, 'a');
423    /// ```
424    StartsWith,
425    /// SQL 'strpos' function.
426    /// Returns the index of the given substring in the target string.
427    /// ```sql
428    /// SELECT STRPOS(column_1,'xyz') FROM df;
429    /// ```
430    StrPos,
431    /// SQL 'substr' function.
432    /// Returns a portion of the data (first character = 1) in the range.
433    ///   \[start, start + length]
434    /// ```sql
435    /// SELECT SUBSTR(column_1, 3, 5) FROM df;
436    /// ```
437    Substring,
438    /// SQL 'string_to_array' function.
439    /// Splits a string into an array of strings using the given delimiter.
440    /// ```sql
441    /// SELECT STRING_TO_ARRAY(column_1, ',') FROM df;
442    /// ```
443    StringToArray,
444    /// SQL 'strptime' function.
445    /// Converts a string to a datetime using a format string.
446    /// ```sql
447    /// SELECT STRPTIME(column_1, '%d-%m-%Y %H:%M') FROM df;
448    /// ```
449    Strptime,
450    /// SQL 'time' function.
451    /// Converts a formatted string time to an actual Time type; ISO-8601 format is
452    /// assumed unless a strftime-compatible formatting string is provided as the second
453    /// parameter.
454    /// ```sql
455    /// SELECT TIME('10:30:45') FROM df;
456    /// SELECT TIME('20.30', '%H.%M') FROM df;
457    /// ```
458    Time,
459    /// SQL 'timestamp' function.
460    /// Converts a formatted string datetime to an actual Datetime type; ISO-8601 format is
461    /// assumed unless a strftime-compatible formatting string is provided as the second
462    /// parameter.
463    /// ```sql
464    /// SELECT TIMESTAMP('2021-03-15 10:30:45') FROM df;
465    /// SELECT TIMESTAMP('2021-15-03T00:01:02.333', '%Y-d%-%m %H:%M:%S') FROM df;
466    /// ```
467    Timestamp,
468    /// SQL 'upper' function.
469    /// Returns an uppercased column.
470    /// ```sql
471    /// SELECT UPPER(column_1) FROM df;
472    /// ```
473    Upper,
474
475    // ----
476    // Conditional functions
477    // ----
478    /// SQL 'coalesce' function.
479    /// Returns the first non-null value in the provided values/columns.
480    /// ```sql
481    /// SELECT COALESCE(column_1, ...) FROM df;
482    /// ```
483    Coalesce,
484    /// SQL 'greatest' function.
485    /// Returns the greatest value in the list of expressions.
486    /// ```sql
487    /// SELECT GREATEST(column_1, column_2, ...) FROM df;
488    /// ```
489    Greatest,
490    /// SQL 'if' function.
491    /// Returns expr1 if the boolean condition provided as the first
492    /// parameter evaluates to true, and expr2 otherwise.
493    /// ```sql
494    /// SELECT IF(column < 0, expr1, expr2) FROM df;
495    /// ```
496    If,
497    /// SQL 'ifnull' function.
498    /// If an expression value is NULL, return an alternative value.
499    /// ```sql
500    /// SELECT IFNULL(string_col, 'n/a') FROM df;
501    /// ```
502    IfNull,
503    /// SQL 'least' function.
504    /// Returns the smallest value in the list of expressions.
505    /// ```sql
506    /// SELECT LEAST(column_1, column_2, ...) FROM df;
507    /// ```
508    Least,
509    /// SQL 'nullif' function.
510    /// Returns NULL if two expressions are equal, otherwise returns the first.
511    /// ```sql
512    /// SELECT NULLIF(column_1, column_2) FROM df;
513    /// ```
514    NullIf,
515
516    // ----
517    // Aggregate functions
518    // ----
519    /// SQL 'avg' function.
520    /// Returns the average (mean) of all the elements in the grouping.
521    /// ```sql
522    /// SELECT AVG(column_1) FROM df;
523    /// ```
524    Avg,
525    /// SQL 'corr' function.
526    /// Returns the Pearson correlation coefficient between two columns.
527    /// ```sql
528    /// SELECT CORR(column_1, column_2) FROM df;
529    /// ```
530    Corr,
531    /// SQL 'count' function.
532    /// Returns the amount of elements in the grouping.
533    /// ```sql
534    /// SELECT COUNT(column_1) FROM df;
535    /// SELECT COUNT(*) FROM df;
536    /// SELECT COUNT(DISTINCT column_1) FROM df;
537    /// SELECT COUNT(DISTINCT *) FROM df;
538    /// ```
539    Count,
540    /// SQL 'covar_pop' function.
541    /// Returns the population covariance between two columns.
542    /// ```sql
543    /// SELECT COVAR_POP(column_1, column_2) FROM df;
544    /// ```
545    CovarPop,
546    /// SQL 'covar_samp' function.
547    /// Returns the sample covariance between two columns.
548    /// ```sql
549    /// SELECT COVAR_SAMP(column_1, column_2) FROM df;
550    /// ```
551    CovarSamp,
552    /// SQL 'first' function.
553    /// Returns the first element of the grouping.
554    /// ```sql
555    /// SELECT FIRST(column_1) FROM df;
556    /// ```
557    First,
558    /// SQL 'last' function.
559    /// Returns the last element of the grouping.
560    /// ```sql
561    /// SELECT LAST(column_1) FROM df;
562    /// ```
563    Last,
564    /// SQL 'max' function.
565    /// Returns the greatest (maximum) of all the elements in the grouping.
566    /// ```sql
567    /// SELECT MAX(column_1) FROM df;
568    /// ```
569    Max,
570    /// SQL 'median' function.
571    /// Returns the median element from the grouping.
572    /// ```sql
573    /// SELECT MEDIAN(column_1) FROM df;
574    /// ```
575    Median,
576    /// SQL 'quantile_cont' function.
577    /// Returns the continuous quantile element from the grouping
578    /// (interpolated value between two closest values).
579    /// ```sql
580    /// SELECT QUANTILE_CONT(column_1) FROM df;
581    /// ```
582    QuantileCont,
583    /// SQL 'quantile_disc' function.
584    /// Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value,
585    /// and returns the value associated with the subinterval where the quantile value falls.
586    /// ```sql
587    /// SELECT QUANTILE_DISC(column_1) FROM df;
588    /// ```
589    QuantileDisc,
590    /// SQL 'min' function.
591    /// Returns the smallest (minimum) of all the elements in the grouping.
592    /// ```sql
593    /// SELECT MIN(column_1) FROM df;
594    /// ```
595    Min,
596    /// SQL 'stddev' function.
597    /// Returns the standard deviation of all the elements in the grouping.
598    /// ```sql
599    /// SELECT STDDEV(column_1) FROM df;
600    /// ```
601    StdDev,
602    /// SQL 'sum' function.
603    /// Returns the sum of all the elements in the grouping.
604    /// ```sql
605    /// SELECT SUM(column_1) FROM df;
606    /// ```
607    Sum,
608    /// SQL 'variance' function.
609    /// Returns the variance of all the elements in the grouping.
610    /// ```sql
611    /// SELECT VARIANCE(column_1) FROM df;
612    /// ```
613    Variance,
614    // ----
615    // Array functions
616    // ----
617    /// SQL 'array_length' function.
618    /// Returns the length of the array.
619    /// ```sql
620    /// SELECT ARRAY_LENGTH(column_1) FROM df;
621    /// ```
622    ArrayLength,
623    /// SQL 'array_lower' function.
624    /// Returns the minimum value in an array; equivalent to `array_min`.
625    /// ```sql
626    /// SELECT ARRAY_LOWER(column_1) FROM df;
627    /// ```
628    ArrayMin,
629    /// SQL 'array_upper' function.
630    /// Returns the maximum value in an array; equivalent to `array_max`.
631    /// ```sql
632    /// SELECT ARRAY_UPPER(column_1) FROM df;
633    /// ```
634    ArrayMax,
635    /// SQL 'array_sum' function.
636    /// Returns the sum of all values in an array.
637    /// ```sql
638    /// SELECT ARRAY_SUM(column_1) FROM df;
639    /// ```
640    ArraySum,
641    /// SQL 'array_mean' function.
642    /// Returns the mean of all values in an array.
643    /// ```sql
644    /// SELECT ARRAY_MEAN(column_1) FROM df;
645    /// ```
646    ArrayMean,
647    /// SQL 'array_reverse' function.
648    /// Returns the array with the elements in reverse order.
649    /// ```sql
650    /// SELECT ARRAY_REVERSE(column_1) FROM df;
651    /// ```
652    ArrayReverse,
653    /// SQL 'array_unique' function.
654    /// Returns the array with the unique elements.
655    /// ```sql
656    /// SELECT ARRAY_UNIQUE(column_1) FROM df;
657    /// ```
658    ArrayUnique,
659    /// SQL 'unnest' function.
660    /// Unnest/explodes an array column into multiple rows.
661    /// ```sql
662    /// SELECT unnest(column_1) FROM df;
663    /// ```
664    Explode,
665    /// SQL 'array_agg' function.
666    /// Concatenates the input expressions, including nulls, into an array.
667    /// ```sql
668    /// SELECT ARRAY_AGG(column_1, column_2, ...) FROM df;
669    /// ```
670    ArrayAgg,
671    /// SQL 'array_to_string' function.
672    /// Takes all elements of the array and joins them into one string.
673    /// ```sql
674    /// SELECT ARRAY_TO_STRING(column_1, ',') FROM df;
675    /// SELECT ARRAY_TO_STRING(column_1, ',', 'n/a') FROM df;
676    /// ```
677    ArrayToString,
678    /// SQL 'array_get' function.
679    /// Returns the value at the given index in the array.
680    /// ```sql
681    /// SELECT ARRAY_GET(column_1, 1) FROM df;
682    /// ```
683    ArrayGet,
684    /// SQL 'array_contains' function.
685    /// Returns true if the array contains the value.
686    /// ```sql
687    /// SELECT ARRAY_CONTAINS(column_1, 'foo') FROM df;
688    /// ```
689    ArrayContains,
690
691    // ----
692    // Column selection
693    // ----
694    Columns,
695
696    // ----
697    // User-defined
698    // ----
699    Udf(String),
700}
701
702impl PolarsSQLFunctions {
703    pub(crate) fn keywords() -> &'static [&'static str] {
704        &[
705            "abs",
706            "acos",
707            "acosd",
708            "array_contains",
709            "array_get",
710            "array_length",
711            "array_lower",
712            "array_mean",
713            "array_reverse",
714            "array_sum",
715            "array_to_string",
716            "array_unique",
717            "array_upper",
718            "asin",
719            "asind",
720            "atan",
721            "atan2",
722            "atan2d",
723            "atand",
724            "avg",
725            "bit_and",
726            "bit_count",
727            "bit_length",
728            "bit_or",
729            "bit_xor",
730            "cbrt",
731            "ceil",
732            "ceiling",
733            "char_length",
734            "character_length",
735            "coalesce",
736            "columns",
737            "concat",
738            "concat_ws",
739            "corr",
740            "cos",
741            "cosd",
742            "cot",
743            "cotd",
744            "count",
745            "covar",
746            "covar_pop",
747            "covar_samp",
748            "date",
749            "date_part",
750            "degrees",
751            "ends_with",
752            "exp",
753            "first",
754            "floor",
755            "greatest",
756            "if",
757            "ifnull",
758            "initcap",
759            "last",
760            "least",
761            "left",
762            "length",
763            "ln",
764            "log",
765            "log10",
766            "log1p",
767            "log2",
768            "lower",
769            "ltrim",
770            "max",
771            "median",
772            "quantile_disc",
773            "min",
774            "mod",
775            "nullif",
776            "octet_length",
777            "pi",
778            "pow",
779            "power",
780            "quantile_cont",
781            "quantile_disc",
782            "radians",
783            "regexp_like",
784            "replace",
785            "reverse",
786            "right",
787            "round",
788            "rtrim",
789            "sign",
790            "sin",
791            "sind",
792            "sqrt",
793            "starts_with",
794            "stddev",
795            "stddev_samp",
796            "stdev",
797            "stdev_samp",
798            "strftime",
799            "strpos",
800            "strptime",
801            "substr",
802            "sum",
803            "tan",
804            "tand",
805            "unnest",
806            "upper",
807            "var",
808            "var_samp",
809            "variance",
810        ]
811    }
812}
813
814impl PolarsSQLFunctions {
815    fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
816        let function_name = function.name.0[0].value.to_lowercase();
817        Ok(match function_name.as_str() {
818            // ----
819            // Bitwise functions
820            // ----
821            "bit_and" | "bitand" => Self::BitAnd,
822            #[cfg(feature = "bitwise")]
823            "bit_count" | "bitcount" => Self::BitCount,
824            "bit_or" | "bitor" => Self::BitOr,
825            "bit_xor" | "bitxor" | "xor" => Self::BitXor,
826
827            // ----
828            // Math functions
829            // ----
830            "abs" => Self::Abs,
831            "cbrt" => Self::Cbrt,
832            "ceil" | "ceiling" => Self::Ceil,
833            "div" => Self::Div,
834            "exp" => Self::Exp,
835            "floor" => Self::Floor,
836            "ln" => Self::Ln,
837            "log" => Self::Log,
838            "log10" => Self::Log10,
839            "log1p" => Self::Log1p,
840            "log2" => Self::Log2,
841            "mod" => Self::Mod,
842            "pi" => Self::Pi,
843            "pow" | "power" => Self::Pow,
844            "round" => Self::Round,
845            "sign" => Self::Sign,
846            "sqrt" => Self::Sqrt,
847
848            // ----
849            // Trig functions
850            // ----
851            "cos" => Self::Cos,
852            "cot" => Self::Cot,
853            "sin" => Self::Sin,
854            "tan" => Self::Tan,
855            "cosd" => Self::CosD,
856            "cotd" => Self::CotD,
857            "sind" => Self::SinD,
858            "tand" => Self::TanD,
859            "acos" => Self::Acos,
860            "asin" => Self::Asin,
861            "atan" => Self::Atan,
862            "atan2" => Self::Atan2,
863            "acosd" => Self::AcosD,
864            "asind" => Self::AsinD,
865            "atand" => Self::AtanD,
866            "atan2d" => Self::Atan2D,
867            "degrees" => Self::Degrees,
868            "radians" => Self::Radians,
869
870            // ----
871            // Conditional functions
872            // ----
873            "coalesce" => Self::Coalesce,
874            "greatest" => Self::Greatest,
875            "if" => Self::If,
876            "ifnull" => Self::IfNull,
877            "least" => Self::Least,
878            "nullif" => Self::NullIf,
879
880            // ----
881            // Date functions
882            // ----
883            "date_part" => Self::DatePart,
884            "strftime" => Self::Strftime,
885
886            // ----
887            // String functions
888            // ----
889            "bit_length" => Self::BitLength,
890            "concat" => Self::Concat,
891            "concat_ws" => Self::ConcatWS,
892            "date" => Self::Date,
893            "timestamp" | "datetime" => Self::Timestamp,
894            "ends_with" => Self::EndsWith,
895            #[cfg(feature = "nightly")]
896            "initcap" => Self::InitCap,
897            "length" | "char_length" | "character_length" => Self::Length,
898            "left" => Self::Left,
899            "lower" => Self::Lower,
900            "ltrim" => Self::LTrim,
901            "normalize" => Self::Normalize,
902            "octet_length" => Self::OctetLength,
903            "strpos" => Self::StrPos,
904            "regexp_like" => Self::RegexpLike,
905            "replace" => Self::Replace,
906            "reverse" => Self::Reverse,
907            "right" => Self::Right,
908            "rtrim" => Self::RTrim,
909            "split_part" => Self::SplitPart,
910            "starts_with" => Self::StartsWith,
911            "string_to_array" => Self::StringToArray,
912            "strptime" => Self::Strptime,
913            "substr" => Self::Substring,
914            "time" => Self::Time,
915            "upper" => Self::Upper,
916
917            // ----
918            // Aggregate functions
919            // ----
920            "avg" => Self::Avg,
921            "corr" => Self::Corr,
922            "count" => Self::Count,
923            "covar_pop" => Self::CovarPop,
924            "covar" | "covar_samp" => Self::CovarSamp,
925            "first" => Self::First,
926            "last" => Self::Last,
927            "max" => Self::Max,
928            "median" => Self::Median,
929            "quantile_cont" => Self::QuantileCont,
930            "quantile_disc" => Self::QuantileDisc,
931            "min" => Self::Min,
932            "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev,
933            "sum" => Self::Sum,
934            "var" | "variance" | "var_samp" => Self::Variance,
935
936            // ----
937            // Array functions
938            // ----
939            "array_agg" => Self::ArrayAgg,
940            "array_contains" => Self::ArrayContains,
941            "array_get" => Self::ArrayGet,
942            "array_length" => Self::ArrayLength,
943            "array_lower" => Self::ArrayMin,
944            "array_mean" => Self::ArrayMean,
945            "array_reverse" => Self::ArrayReverse,
946            "array_sum" => Self::ArraySum,
947            "array_to_string" => Self::ArrayToString,
948            "array_unique" => Self::ArrayUnique,
949            "array_upper" => Self::ArrayMax,
950            "unnest" => Self::Explode,
951
952            // ----
953            // Column selection
954            // ----
955            "columns" => Self::Columns,
956
957            other => {
958                if ctx.function_registry.contains(other) {
959                    Self::Udf(other.to_string())
960                } else {
961                    polars_bail!(SQLInterface: "unsupported function '{}'", other);
962                }
963            },
964        })
965    }
966}
967
968impl SQLFunctionVisitor<'_> {
969    pub(crate) fn visit_function(&mut self) -> PolarsResult<Expr> {
970        use PolarsSQLFunctions::*;
971        use polars_lazy::prelude::Literal;
972
973        let function_name = PolarsSQLFunctions::try_from_sql(self.func, self.ctx)?;
974        let function = self.func;
975
976        // TODO: implement the following functions where possible
977        if !function.within_group.is_empty() {
978            polars_bail!(SQLInterface: "'WITHIN GROUP' is not currently supported")
979        }
980        if function.filter.is_some() {
981            polars_bail!(SQLInterface: "'FILTER' is not currently supported")
982        }
983        if function.null_treatment.is_some() {
984            polars_bail!(SQLInterface: "'IGNORE|RESPECT NULLS' is not currently supported")
985        }
986
987        let log_with_base =
988            |e: Expr, base: f64| e.log(LiteralValue::Dyn(DynLiteralValue::Float(base)).lit());
989        match function_name {
990            // ----
991            // Bitwise functions
992            // ----
993            BitAnd => self.visit_binary::<Expr>(Expr::and),
994            #[cfg(feature = "bitwise")]
995            BitCount => self.visit_unary(Expr::bitwise_count_ones),
996            BitOr => self.visit_binary::<Expr>(Expr::or),
997            BitXor => self.visit_binary::<Expr>(Expr::xor),
998
999            // ----
1000            // Math functions
1001            // ----
1002            Abs => self.visit_unary(Expr::abs),
1003            Cbrt => self.visit_unary(Expr::cbrt),
1004            Ceil => self.visit_unary(Expr::ceil),
1005            Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64)),
1006            Exp => self.visit_unary(Expr::exp),
1007            Floor => self.visit_unary(Expr::floor),
1008            Ln => self.visit_unary(|e| log_with_base(e, std::f64::consts::E)),
1009            Log => self.visit_binary(Expr::log),
1010            Log10 => self.visit_unary(|e| log_with_base(e, 10.0)),
1011            Log1p => self.visit_unary(Expr::log1p),
1012            Log2 => self.visit_unary(|e| log_with_base(e, 2.0)),
1013            Pi => self.visit_nullary(Expr::pi),
1014            Mod => self.visit_binary(|e1, e2| e1 % e2),
1015            Pow => self.visit_binary::<Expr>(Expr::pow),
1016            Round => {
1017                let args = extract_args(function)?;
1018                match args.len() {
1019                    1 => self.visit_unary(|e| e.round(0, RoundMode::default())),
1020                    2 => self.try_visit_binary(|e, decimals| {
1021                        Ok(e.round(match decimals {
1022                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1023                                if n >= 0 { n as u32 } else {
1024                                    polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", args[1])
1025                                }
1026                            },
1027                            _ => polars_bail!(SQLSyntax: "invalid value for ROUND decimals ({})", args[1]),
1028                        }, RoundMode::default()))
1029                    }),
1030                    _ => polars_bail!(SQLSyntax: "ROUND expects 1-2 arguments (found {})", args.len()),
1031                }
1032            },
1033            Sign => self.visit_unary(Expr::sign),
1034            Sqrt => self.visit_unary(Expr::sqrt),
1035
1036            // ----
1037            // Trig functions
1038            // ----
1039            Acos => self.visit_unary(Expr::arccos),
1040            AcosD => self.visit_unary(|e| e.arccos().degrees()),
1041            Asin => self.visit_unary(Expr::arcsin),
1042            AsinD => self.visit_unary(|e| e.arcsin().degrees()),
1043            Atan => self.visit_unary(Expr::arctan),
1044            Atan2 => self.visit_binary(Expr::arctan2),
1045            Atan2D => self.visit_binary(|e, s| e.arctan2(s).degrees()),
1046            AtanD => self.visit_unary(|e| e.arctan().degrees()),
1047            Cos => self.visit_unary(Expr::cos),
1048            CosD => self.visit_unary(|e| e.radians().cos()),
1049            Cot => self.visit_unary(Expr::cot),
1050            CotD => self.visit_unary(|e| e.radians().cot()),
1051            Degrees => self.visit_unary(Expr::degrees),
1052            Radians => self.visit_unary(Expr::radians),
1053            Sin => self.visit_unary(Expr::sin),
1054            SinD => self.visit_unary(|e| e.radians().sin()),
1055            Tan => self.visit_unary(Expr::tan),
1056            TanD => self.visit_unary(|e| e.radians().tan()),
1057
1058            // ----
1059            // Conditional functions
1060            // ----
1061            Coalesce => self.visit_variadic(coalesce),
1062            Greatest => self.visit_variadic(|exprs: &[Expr]| max_horizontal(exprs).unwrap()),
1063            If => {
1064                let args = extract_args(function)?;
1065                match args.len() {
1066                    3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| {
1067                        Ok(when(cond).then(expr1).otherwise(expr2))
1068                    }),
1069                    _ => {
1070                        polars_bail!(SQLSyntax: "IF expects 3 arguments (found {})", args.len()
1071                        )
1072                    },
1073                }
1074            },
1075            IfNull => {
1076                let args = extract_args(function)?;
1077                match args.len() {
1078                    2 => self.visit_variadic(coalesce),
1079                    _ => {
1080                        polars_bail!(SQLSyntax: "IFNULL expects 2 arguments (found {})", args.len())
1081                    },
1082                }
1083            },
1084            Least => self.visit_variadic(|exprs: &[Expr]| min_horizontal(exprs).unwrap()),
1085            NullIf => {
1086                let args = extract_args(function)?;
1087                match args.len() {
1088                    2 => self.visit_binary(|l: Expr, r: Expr| {
1089                        when(l.clone().eq(r))
1090                            .then(lit(LiteralValue::untyped_null()))
1091                            .otherwise(l)
1092                    }),
1093                    _ => {
1094                        polars_bail!(SQLSyntax: "NULLIF expects 2 arguments (found {})", args.len())
1095                    },
1096                }
1097            },
1098
1099            // ----
1100            // Date functions
1101            // ----
1102            DatePart => self.try_visit_binary(|part, e| {
1103                match part {
1104                    Expr::Literal(p) if p.extract_str().is_some() => {
1105                        let p = p.extract_str().unwrap();
1106                        // note: 'DATE_PART' and 'EXTRACT' are minor syntactic
1107                        // variations on otherwise identical functionality
1108                        parse_extract_date_part(
1109                            e,
1110                            &DateTimeField::Custom(Ident {
1111                                value: p.to_string(),
1112                                quote_style: None,
1113                                span: Span::empty(),
1114                            }),
1115                        )
1116                    },
1117                    _ => {
1118                        polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART ({})", part);
1119                    },
1120                }
1121            }),
1122            Strftime => {
1123                let args = extract_args(function)?;
1124                match args.len() {
1125                    2 => self.visit_binary(|e, fmt: String| e.dt().strftime(fmt.as_str())),
1126                    _ => {
1127                        polars_bail!(SQLSyntax: "STRFTIME expects 2 arguments (found {})", args.len())
1128                    },
1129                }
1130            },
1131
1132            // ----
1133            // String functions
1134            // ----
1135            BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)),
1136            Concat => {
1137                let args = extract_args(function)?;
1138                if args.is_empty() {
1139                    polars_bail!(SQLSyntax: "CONCAT expects at least 1 argument (found 0)");
1140                } else {
1141                    self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true))
1142                }
1143            },
1144            ConcatWS => {
1145                let args = extract_args(function)?;
1146                if args.len() < 2 {
1147                    polars_bail!(SQLSyntax: "CONCAT_WS expects at least 2 arguments (found {})", args.len());
1148                } else {
1149                    self.try_visit_variadic(|exprs: &[Expr]| {
1150                        match &exprs[0] {
1151                            Expr::Literal(lv) if lv.extract_str().is_some() => Ok(concat_str(&exprs[1..], lv.extract_str().unwrap(), true)),
1152                            _ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string (found {:?})", exprs[0]),
1153                        }
1154                    })
1155                }
1156            },
1157            Date => {
1158                let args = extract_args(function)?;
1159                match args.len() {
1160                    1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
1161                    2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
1162                    _ => {
1163                        polars_bail!(SQLSyntax: "DATE expects 1-2 arguments (found {})", args.len())
1164                    },
1165                }
1166            },
1167            EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
1168            #[cfg(feature = "nightly")]
1169            InitCap => self.visit_unary(|e| e.str().to_titlecase()),
1170            Left => self.try_visit_binary(|e, length| {
1171                Ok(match length {
1172                    Expr::Literal(lv) if lv.is_null() => lit(lv),
1173                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => lit(""),
1174                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1175                        let len = if n > 0 {
1176                            lit(n)
1177                        } else {
1178                            (e.clone().str().len_chars() + lit(n)).clip_min(lit(0))
1179                        };
1180                        e.str().slice(lit(0), len)
1181                    },
1182                    Expr::Literal(v) => {
1183                        polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT ({:?})", v)
1184                    },
1185                    _ => when(length.clone().gt_eq(lit(0)))
1186                        .then(e.clone().str().slice(lit(0), length.clone().abs()))
1187                        .otherwise(e.clone().str().slice(
1188                            lit(0),
1189                            (e.str().len_chars() + length.clone()).clip_min(lit(0)),
1190                        )),
1191                })
1192            }),
1193            Length => self.visit_unary(|e| e.str().len_chars()),
1194            Lower => self.visit_unary(|e| e.str().to_lowercase()),
1195            LTrim => {
1196                let args = extract_args(function)?;
1197                match args.len() {
1198                    1 => self.visit_unary(|e| {
1199                        e.str().strip_chars_start(lit(LiteralValue::untyped_null()))
1200                    }),
1201                    2 => self.visit_binary(|e, s| e.str().strip_chars_start(s)),
1202                    _ => {
1203                        polars_bail!(SQLSyntax: "LTRIM expects 1-2 arguments (found {})", args.len())
1204                    },
1205                }
1206            },
1207            Normalize => {
1208                let args = extract_args(function)?;
1209                match args.len() {
1210                    1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)),
1211                    2 => {
1212                        let form = if let FunctionArgExpr::Expr(SQLExpr::Identifier(Ident {
1213                            value: s,
1214                            quote_style: None,
1215                            span: _,
1216                        })) = args[1]
1217                        {
1218                            match s.to_uppercase().as_str() {
1219                                "NFC" => UnicodeForm::NFC,
1220                                "NFD" => UnicodeForm::NFD,
1221                                "NFKC" => UnicodeForm::NFKC,
1222                                "NFKD" => UnicodeForm::NFKD,
1223                                _ => {
1224                                    polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s)
1225                                },
1226                            }
1227                        } else {
1228                            polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", args[1])
1229                        };
1230                        self.try_visit_binary(|e, _form: Expr| Ok(e.str().normalize(form.clone())))
1231                    },
1232                    _ => {
1233                        polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len())
1234                    },
1235                }
1236            },
1237            OctetLength => self.visit_unary(|e| e.str().len_bytes()),
1238            StrPos => {
1239                // note: SQL is 1-indexed; returns zero if no match found
1240                self.visit_binary(|expr, substring| {
1241                    (expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))
1242                })
1243            },
1244            RegexpLike => {
1245                let args = extract_args(function)?;
1246                match args.len() {
1247                    2 => self.visit_binary(|e, s| e.str().contains(s, true)),
1248                    3 => self.try_visit_ternary(|e, pat, flags| {
1249                        Ok(e.str().contains(
1250                            match (pat, flags) {
1251                                (Expr::Literal(s_lv), Expr::Literal(f_lv)) if s_lv.extract_str().is_some() && f_lv.extract_str().is_some() => {
1252                                    let s = s_lv.extract_str().unwrap();
1253                                    let f = f_lv.extract_str().unwrap();
1254                                    if f.is_empty() {
1255                                        polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE ({})", args[2]);
1256                                    };
1257                                    lit(format!("(?{f}){s}"))
1258                                },
1259                                _ => {
1260                                    polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE ({}, {})", args[1], args[2]);
1261                                },
1262                            },
1263                            true))
1264                    }),
1265                    _ => polars_bail!(SQLSyntax: "REGEXP_LIKE expects 2-3 arguments (found {})",args.len()),
1266                }
1267            },
1268            Replace => {
1269                let args = extract_args(function)?;
1270                match args.len() {
1271                    3 => self
1272                        .try_visit_ternary(|e, old, new| Ok(e.str().replace_all(old, new, true))),
1273                    _ => {
1274                        polars_bail!(SQLSyntax: "REPLACE expects 3 arguments (found {})", args.len())
1275                    },
1276                }
1277            },
1278            Reverse => self.visit_unary(|e| e.str().reverse()),
1279            Right => self.try_visit_binary(|e, length| {
1280                Ok(match length {
1281                    Expr::Literal(lv) if lv.is_null() => lit(lv),
1282                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => typed_lit(""),
1283                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1284                        let n: i64 = n.try_into().unwrap();
1285                        let offset = if n < 0 {
1286                            lit(n.abs())
1287                        } else {
1288                            e.clone().str().len_chars().cast(DataType::Int32) - lit(n)
1289                        };
1290                        e.str().slice(offset, lit(LiteralValue::untyped_null()))
1291                    },
1292                    Expr::Literal(v) => {
1293                        polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT ({:?})", v)
1294                    },
1295                    _ => when(length.clone().lt(lit(0)))
1296                        .then(
1297                            e.clone()
1298                                .str()
1299                                .slice(length.clone().abs(), lit(LiteralValue::untyped_null())),
1300                        )
1301                        .otherwise(e.clone().str().slice(
1302                            e.str().len_chars().cast(DataType::Int32) - length.clone(),
1303                            lit(LiteralValue::untyped_null()),
1304                        )),
1305                })
1306            }),
1307            RTrim => {
1308                let args = extract_args(function)?;
1309                match args.len() {
1310                    1 => self.visit_unary(|e| {
1311                        e.str().strip_chars_end(lit(LiteralValue::untyped_null()))
1312                    }),
1313                    2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
1314                    _ => {
1315                        polars_bail!(SQLSyntax: "RTRIM expects 1-2 arguments (found {})", args.len())
1316                    },
1317                }
1318            },
1319            SplitPart => {
1320                let args = extract_args(function)?;
1321                match args.len() {
1322                    3 => self.try_visit_ternary(|e, sep, idx| {
1323                        let idx = adjust_one_indexed_param(idx, true);
1324                        Ok(when(e.clone().is_not_null())
1325                            .then(
1326                                e.clone()
1327                                    .str()
1328                                    .split(sep)
1329                                    .list()
1330                                    .get(idx, true)
1331                                    .fill_null(lit("")),
1332                            )
1333                            .otherwise(e))
1334                    }),
1335                    _ => {
1336                        polars_bail!(SQLSyntax: "SPLIT_PART expects 3 arguments (found {})", args.len())
1337                    },
1338                }
1339            },
1340            StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
1341            StringToArray => {
1342                let args = extract_args(function)?;
1343                match args.len() {
1344                    2 => self.visit_binary(|e, sep| e.str().split(sep)),
1345                    _ => {
1346                        polars_bail!(SQLSyntax: "STRING_TO_ARRAY expects 2 arguments (found {})", args.len())
1347                    },
1348                }
1349            },
1350            Strptime => {
1351                let args = extract_args(function)?;
1352                match args.len() {
1353                    2 => self.visit_binary(|e, fmt: String| {
1354                        e.str().strptime(
1355                            DataType::Datetime(TimeUnit::Microseconds, None),
1356                            StrptimeOptions {
1357                                format: Some(fmt.into()),
1358                                ..Default::default()
1359                            },
1360                            lit("latest"),
1361                        )
1362                    }),
1363                    _ => {
1364                        polars_bail!(SQLSyntax: "STRPTIME expects 2 arguments (found {})", args.len())
1365                    },
1366                }
1367            },
1368            Time => {
1369                let args = extract_args(function)?;
1370                match args.len() {
1371                    1 => self.visit_unary(|e| e.str().to_time(StrptimeOptions::default())),
1372                    2 => self.visit_binary(|e, fmt| e.str().to_time(fmt)),
1373                    _ => {
1374                        polars_bail!(SQLSyntax: "TIME expects 1-2 arguments (found {})", args.len())
1375                    },
1376                }
1377            },
1378            Timestamp => {
1379                let args = extract_args(function)?;
1380                match args.len() {
1381                    1 => self.visit_unary(|e| {
1382                        e.str()
1383                            .to_datetime(None, None, StrptimeOptions::default(), lit("latest"))
1384                    }),
1385                    2 => self
1386                        .visit_binary(|e, fmt| e.str().to_datetime(None, None, fmt, lit("latest"))),
1387                    _ => {
1388                        polars_bail!(SQLSyntax: "DATETIME expects 1-2 arguments (found {})", args.len())
1389                    },
1390                }
1391            },
1392            Substring => {
1393                let args = extract_args(function)?;
1394                match args.len() {
1395                    // note: SQL is 1-indexed, hence the need for adjustments
1396                    2 => self.try_visit_binary(|e, start| {
1397                        Ok(match start {
1398                            Expr::Literal(lv) if lv.is_null() => lit(lv),
1399                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1400                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null())),
1401                            Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1402                            _ => start.clone() + lit(1),
1403                        })
1404                    }),
1405                    3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
1406                        Ok(match (start.clone(), length.clone()) {
1407                            (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1408                            (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1409                                polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", args[2])
1410                            },
1411                            (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => e.str().slice(lit(n - 1), length),
1412                            (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => {
1413                                e.str().slice(lit(0), (length + lit(n - 1)).clip_min(lit(0)))
1414                            },
1415                            (Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1416                            (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1417                                polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR ({})", args[1])
1418                            },
1419                            _ => {
1420                                let adjusted_start = start - lit(1);
1421                                when(adjusted_start.clone().lt(lit(0)))
1422                                    .then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
1423                                    .otherwise(e.str().slice(adjusted_start, length))
1424                            }
1425                        })
1426                    }),
1427                    _ => polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found {})", args.len()),
1428                }
1429            },
1430            Upper => self.visit_unary(|e| e.str().to_uppercase()),
1431
1432            // ----
1433            // Aggregate functions
1434            // ----
1435            Avg => self.visit_unary(Expr::mean),
1436            Corr => self.visit_binary(polars_lazy::dsl::pearson_corr),
1437            Count => self.visit_count(),
1438            CovarPop => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 0)),
1439            CovarSamp => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 1)),
1440            First => self.visit_unary(Expr::first),
1441            Last => self.visit_unary(Expr::last),
1442            Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max),
1443            Median => self.visit_unary(Expr::median),
1444            QuantileCont => {
1445                let args = extract_args(function)?;
1446                match args.len() {
1447                    2 => self.try_visit_binary(|e, q| {
1448                        let value = match q {
1449                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1450                                if (0.0..=1.0).contains(&f) {
1451                                    Expr::from(f)
1452                                } else {
1453                                    polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1454                                }
1455                            },
1456                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1457                                if (0..=1).contains(&n) {
1458                                    Expr::from(n as f64)
1459                                } else {
1460                                    polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1461                                }
1462                            },
1463                            _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1])
1464                        };
1465                        Ok(e.quantile(value, QuantileMethod::Linear))
1466                    }),
1467                    _ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()),
1468                }
1469            },
1470            QuantileDisc => {
1471                let args = extract_args(function)?;
1472                match args.len() {
1473                    2 => self.try_visit_binary(|e, q| {
1474                        let value = match q {
1475                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1476                                if (0.0..=1.0).contains(&f) {
1477                                    Expr::from(f)
1478                                } else {
1479                                    polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1480                                }
1481                            },
1482                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1483                                if (0..=1).contains(&n) {
1484                                    Expr::from(n as f64)
1485                                } else {
1486                                    polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1487                                }
1488                            },
1489                            _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1])
1490                        };
1491                        Ok(e.quantile(value, QuantileMethod::Equiprobable))
1492                    }),
1493                    _ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()),
1494                }
1495            },
1496            Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min),
1497            StdDev => self.visit_unary(|e| e.std(1)),
1498            Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum),
1499            Variance => self.visit_unary(|e| e.var(1)),
1500
1501            // ----
1502            // Array functions
1503            // ----
1504            ArrayAgg => self.visit_arr_agg(),
1505            ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s, true)),
1506            ArrayGet => {
1507                // note: SQL is 1-indexed, not 0-indexed
1508                self.visit_binary(|e, idx: Expr| {
1509                    let idx = adjust_one_indexed_param(idx, true);
1510                    e.list().get(idx, true)
1511                })
1512            },
1513            ArrayLength => self.visit_unary(|e| e.list().len()),
1514            ArrayMax => self.visit_unary(|e| e.list().max()),
1515            ArrayMean => self.visit_unary(|e| e.list().mean()),
1516            ArrayMin => self.visit_unary(|e| e.list().min()),
1517            ArrayReverse => self.visit_unary(|e| e.list().reverse()),
1518            ArraySum => self.visit_unary(|e| e.list().sum()),
1519            ArrayToString => self.visit_arr_to_string(),
1520            ArrayUnique => self.visit_unary(|e| e.list().unique()),
1521            Explode => self.visit_unary(|e| e.explode()),
1522
1523            // ----
1524            // Column selection
1525            // ----
1526            Columns => {
1527                let active_schema = self.active_schema;
1528                self.try_visit_unary(|e: Expr| match e {
1529                    Expr::Literal(lv) if lv.extract_str().is_some() => {
1530                        let pat = lv.extract_str().unwrap();
1531                        if pat == "*" {
1532                            polars_bail!(
1533                                SQLSyntax: "COLUMNS('*') is not a valid regex; \
1534                                did you mean COLUMNS(*)?"
1535                            )
1536                        };
1537                        let pat = match pat {
1538                            _ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(),
1539                            _ if pat.starts_with('^') => format!("{pat}.*$"),
1540                            _ if pat.ends_with('$') => format!("^.*{pat}"),
1541                            _ => format!("^.*{pat}.*$"),
1542                        };
1543                        if let Some(active_schema) = &active_schema {
1544                            let rx = polars_utils::regex_cache::compile_regex(&pat).unwrap();
1545                            let col_names = active_schema
1546                                .iter_names()
1547                                .filter(|name| rx.is_match(name))
1548                                .cloned()
1549                                .collect::<Vec<_>>();
1550
1551                            Ok(if col_names.len() == 1 {
1552                                col(col_names.into_iter().next().unwrap())
1553                            } else {
1554                                cols(col_names).as_expr()
1555                            })
1556                        } else {
1557                            Ok(col(pat.as_str()))
1558                        }
1559                    },
1560                    Expr::Selector(s) => Ok(s.as_expr()),
1561                    _ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e),
1562                })
1563            },
1564
1565            // ----
1566            // User-defined
1567            // ----
1568            Udf(func_name) => self.visit_udf(&func_name),
1569        }
1570    }
1571
1572    fn visit_udf(&mut self, func_name: &str) -> PolarsResult<Expr> {
1573        let args = extract_args(self.func)?
1574            .into_iter()
1575            .map(|arg| {
1576                if let FunctionArgExpr::Expr(e) = arg {
1577                    parse_sql_expr(e, self.ctx, self.active_schema)
1578                } else {
1579                    polars_bail!(SQLInterface: "only expressions are supported in UDFs")
1580                }
1581            })
1582            .collect::<PolarsResult<Vec<_>>>()?;
1583
1584        Ok(self
1585            .ctx
1586            .function_registry
1587            .get_udf(func_name)?
1588            .ok_or_else(|| polars_err!(SQLInterface: "UDF {} not found", func_name))?
1589            .call(args))
1590    }
1591
1592    /// Window specs without partition bys are essentially cumulative functions
1593    /// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false)
1594    fn apply_cumulative_window(
1595        &mut self,
1596        f: impl Fn(Expr) -> Expr,
1597        cumulative_f: impl Fn(Expr, bool) -> Expr,
1598        WindowSpec {
1599            partition_by,
1600            order_by,
1601            ..
1602        }: &WindowSpec,
1603    ) -> PolarsResult<Expr> {
1604        if !order_by.is_empty() && partition_by.is_empty() {
1605            let (order_by, desc): (Vec<Expr>, Vec<bool>) = order_by
1606                .iter()
1607                .map(|o| {
1608                    let expr = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
1609                    Ok(match o.asc {
1610                        Some(b) => (expr, !b),
1611                        None => (expr, false),
1612                    })
1613                })
1614                .collect::<PolarsResult<Vec<_>>>()?
1615                .into_iter()
1616                .unzip();
1617            self.visit_unary_no_window(|e| {
1618                cumulative_f(
1619                    e.sort_by(
1620                        &order_by,
1621                        SortMultipleOptions::default().with_order_descending_multi(desc.clone()),
1622                    ),
1623                    false,
1624                )
1625            })
1626        } else {
1627            self.visit_unary(f)
1628        }
1629    }
1630
1631    fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1632        self.try_visit_unary(|e| Ok(f(e)))
1633    }
1634
1635    fn try_visit_unary(&mut self, f: impl Fn(Expr) -> PolarsResult<Expr>) -> PolarsResult<Expr> {
1636        let args = extract_args(self.func)?;
1637        match args.as_slice() {
1638            [FunctionArgExpr::Expr(sql_expr)] => {
1639                f(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?)
1640            },
1641            [FunctionArgExpr::Wildcard] => f(parse_sql_expr(
1642                &SQLExpr::Wildcard(AttachedToken::empty()),
1643                self.ctx,
1644                self.active_schema,
1645            )?),
1646            _ => self.not_supported_error(),
1647        }
1648        .and_then(|e| self.apply_window_spec(e, &self.func.over))
1649    }
1650
1651    /// Some functions have cumulative equivalents that can be applied to window specs
1652    /// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false)
1653    /// visit_unary_with_cumulative_window will take in a function & a cumulative function
1654    /// if there is a cumulative window spec, it will apply the cumulative function,
1655    /// otherwise it will apply the function
1656    fn visit_unary_with_opt_cumulative(
1657        &mut self,
1658        f: impl Fn(Expr) -> Expr,
1659        cumulative_f: impl Fn(Expr, bool) -> Expr,
1660    ) -> PolarsResult<Expr> {
1661        match self.func.over.as_ref() {
1662            Some(WindowType::WindowSpec(spec)) => {
1663                self.apply_cumulative_window(f, cumulative_f, spec)
1664            },
1665            Some(WindowType::NamedWindow(named_window)) => polars_bail!(
1666                SQLInterface: "Named windows are not currently supported; found {:?}",
1667                named_window
1668            ),
1669            _ => self.visit_unary(f),
1670        }
1671    }
1672
1673    fn visit_unary_no_window(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1674        let args = extract_args(self.func)?;
1675        match args.as_slice() {
1676            [FunctionArgExpr::Expr(sql_expr)] => {
1677                let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1678                // apply the function on the inner expr -- e.g. SUM(a) -> SUM
1679                Ok(f(expr))
1680            },
1681            _ => self.not_supported_error(),
1682        }
1683    }
1684
1685    fn visit_binary<Arg: FromSQLExpr>(
1686        &mut self,
1687        f: impl Fn(Expr, Arg) -> Expr,
1688    ) -> PolarsResult<Expr> {
1689        self.try_visit_binary(|e, a| Ok(f(e, a)))
1690    }
1691
1692    fn try_visit_binary<Arg: FromSQLExpr>(
1693        &mut self,
1694        f: impl Fn(Expr, Arg) -> PolarsResult<Expr>,
1695    ) -> PolarsResult<Expr> {
1696        let args = extract_args(self.func)?;
1697        match args.as_slice() {
1698            [
1699                FunctionArgExpr::Expr(sql_expr1),
1700                FunctionArgExpr::Expr(sql_expr2),
1701            ] => {
1702                let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
1703                let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
1704                f(expr1, expr2)
1705            },
1706            _ => self.not_supported_error(),
1707        }
1708    }
1709
1710    fn visit_variadic(&mut self, f: impl Fn(&[Expr]) -> Expr) -> PolarsResult<Expr> {
1711        self.try_visit_variadic(|e| Ok(f(e)))
1712    }
1713
1714    fn try_visit_variadic(
1715        &mut self,
1716        f: impl Fn(&[Expr]) -> PolarsResult<Expr>,
1717    ) -> PolarsResult<Expr> {
1718        let args = extract_args(self.func)?;
1719        let mut expr_args = vec![];
1720        for arg in args {
1721            if let FunctionArgExpr::Expr(sql_expr) = arg {
1722                expr_args.push(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?);
1723            } else {
1724                return self.not_supported_error();
1725            };
1726        }
1727        f(&expr_args)
1728    }
1729
1730    fn try_visit_ternary<Arg: FromSQLExpr>(
1731        &mut self,
1732        f: impl Fn(Expr, Arg, Arg) -> PolarsResult<Expr>,
1733    ) -> PolarsResult<Expr> {
1734        let args = extract_args(self.func)?;
1735        match args.as_slice() {
1736            [
1737                FunctionArgExpr::Expr(sql_expr1),
1738                FunctionArgExpr::Expr(sql_expr2),
1739                FunctionArgExpr::Expr(sql_expr3),
1740            ] => {
1741                let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
1742                let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
1743                let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?;
1744                f(expr1, expr2, expr3)
1745            },
1746            _ => self.not_supported_error(),
1747        }
1748    }
1749
1750    fn visit_nullary(&self, f: impl Fn() -> Expr) -> PolarsResult<Expr> {
1751        let args = extract_args(self.func)?;
1752        if !args.is_empty() {
1753            return self.not_supported_error();
1754        }
1755        Ok(f())
1756    }
1757
1758    fn visit_arr_agg(&mut self) -> PolarsResult<Expr> {
1759        let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
1760        match args.as_slice() {
1761            [FunctionArgExpr::Expr(sql_expr)] => {
1762                let mut base = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1763                if is_distinct {
1764                    base = base.unique_stable();
1765                }
1766                for clause in clauses {
1767                    match clause {
1768                        FunctionArgumentClause::OrderBy(order_exprs) => {
1769                            base = self.apply_order_by(base, order_exprs.as_slice())?;
1770                        },
1771                        FunctionArgumentClause::Limit(limit_expr) => {
1772                            let limit = parse_sql_expr(&limit_expr, self.ctx, self.active_schema)?;
1773                            match limit {
1774                                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))
1775                                    if n >= 0 =>
1776                                {
1777                                    base = base.head(Some(n as usize))
1778                                },
1779                                _ => {
1780                                    polars_bail!(SQLSyntax: "LIMIT in ARRAY_AGG must be a positive integer")
1781                                },
1782                            };
1783                        },
1784                        _ => {},
1785                    }
1786                }
1787                Ok(base.implode())
1788            },
1789            _ => {
1790                polars_bail!(SQLSyntax: "ARRAY_AGG must have exactly one argument; found {}", args.len())
1791            },
1792        }
1793    }
1794
1795    fn visit_arr_to_string(&mut self) -> PolarsResult<Expr> {
1796        let args = extract_args(self.func)?;
1797        match args.len() {
1798            2 => self.try_visit_binary(|e, sep| {
1799                Ok(e.cast(DataType::List(Box::from(DataType::String)))
1800                    .list()
1801                    .join(sep, true))
1802            }),
1803            #[cfg(feature = "list_eval")]
1804            3 => self.try_visit_ternary(|e, sep, null_value| match null_value {
1805                Expr::Literal(lv) if lv.extract_str().is_some() => {
1806                    Ok(if lv.extract_str().unwrap().is_empty() {
1807                        e.cast(DataType::List(Box::from(DataType::String)))
1808                            .list()
1809                            .join(sep, true)
1810                    } else {
1811                        e.cast(DataType::List(Box::from(DataType::String)))
1812                            .list()
1813                            .eval(col("").fill_null(lit(lv.extract_str().unwrap())))
1814                            .list()
1815                            .join(sep, false)
1816                    })
1817                },
1818                _ => {
1819                    polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING ({})", args[2])
1820                },
1821            }),
1822            _ => {
1823                polars_bail!(SQLSyntax: "ARRAY_TO_STRING expects 2-3 arguments (found {})", args.len())
1824            },
1825        }
1826    }
1827
1828    fn visit_count(&mut self) -> PolarsResult<Expr> {
1829        let (args, is_distinct) = extract_args_distinct(self.func)?;
1830        let count_expr = match (is_distinct, args.as_slice()) {
1831            // count(*), count()
1832            (false, [FunctionArgExpr::Wildcard] | []) => len(),
1833            // count(column_name)
1834            (false, [FunctionArgExpr::Expr(sql_expr)]) => {
1835                let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1836                expr.count()
1837            },
1838            // count(distinct column_name)
1839            (true, [FunctionArgExpr::Expr(sql_expr)]) => {
1840                let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1841                expr.clone().n_unique().sub(expr.null_count().gt(lit(0)))
1842            },
1843            _ => self.not_supported_error()?,
1844        };
1845        self.apply_window_spec(count_expr, &self.func.over)
1846    }
1847
1848    fn apply_order_by(&mut self, expr: Expr, order_by: &[OrderByExpr]) -> PolarsResult<Expr> {
1849        let mut by = Vec::with_capacity(order_by.len());
1850        let mut descending = Vec::with_capacity(order_by.len());
1851        let mut nulls_last = Vec::with_capacity(order_by.len());
1852
1853        for ob in order_by {
1854            // note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise
1855            // https://www.postgresql.org/docs/current/queries-order.html
1856            let desc_order = !ob.asc.unwrap_or(true);
1857            by.push(parse_sql_expr(&ob.expr, self.ctx, self.active_schema)?);
1858            nulls_last.push(!ob.nulls_first.unwrap_or(desc_order));
1859            descending.push(desc_order);
1860        }
1861        Ok(expr.sort_by(
1862            by,
1863            SortMultipleOptions::default()
1864                .with_order_descending_multi(descending)
1865                .with_nulls_last_multi(nulls_last)
1866                .with_maintain_order(true),
1867        ))
1868    }
1869
1870    fn apply_window_spec(
1871        &mut self,
1872        expr: Expr,
1873        window_type: &Option<WindowType>,
1874    ) -> PolarsResult<Expr> {
1875        Ok(match &window_type {
1876            Some(WindowType::WindowSpec(window_spec)) => {
1877                if window_spec.partition_by.is_empty() {
1878                    let exprs = window_spec
1879                        .order_by
1880                        .iter()
1881                        .map(|o| {
1882                            let e = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
1883                            Ok(o.asc.map_or(e.clone(), |b| {
1884                                e.sort(SortOptions::default().with_order_descending(!b))
1885                            }))
1886                        })
1887                        .collect::<PolarsResult<Vec<_>>>()?;
1888                    expr.over(exprs)
1889                } else {
1890                    // Process for simple window specification, partition by first
1891                    let partition_by = window_spec
1892                        .partition_by
1893                        .iter()
1894                        .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
1895                        .collect::<PolarsResult<Vec<_>>>()?;
1896                    expr.over(partition_by)
1897                }
1898            },
1899            Some(WindowType::NamedWindow(named_window)) => polars_bail!(
1900                SQLInterface: "Named windows are not currently supported; found {:?}",
1901                named_window
1902            ),
1903            None => expr,
1904        })
1905    }
1906
1907    fn not_supported_error(&self) -> PolarsResult<Expr> {
1908        polars_bail!(
1909            SQLInterface:
1910            "no function matches the given name and arguments: `{}`",
1911            self.func.to_string()
1912        );
1913    }
1914}
1915
1916fn extract_args(func: &SQLFunction) -> PolarsResult<Vec<&FunctionArgExpr>> {
1917    let (args, _, _) = _extract_func_args(func, false, false)?;
1918    Ok(args)
1919}
1920
1921fn extract_args_distinct(func: &SQLFunction) -> PolarsResult<(Vec<&FunctionArgExpr>, bool)> {
1922    let (args, is_distinct, _) = _extract_func_args(func, true, false)?;
1923    Ok((args, is_distinct))
1924}
1925
1926fn extract_args_and_clauses(
1927    func: &SQLFunction,
1928) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
1929    _extract_func_args(func, true, true)
1930}
1931
1932fn _extract_func_args(
1933    func: &SQLFunction,
1934    get_distinct: bool,
1935    get_clauses: bool,
1936) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
1937    match &func.args {
1938        FunctionArguments::List(FunctionArgumentList {
1939            args,
1940            duplicate_treatment,
1941            clauses,
1942        }) => {
1943            let is_distinct = matches!(duplicate_treatment, Some(DuplicateTreatment::Distinct));
1944            if !(get_clauses || get_distinct) && is_distinct {
1945                polars_bail!(SQLSyntax: "unexpected use of DISTINCT found in '{}'", func.name)
1946            } else if !get_clauses && !clauses.is_empty() {
1947                polars_bail!(SQLSyntax: "unexpected clause found in '{}' ({})", func.name, clauses[0])
1948            } else {
1949                let unpacked_args = args
1950                    .iter()
1951                    .map(|arg| match arg {
1952                        FunctionArg::Named { arg, .. } => arg,
1953                        FunctionArg::ExprNamed { arg, .. } => arg,
1954                        FunctionArg::Unnamed(arg) => arg,
1955                    })
1956                    .collect();
1957                Ok((unpacked_args, is_distinct, clauses.clone()))
1958            }
1959        },
1960        FunctionArguments::Subquery { .. } => {
1961            Err(polars_err!(SQLInterface: "subquery not expected in {}", func.name))
1962        },
1963        FunctionArguments::None => Ok((vec![], false, vec![])),
1964    }
1965}
1966
1967pub(crate) trait FromSQLExpr {
1968    fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
1969    where
1970        Self: Sized;
1971}
1972
1973impl FromSQLExpr for f64 {
1974    fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
1975    where
1976        Self: Sized,
1977    {
1978        match expr {
1979            SQLExpr::Value(v) => match v {
1980                SQLValue::Number(s, _) => s
1981                    .parse()
1982                    .map_err(|_| polars_err!(SQLInterface: "cannot parse literal {:?}", s)),
1983                _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
1984            },
1985            _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
1986        }
1987    }
1988}
1989
1990impl FromSQLExpr for bool {
1991    fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
1992    where
1993        Self: Sized,
1994    {
1995        match expr {
1996            SQLExpr::Value(v) => match v {
1997                SQLValue::Boolean(v) => Ok(*v),
1998                _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", v),
1999            },
2000            _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", expr),
2001        }
2002    }
2003}
2004
2005impl FromSQLExpr for String {
2006    fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2007    where
2008        Self: Sized,
2009    {
2010        match expr {
2011            SQLExpr::Value(v) => match v {
2012                SQLValue::SingleQuotedString(s) => Ok(s.clone()),
2013                _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2014            },
2015            _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2016        }
2017    }
2018}
2019
2020impl FromSQLExpr for StrptimeOptions {
2021    fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2022    where
2023        Self: Sized,
2024    {
2025        match expr {
2026            SQLExpr::Value(v) => match v {
2027                SQLValue::SingleQuotedString(s) => Ok(StrptimeOptions {
2028                    format: Some(PlSmallStr::from_str(s)),
2029                    ..StrptimeOptions::default()
2030                }),
2031                _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2032            },
2033            _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2034        }
2035    }
2036}
2037
2038impl FromSQLExpr for Expr {
2039    fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2040    where
2041        Self: Sized,
2042    {
2043        parse_sql_expr(expr, ctx, None)
2044    }
2045}