1use std::ops::Sub;
2
3use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions};
4use polars_core::prelude::{
5 DataType, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail, polars_err,
6};
7use polars_lazy::dsl::Expr;
8use polars_ops::chunked_array::UnicodeForm;
9use polars_ops::series::RoundMode;
10use polars_plan::dsl::{coalesce, concat_str, len, max_horizontal, min_horizontal, when};
11use polars_plan::plans::{DynLiteralValue, LiteralValue, typed_lit};
12use polars_plan::prelude::{StrptimeOptions, col, cols, lit};
13use polars_utils::pl_str::PlSmallStr;
14use sqlparser::ast::helpers::attached_token::AttachedToken;
15use sqlparser::ast::{
16 DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg,
17 FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident,
18 OrderByExpr, Value as SQLValue, WindowSpec, WindowType,
19};
20use sqlparser::tokenizer::Span;
21
22use crate::SQLContext;
23use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr};
24
25pub(crate) struct SQLFunctionVisitor<'a> {
26 pub(crate) func: &'a SQLFunction,
27 pub(crate) ctx: &'a mut SQLContext,
28 pub(crate) active_schema: Option<&'a Schema>,
29}
30
31pub(crate) enum PolarsSQLFunctions {
33 BitAnd,
42 #[cfg(feature = "bitwise")]
48 BitCount,
49 BitOr,
55 BitXor,
61
62 Abs,
71 Ceil,
77 Div,
83 Exp,
89 Floor,
96 Pi,
102 Ln,
108 Log2,
114 Log10,
120 Log,
126 Log1p,
132 Pow,
138 Mod,
144 Sqrt,
150 Cbrt,
156 Round,
163 Sign,
169
170 Cos,
179 Cot,
185 Sin,
191 Tan,
197 CosD,
203 CotD,
209 SinD,
215 TanD,
221 Acos,
227 Asin,
233 Atan,
239 Atan2,
245 AcosD,
251 AsinD,
257 AtanD,
263 Atan2D,
269 Degrees,
277 Radians,
283
284 DatePart,
293 Strftime,
299
300 BitLength,
308 Concat,
314 ConcatWS,
321 Date,
330 EndsWith,
337 #[cfg(feature = "nightly")]
343 InitCap,
344 Left,
350 Length,
356 Lower,
362 LTrim,
368 Normalize,
375 OctetLength,
381 RegexpLike,
387 Replace,
393 Reverse,
399 Right,
405 RTrim,
411 SplitPart,
418 StartsWith,
425 StrPos,
431 Substring,
438 StringToArray,
444 Strptime,
450 Time,
459 Timestamp,
468 Upper,
474
475 Coalesce,
484 Greatest,
490 If,
497 IfNull,
503 Least,
509 NullIf,
515
516 Avg,
525 Corr,
531 Count,
540 CovarPop,
546 CovarSamp,
552 First,
558 Last,
564 Max,
570 Median,
576 QuantileCont,
583 QuantileDisc,
590 Min,
596 StdDev,
602 Sum,
608 Variance,
614 ArrayLength,
623 ArrayMin,
629 ArrayMax,
635 ArraySum,
641 ArrayMean,
647 ArrayReverse,
653 ArrayUnique,
659 Explode,
665 ArrayAgg,
671 ArrayToString,
678 ArrayGet,
684 ArrayContains,
690
691 Columns,
695
696 Udf(String),
700}
701
702impl PolarsSQLFunctions {
703 pub(crate) fn keywords() -> &'static [&'static str] {
704 &[
705 "abs",
706 "acos",
707 "acosd",
708 "array_contains",
709 "array_get",
710 "array_length",
711 "array_lower",
712 "array_mean",
713 "array_reverse",
714 "array_sum",
715 "array_to_string",
716 "array_unique",
717 "array_upper",
718 "asin",
719 "asind",
720 "atan",
721 "atan2",
722 "atan2d",
723 "atand",
724 "avg",
725 "bit_and",
726 "bit_count",
727 "bit_length",
728 "bit_or",
729 "bit_xor",
730 "cbrt",
731 "ceil",
732 "ceiling",
733 "char_length",
734 "character_length",
735 "coalesce",
736 "columns",
737 "concat",
738 "concat_ws",
739 "corr",
740 "cos",
741 "cosd",
742 "cot",
743 "cotd",
744 "count",
745 "covar",
746 "covar_pop",
747 "covar_samp",
748 "date",
749 "date_part",
750 "degrees",
751 "ends_with",
752 "exp",
753 "first",
754 "floor",
755 "greatest",
756 "if",
757 "ifnull",
758 "initcap",
759 "last",
760 "least",
761 "left",
762 "length",
763 "ln",
764 "log",
765 "log10",
766 "log1p",
767 "log2",
768 "lower",
769 "ltrim",
770 "max",
771 "median",
772 "quantile_disc",
773 "min",
774 "mod",
775 "nullif",
776 "octet_length",
777 "pi",
778 "pow",
779 "power",
780 "quantile_cont",
781 "quantile_disc",
782 "radians",
783 "regexp_like",
784 "replace",
785 "reverse",
786 "right",
787 "round",
788 "rtrim",
789 "sign",
790 "sin",
791 "sind",
792 "sqrt",
793 "starts_with",
794 "stddev",
795 "stddev_samp",
796 "stdev",
797 "stdev_samp",
798 "strftime",
799 "strpos",
800 "strptime",
801 "substr",
802 "sum",
803 "tan",
804 "tand",
805 "unnest",
806 "upper",
807 "var",
808 "var_samp",
809 "variance",
810 ]
811 }
812}
813
814impl PolarsSQLFunctions {
815 fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
816 let function_name = function.name.0[0].value.to_lowercase();
817 Ok(match function_name.as_str() {
818 "bit_and" | "bitand" => Self::BitAnd,
822 #[cfg(feature = "bitwise")]
823 "bit_count" | "bitcount" => Self::BitCount,
824 "bit_or" | "bitor" => Self::BitOr,
825 "bit_xor" | "bitxor" | "xor" => Self::BitXor,
826
827 "abs" => Self::Abs,
831 "cbrt" => Self::Cbrt,
832 "ceil" | "ceiling" => Self::Ceil,
833 "div" => Self::Div,
834 "exp" => Self::Exp,
835 "floor" => Self::Floor,
836 "ln" => Self::Ln,
837 "log" => Self::Log,
838 "log10" => Self::Log10,
839 "log1p" => Self::Log1p,
840 "log2" => Self::Log2,
841 "mod" => Self::Mod,
842 "pi" => Self::Pi,
843 "pow" | "power" => Self::Pow,
844 "round" => Self::Round,
845 "sign" => Self::Sign,
846 "sqrt" => Self::Sqrt,
847
848 "cos" => Self::Cos,
852 "cot" => Self::Cot,
853 "sin" => Self::Sin,
854 "tan" => Self::Tan,
855 "cosd" => Self::CosD,
856 "cotd" => Self::CotD,
857 "sind" => Self::SinD,
858 "tand" => Self::TanD,
859 "acos" => Self::Acos,
860 "asin" => Self::Asin,
861 "atan" => Self::Atan,
862 "atan2" => Self::Atan2,
863 "acosd" => Self::AcosD,
864 "asind" => Self::AsinD,
865 "atand" => Self::AtanD,
866 "atan2d" => Self::Atan2D,
867 "degrees" => Self::Degrees,
868 "radians" => Self::Radians,
869
870 "coalesce" => Self::Coalesce,
874 "greatest" => Self::Greatest,
875 "if" => Self::If,
876 "ifnull" => Self::IfNull,
877 "least" => Self::Least,
878 "nullif" => Self::NullIf,
879
880 "date_part" => Self::DatePart,
884 "strftime" => Self::Strftime,
885
886 "bit_length" => Self::BitLength,
890 "concat" => Self::Concat,
891 "concat_ws" => Self::ConcatWS,
892 "date" => Self::Date,
893 "timestamp" | "datetime" => Self::Timestamp,
894 "ends_with" => Self::EndsWith,
895 #[cfg(feature = "nightly")]
896 "initcap" => Self::InitCap,
897 "length" | "char_length" | "character_length" => Self::Length,
898 "left" => Self::Left,
899 "lower" => Self::Lower,
900 "ltrim" => Self::LTrim,
901 "normalize" => Self::Normalize,
902 "octet_length" => Self::OctetLength,
903 "strpos" => Self::StrPos,
904 "regexp_like" => Self::RegexpLike,
905 "replace" => Self::Replace,
906 "reverse" => Self::Reverse,
907 "right" => Self::Right,
908 "rtrim" => Self::RTrim,
909 "split_part" => Self::SplitPart,
910 "starts_with" => Self::StartsWith,
911 "string_to_array" => Self::StringToArray,
912 "strptime" => Self::Strptime,
913 "substr" => Self::Substring,
914 "time" => Self::Time,
915 "upper" => Self::Upper,
916
917 "avg" => Self::Avg,
921 "corr" => Self::Corr,
922 "count" => Self::Count,
923 "covar_pop" => Self::CovarPop,
924 "covar" | "covar_samp" => Self::CovarSamp,
925 "first" => Self::First,
926 "last" => Self::Last,
927 "max" => Self::Max,
928 "median" => Self::Median,
929 "quantile_cont" => Self::QuantileCont,
930 "quantile_disc" => Self::QuantileDisc,
931 "min" => Self::Min,
932 "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev,
933 "sum" => Self::Sum,
934 "var" | "variance" | "var_samp" => Self::Variance,
935
936 "array_agg" => Self::ArrayAgg,
940 "array_contains" => Self::ArrayContains,
941 "array_get" => Self::ArrayGet,
942 "array_length" => Self::ArrayLength,
943 "array_lower" => Self::ArrayMin,
944 "array_mean" => Self::ArrayMean,
945 "array_reverse" => Self::ArrayReverse,
946 "array_sum" => Self::ArraySum,
947 "array_to_string" => Self::ArrayToString,
948 "array_unique" => Self::ArrayUnique,
949 "array_upper" => Self::ArrayMax,
950 "unnest" => Self::Explode,
951
952 "columns" => Self::Columns,
956
957 other => {
958 if ctx.function_registry.contains(other) {
959 Self::Udf(other.to_string())
960 } else {
961 polars_bail!(SQLInterface: "unsupported function '{}'", other);
962 }
963 },
964 })
965 }
966}
967
968impl SQLFunctionVisitor<'_> {
969 pub(crate) fn visit_function(&mut self) -> PolarsResult<Expr> {
970 use PolarsSQLFunctions::*;
971 use polars_lazy::prelude::Literal;
972
973 let function_name = PolarsSQLFunctions::try_from_sql(self.func, self.ctx)?;
974 let function = self.func;
975
976 if !function.within_group.is_empty() {
978 polars_bail!(SQLInterface: "'WITHIN GROUP' is not currently supported")
979 }
980 if function.filter.is_some() {
981 polars_bail!(SQLInterface: "'FILTER' is not currently supported")
982 }
983 if function.null_treatment.is_some() {
984 polars_bail!(SQLInterface: "'IGNORE|RESPECT NULLS' is not currently supported")
985 }
986
987 let log_with_base =
988 |e: Expr, base: f64| e.log(LiteralValue::Dyn(DynLiteralValue::Float(base)).lit());
989 match function_name {
990 BitAnd => self.visit_binary::<Expr>(Expr::and),
994 #[cfg(feature = "bitwise")]
995 BitCount => self.visit_unary(Expr::bitwise_count_ones),
996 BitOr => self.visit_binary::<Expr>(Expr::or),
997 BitXor => self.visit_binary::<Expr>(Expr::xor),
998
999 Abs => self.visit_unary(Expr::abs),
1003 Cbrt => self.visit_unary(Expr::cbrt),
1004 Ceil => self.visit_unary(Expr::ceil),
1005 Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64)),
1006 Exp => self.visit_unary(Expr::exp),
1007 Floor => self.visit_unary(Expr::floor),
1008 Ln => self.visit_unary(|e| log_with_base(e, std::f64::consts::E)),
1009 Log => self.visit_binary(Expr::log),
1010 Log10 => self.visit_unary(|e| log_with_base(e, 10.0)),
1011 Log1p => self.visit_unary(Expr::log1p),
1012 Log2 => self.visit_unary(|e| log_with_base(e, 2.0)),
1013 Pi => self.visit_nullary(Expr::pi),
1014 Mod => self.visit_binary(|e1, e2| e1 % e2),
1015 Pow => self.visit_binary::<Expr>(Expr::pow),
1016 Round => {
1017 let args = extract_args(function)?;
1018 match args.len() {
1019 1 => self.visit_unary(|e| e.round(0, RoundMode::default())),
1020 2 => self.try_visit_binary(|e, decimals| {
1021 Ok(e.round(match decimals {
1022 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1023 if n >= 0 { n as u32 } else {
1024 polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", args[1])
1025 }
1026 },
1027 _ => polars_bail!(SQLSyntax: "invalid value for ROUND decimals ({})", args[1]),
1028 }, RoundMode::default()))
1029 }),
1030 _ => polars_bail!(SQLSyntax: "ROUND expects 1-2 arguments (found {})", args.len()),
1031 }
1032 },
1033 Sign => self.visit_unary(Expr::sign),
1034 Sqrt => self.visit_unary(Expr::sqrt),
1035
1036 Acos => self.visit_unary(Expr::arccos),
1040 AcosD => self.visit_unary(|e| e.arccos().degrees()),
1041 Asin => self.visit_unary(Expr::arcsin),
1042 AsinD => self.visit_unary(|e| e.arcsin().degrees()),
1043 Atan => self.visit_unary(Expr::arctan),
1044 Atan2 => self.visit_binary(Expr::arctan2),
1045 Atan2D => self.visit_binary(|e, s| e.arctan2(s).degrees()),
1046 AtanD => self.visit_unary(|e| e.arctan().degrees()),
1047 Cos => self.visit_unary(Expr::cos),
1048 CosD => self.visit_unary(|e| e.radians().cos()),
1049 Cot => self.visit_unary(Expr::cot),
1050 CotD => self.visit_unary(|e| e.radians().cot()),
1051 Degrees => self.visit_unary(Expr::degrees),
1052 Radians => self.visit_unary(Expr::radians),
1053 Sin => self.visit_unary(Expr::sin),
1054 SinD => self.visit_unary(|e| e.radians().sin()),
1055 Tan => self.visit_unary(Expr::tan),
1056 TanD => self.visit_unary(|e| e.radians().tan()),
1057
1058 Coalesce => self.visit_variadic(coalesce),
1062 Greatest => self.visit_variadic(|exprs: &[Expr]| max_horizontal(exprs).unwrap()),
1063 If => {
1064 let args = extract_args(function)?;
1065 match args.len() {
1066 3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| {
1067 Ok(when(cond).then(expr1).otherwise(expr2))
1068 }),
1069 _ => {
1070 polars_bail!(SQLSyntax: "IF expects 3 arguments (found {})", args.len()
1071 )
1072 },
1073 }
1074 },
1075 IfNull => {
1076 let args = extract_args(function)?;
1077 match args.len() {
1078 2 => self.visit_variadic(coalesce),
1079 _ => {
1080 polars_bail!(SQLSyntax: "IFNULL expects 2 arguments (found {})", args.len())
1081 },
1082 }
1083 },
1084 Least => self.visit_variadic(|exprs: &[Expr]| min_horizontal(exprs).unwrap()),
1085 NullIf => {
1086 let args = extract_args(function)?;
1087 match args.len() {
1088 2 => self.visit_binary(|l: Expr, r: Expr| {
1089 when(l.clone().eq(r))
1090 .then(lit(LiteralValue::untyped_null()))
1091 .otherwise(l)
1092 }),
1093 _ => {
1094 polars_bail!(SQLSyntax: "NULLIF expects 2 arguments (found {})", args.len())
1095 },
1096 }
1097 },
1098
1099 DatePart => self.try_visit_binary(|part, e| {
1103 match part {
1104 Expr::Literal(p) if p.extract_str().is_some() => {
1105 let p = p.extract_str().unwrap();
1106 parse_extract_date_part(
1109 e,
1110 &DateTimeField::Custom(Ident {
1111 value: p.to_string(),
1112 quote_style: None,
1113 span: Span::empty(),
1114 }),
1115 )
1116 },
1117 _ => {
1118 polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART ({})", part);
1119 },
1120 }
1121 }),
1122 Strftime => {
1123 let args = extract_args(function)?;
1124 match args.len() {
1125 2 => self.visit_binary(|e, fmt: String| e.dt().strftime(fmt.as_str())),
1126 _ => {
1127 polars_bail!(SQLSyntax: "STRFTIME expects 2 arguments (found {})", args.len())
1128 },
1129 }
1130 },
1131
1132 BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)),
1136 Concat => {
1137 let args = extract_args(function)?;
1138 if args.is_empty() {
1139 polars_bail!(SQLSyntax: "CONCAT expects at least 1 argument (found 0)");
1140 } else {
1141 self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true))
1142 }
1143 },
1144 ConcatWS => {
1145 let args = extract_args(function)?;
1146 if args.len() < 2 {
1147 polars_bail!(SQLSyntax: "CONCAT_WS expects at least 2 arguments (found {})", args.len());
1148 } else {
1149 self.try_visit_variadic(|exprs: &[Expr]| {
1150 match &exprs[0] {
1151 Expr::Literal(lv) if lv.extract_str().is_some() => Ok(concat_str(&exprs[1..], lv.extract_str().unwrap(), true)),
1152 _ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string (found {:?})", exprs[0]),
1153 }
1154 })
1155 }
1156 },
1157 Date => {
1158 let args = extract_args(function)?;
1159 match args.len() {
1160 1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
1161 2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
1162 _ => {
1163 polars_bail!(SQLSyntax: "DATE expects 1-2 arguments (found {})", args.len())
1164 },
1165 }
1166 },
1167 EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
1168 #[cfg(feature = "nightly")]
1169 InitCap => self.visit_unary(|e| e.str().to_titlecase()),
1170 Left => self.try_visit_binary(|e, length| {
1171 Ok(match length {
1172 Expr::Literal(lv) if lv.is_null() => lit(lv),
1173 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => lit(""),
1174 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1175 let len = if n > 0 {
1176 lit(n)
1177 } else {
1178 (e.clone().str().len_chars() + lit(n)).clip_min(lit(0))
1179 };
1180 e.str().slice(lit(0), len)
1181 },
1182 Expr::Literal(v) => {
1183 polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT ({:?})", v)
1184 },
1185 _ => when(length.clone().gt_eq(lit(0)))
1186 .then(e.clone().str().slice(lit(0), length.clone().abs()))
1187 .otherwise(e.clone().str().slice(
1188 lit(0),
1189 (e.str().len_chars() + length.clone()).clip_min(lit(0)),
1190 )),
1191 })
1192 }),
1193 Length => self.visit_unary(|e| e.str().len_chars()),
1194 Lower => self.visit_unary(|e| e.str().to_lowercase()),
1195 LTrim => {
1196 let args = extract_args(function)?;
1197 match args.len() {
1198 1 => self.visit_unary(|e| {
1199 e.str().strip_chars_start(lit(LiteralValue::untyped_null()))
1200 }),
1201 2 => self.visit_binary(|e, s| e.str().strip_chars_start(s)),
1202 _ => {
1203 polars_bail!(SQLSyntax: "LTRIM expects 1-2 arguments (found {})", args.len())
1204 },
1205 }
1206 },
1207 Normalize => {
1208 let args = extract_args(function)?;
1209 match args.len() {
1210 1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)),
1211 2 => {
1212 let form = if let FunctionArgExpr::Expr(SQLExpr::Identifier(Ident {
1213 value: s,
1214 quote_style: None,
1215 span: _,
1216 })) = args[1]
1217 {
1218 match s.to_uppercase().as_str() {
1219 "NFC" => UnicodeForm::NFC,
1220 "NFD" => UnicodeForm::NFD,
1221 "NFKC" => UnicodeForm::NFKC,
1222 "NFKD" => UnicodeForm::NFKD,
1223 _ => {
1224 polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s)
1225 },
1226 }
1227 } else {
1228 polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", args[1])
1229 };
1230 self.try_visit_binary(|e, _form: Expr| Ok(e.str().normalize(form.clone())))
1231 },
1232 _ => {
1233 polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len())
1234 },
1235 }
1236 },
1237 OctetLength => self.visit_unary(|e| e.str().len_bytes()),
1238 StrPos => {
1239 self.visit_binary(|expr, substring| {
1241 (expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))
1242 })
1243 },
1244 RegexpLike => {
1245 let args = extract_args(function)?;
1246 match args.len() {
1247 2 => self.visit_binary(|e, s| e.str().contains(s, true)),
1248 3 => self.try_visit_ternary(|e, pat, flags| {
1249 Ok(e.str().contains(
1250 match (pat, flags) {
1251 (Expr::Literal(s_lv), Expr::Literal(f_lv)) if s_lv.extract_str().is_some() && f_lv.extract_str().is_some() => {
1252 let s = s_lv.extract_str().unwrap();
1253 let f = f_lv.extract_str().unwrap();
1254 if f.is_empty() {
1255 polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE ({})", args[2]);
1256 };
1257 lit(format!("(?{f}){s}"))
1258 },
1259 _ => {
1260 polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE ({}, {})", args[1], args[2]);
1261 },
1262 },
1263 true))
1264 }),
1265 _ => polars_bail!(SQLSyntax: "REGEXP_LIKE expects 2-3 arguments (found {})",args.len()),
1266 }
1267 },
1268 Replace => {
1269 let args = extract_args(function)?;
1270 match args.len() {
1271 3 => self
1272 .try_visit_ternary(|e, old, new| Ok(e.str().replace_all(old, new, true))),
1273 _ => {
1274 polars_bail!(SQLSyntax: "REPLACE expects 3 arguments (found {})", args.len())
1275 },
1276 }
1277 },
1278 Reverse => self.visit_unary(|e| e.str().reverse()),
1279 Right => self.try_visit_binary(|e, length| {
1280 Ok(match length {
1281 Expr::Literal(lv) if lv.is_null() => lit(lv),
1282 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => typed_lit(""),
1283 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1284 let n: i64 = n.try_into().unwrap();
1285 let offset = if n < 0 {
1286 lit(n.abs())
1287 } else {
1288 e.clone().str().len_chars().cast(DataType::Int32) - lit(n)
1289 };
1290 e.str().slice(offset, lit(LiteralValue::untyped_null()))
1291 },
1292 Expr::Literal(v) => {
1293 polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT ({:?})", v)
1294 },
1295 _ => when(length.clone().lt(lit(0)))
1296 .then(
1297 e.clone()
1298 .str()
1299 .slice(length.clone().abs(), lit(LiteralValue::untyped_null())),
1300 )
1301 .otherwise(e.clone().str().slice(
1302 e.str().len_chars().cast(DataType::Int32) - length.clone(),
1303 lit(LiteralValue::untyped_null()),
1304 )),
1305 })
1306 }),
1307 RTrim => {
1308 let args = extract_args(function)?;
1309 match args.len() {
1310 1 => self.visit_unary(|e| {
1311 e.str().strip_chars_end(lit(LiteralValue::untyped_null()))
1312 }),
1313 2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
1314 _ => {
1315 polars_bail!(SQLSyntax: "RTRIM expects 1-2 arguments (found {})", args.len())
1316 },
1317 }
1318 },
1319 SplitPart => {
1320 let args = extract_args(function)?;
1321 match args.len() {
1322 3 => self.try_visit_ternary(|e, sep, idx| {
1323 let idx = adjust_one_indexed_param(idx, true);
1324 Ok(when(e.clone().is_not_null())
1325 .then(
1326 e.clone()
1327 .str()
1328 .split(sep)
1329 .list()
1330 .get(idx, true)
1331 .fill_null(lit("")),
1332 )
1333 .otherwise(e))
1334 }),
1335 _ => {
1336 polars_bail!(SQLSyntax: "SPLIT_PART expects 3 arguments (found {})", args.len())
1337 },
1338 }
1339 },
1340 StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
1341 StringToArray => {
1342 let args = extract_args(function)?;
1343 match args.len() {
1344 2 => self.visit_binary(|e, sep| e.str().split(sep)),
1345 _ => {
1346 polars_bail!(SQLSyntax: "STRING_TO_ARRAY expects 2 arguments (found {})", args.len())
1347 },
1348 }
1349 },
1350 Strptime => {
1351 let args = extract_args(function)?;
1352 match args.len() {
1353 2 => self.visit_binary(|e, fmt: String| {
1354 e.str().strptime(
1355 DataType::Datetime(TimeUnit::Microseconds, None),
1356 StrptimeOptions {
1357 format: Some(fmt.into()),
1358 ..Default::default()
1359 },
1360 lit("latest"),
1361 )
1362 }),
1363 _ => {
1364 polars_bail!(SQLSyntax: "STRPTIME expects 2 arguments (found {})", args.len())
1365 },
1366 }
1367 },
1368 Time => {
1369 let args = extract_args(function)?;
1370 match args.len() {
1371 1 => self.visit_unary(|e| e.str().to_time(StrptimeOptions::default())),
1372 2 => self.visit_binary(|e, fmt| e.str().to_time(fmt)),
1373 _ => {
1374 polars_bail!(SQLSyntax: "TIME expects 1-2 arguments (found {})", args.len())
1375 },
1376 }
1377 },
1378 Timestamp => {
1379 let args = extract_args(function)?;
1380 match args.len() {
1381 1 => self.visit_unary(|e| {
1382 e.str()
1383 .to_datetime(None, None, StrptimeOptions::default(), lit("latest"))
1384 }),
1385 2 => self
1386 .visit_binary(|e, fmt| e.str().to_datetime(None, None, fmt, lit("latest"))),
1387 _ => {
1388 polars_bail!(SQLSyntax: "DATETIME expects 1-2 arguments (found {})", args.len())
1389 },
1390 }
1391 },
1392 Substring => {
1393 let args = extract_args(function)?;
1394 match args.len() {
1395 2 => self.try_visit_binary(|e, start| {
1397 Ok(match start {
1398 Expr::Literal(lv) if lv.is_null() => lit(lv),
1399 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1400 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null())),
1401 Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1402 _ => start.clone() + lit(1),
1403 })
1404 }),
1405 3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
1406 Ok(match (start.clone(), length.clone()) {
1407 (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1408 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1409 polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", args[2])
1410 },
1411 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => e.str().slice(lit(n - 1), length),
1412 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => {
1413 e.str().slice(lit(0), (length + lit(n - 1)).clip_min(lit(0)))
1414 },
1415 (Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1416 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1417 polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR ({})", args[1])
1418 },
1419 _ => {
1420 let adjusted_start = start - lit(1);
1421 when(adjusted_start.clone().lt(lit(0)))
1422 .then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
1423 .otherwise(e.str().slice(adjusted_start, length))
1424 }
1425 })
1426 }),
1427 _ => polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found {})", args.len()),
1428 }
1429 },
1430 Upper => self.visit_unary(|e| e.str().to_uppercase()),
1431
1432 Avg => self.visit_unary(Expr::mean),
1436 Corr => self.visit_binary(polars_lazy::dsl::pearson_corr),
1437 Count => self.visit_count(),
1438 CovarPop => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 0)),
1439 CovarSamp => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 1)),
1440 First => self.visit_unary(Expr::first),
1441 Last => self.visit_unary(Expr::last),
1442 Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max),
1443 Median => self.visit_unary(Expr::median),
1444 QuantileCont => {
1445 let args = extract_args(function)?;
1446 match args.len() {
1447 2 => self.try_visit_binary(|e, q| {
1448 let value = match q {
1449 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1450 if (0.0..=1.0).contains(&f) {
1451 Expr::from(f)
1452 } else {
1453 polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1454 }
1455 },
1456 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1457 if (0..=1).contains(&n) {
1458 Expr::from(n as f64)
1459 } else {
1460 polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1461 }
1462 },
1463 _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1])
1464 };
1465 Ok(e.quantile(value, QuantileMethod::Linear))
1466 }),
1467 _ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()),
1468 }
1469 },
1470 QuantileDisc => {
1471 let args = extract_args(function)?;
1472 match args.len() {
1473 2 => self.try_visit_binary(|e, q| {
1474 let value = match q {
1475 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1476 if (0.0..=1.0).contains(&f) {
1477 Expr::from(f)
1478 } else {
1479 polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1480 }
1481 },
1482 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1483 if (0..=1).contains(&n) {
1484 Expr::from(n as f64)
1485 } else {
1486 polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1487 }
1488 },
1489 _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1])
1490 };
1491 Ok(e.quantile(value, QuantileMethod::Equiprobable))
1492 }),
1493 _ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()),
1494 }
1495 },
1496 Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min),
1497 StdDev => self.visit_unary(|e| e.std(1)),
1498 Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum),
1499 Variance => self.visit_unary(|e| e.var(1)),
1500
1501 ArrayAgg => self.visit_arr_agg(),
1505 ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s, true)),
1506 ArrayGet => {
1507 self.visit_binary(|e, idx: Expr| {
1509 let idx = adjust_one_indexed_param(idx, true);
1510 e.list().get(idx, true)
1511 })
1512 },
1513 ArrayLength => self.visit_unary(|e| e.list().len()),
1514 ArrayMax => self.visit_unary(|e| e.list().max()),
1515 ArrayMean => self.visit_unary(|e| e.list().mean()),
1516 ArrayMin => self.visit_unary(|e| e.list().min()),
1517 ArrayReverse => self.visit_unary(|e| e.list().reverse()),
1518 ArraySum => self.visit_unary(|e| e.list().sum()),
1519 ArrayToString => self.visit_arr_to_string(),
1520 ArrayUnique => self.visit_unary(|e| e.list().unique()),
1521 Explode => self.visit_unary(|e| e.explode()),
1522
1523 Columns => {
1527 let active_schema = self.active_schema;
1528 self.try_visit_unary(|e: Expr| match e {
1529 Expr::Literal(lv) if lv.extract_str().is_some() => {
1530 let pat = lv.extract_str().unwrap();
1531 if pat == "*" {
1532 polars_bail!(
1533 SQLSyntax: "COLUMNS('*') is not a valid regex; \
1534 did you mean COLUMNS(*)?"
1535 )
1536 };
1537 let pat = match pat {
1538 _ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(),
1539 _ if pat.starts_with('^') => format!("{pat}.*$"),
1540 _ if pat.ends_with('$') => format!("^.*{pat}"),
1541 _ => format!("^.*{pat}.*$"),
1542 };
1543 if let Some(active_schema) = &active_schema {
1544 let rx = polars_utils::regex_cache::compile_regex(&pat).unwrap();
1545 let col_names = active_schema
1546 .iter_names()
1547 .filter(|name| rx.is_match(name))
1548 .cloned()
1549 .collect::<Vec<_>>();
1550
1551 Ok(if col_names.len() == 1 {
1552 col(col_names.into_iter().next().unwrap())
1553 } else {
1554 cols(col_names).as_expr()
1555 })
1556 } else {
1557 Ok(col(pat.as_str()))
1558 }
1559 },
1560 Expr::Selector(s) => Ok(s.as_expr()),
1561 _ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e),
1562 })
1563 },
1564
1565 Udf(func_name) => self.visit_udf(&func_name),
1569 }
1570 }
1571
1572 fn visit_udf(&mut self, func_name: &str) -> PolarsResult<Expr> {
1573 let args = extract_args(self.func)?
1574 .into_iter()
1575 .map(|arg| {
1576 if let FunctionArgExpr::Expr(e) = arg {
1577 parse_sql_expr(e, self.ctx, self.active_schema)
1578 } else {
1579 polars_bail!(SQLInterface: "only expressions are supported in UDFs")
1580 }
1581 })
1582 .collect::<PolarsResult<Vec<_>>>()?;
1583
1584 Ok(self
1585 .ctx
1586 .function_registry
1587 .get_udf(func_name)?
1588 .ok_or_else(|| polars_err!(SQLInterface: "UDF {} not found", func_name))?
1589 .call(args))
1590 }
1591
1592 fn apply_cumulative_window(
1595 &mut self,
1596 f: impl Fn(Expr) -> Expr,
1597 cumulative_f: impl Fn(Expr, bool) -> Expr,
1598 WindowSpec {
1599 partition_by,
1600 order_by,
1601 ..
1602 }: &WindowSpec,
1603 ) -> PolarsResult<Expr> {
1604 if !order_by.is_empty() && partition_by.is_empty() {
1605 let (order_by, desc): (Vec<Expr>, Vec<bool>) = order_by
1606 .iter()
1607 .map(|o| {
1608 let expr = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
1609 Ok(match o.asc {
1610 Some(b) => (expr, !b),
1611 None => (expr, false),
1612 })
1613 })
1614 .collect::<PolarsResult<Vec<_>>>()?
1615 .into_iter()
1616 .unzip();
1617 self.visit_unary_no_window(|e| {
1618 cumulative_f(
1619 e.sort_by(
1620 &order_by,
1621 SortMultipleOptions::default().with_order_descending_multi(desc.clone()),
1622 ),
1623 false,
1624 )
1625 })
1626 } else {
1627 self.visit_unary(f)
1628 }
1629 }
1630
1631 fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1632 self.try_visit_unary(|e| Ok(f(e)))
1633 }
1634
1635 fn try_visit_unary(&mut self, f: impl Fn(Expr) -> PolarsResult<Expr>) -> PolarsResult<Expr> {
1636 let args = extract_args(self.func)?;
1637 match args.as_slice() {
1638 [FunctionArgExpr::Expr(sql_expr)] => {
1639 f(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?)
1640 },
1641 [FunctionArgExpr::Wildcard] => f(parse_sql_expr(
1642 &SQLExpr::Wildcard(AttachedToken::empty()),
1643 self.ctx,
1644 self.active_schema,
1645 )?),
1646 _ => self.not_supported_error(),
1647 }
1648 .and_then(|e| self.apply_window_spec(e, &self.func.over))
1649 }
1650
1651 fn visit_unary_with_opt_cumulative(
1657 &mut self,
1658 f: impl Fn(Expr) -> Expr,
1659 cumulative_f: impl Fn(Expr, bool) -> Expr,
1660 ) -> PolarsResult<Expr> {
1661 match self.func.over.as_ref() {
1662 Some(WindowType::WindowSpec(spec)) => {
1663 self.apply_cumulative_window(f, cumulative_f, spec)
1664 },
1665 Some(WindowType::NamedWindow(named_window)) => polars_bail!(
1666 SQLInterface: "Named windows are not currently supported; found {:?}",
1667 named_window
1668 ),
1669 _ => self.visit_unary(f),
1670 }
1671 }
1672
1673 fn visit_unary_no_window(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1674 let args = extract_args(self.func)?;
1675 match args.as_slice() {
1676 [FunctionArgExpr::Expr(sql_expr)] => {
1677 let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1678 Ok(f(expr))
1680 },
1681 _ => self.not_supported_error(),
1682 }
1683 }
1684
1685 fn visit_binary<Arg: FromSQLExpr>(
1686 &mut self,
1687 f: impl Fn(Expr, Arg) -> Expr,
1688 ) -> PolarsResult<Expr> {
1689 self.try_visit_binary(|e, a| Ok(f(e, a)))
1690 }
1691
1692 fn try_visit_binary<Arg: FromSQLExpr>(
1693 &mut self,
1694 f: impl Fn(Expr, Arg) -> PolarsResult<Expr>,
1695 ) -> PolarsResult<Expr> {
1696 let args = extract_args(self.func)?;
1697 match args.as_slice() {
1698 [
1699 FunctionArgExpr::Expr(sql_expr1),
1700 FunctionArgExpr::Expr(sql_expr2),
1701 ] => {
1702 let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
1703 let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
1704 f(expr1, expr2)
1705 },
1706 _ => self.not_supported_error(),
1707 }
1708 }
1709
1710 fn visit_variadic(&mut self, f: impl Fn(&[Expr]) -> Expr) -> PolarsResult<Expr> {
1711 self.try_visit_variadic(|e| Ok(f(e)))
1712 }
1713
1714 fn try_visit_variadic(
1715 &mut self,
1716 f: impl Fn(&[Expr]) -> PolarsResult<Expr>,
1717 ) -> PolarsResult<Expr> {
1718 let args = extract_args(self.func)?;
1719 let mut expr_args = vec![];
1720 for arg in args {
1721 if let FunctionArgExpr::Expr(sql_expr) = arg {
1722 expr_args.push(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?);
1723 } else {
1724 return self.not_supported_error();
1725 };
1726 }
1727 f(&expr_args)
1728 }
1729
1730 fn try_visit_ternary<Arg: FromSQLExpr>(
1731 &mut self,
1732 f: impl Fn(Expr, Arg, Arg) -> PolarsResult<Expr>,
1733 ) -> PolarsResult<Expr> {
1734 let args = extract_args(self.func)?;
1735 match args.as_slice() {
1736 [
1737 FunctionArgExpr::Expr(sql_expr1),
1738 FunctionArgExpr::Expr(sql_expr2),
1739 FunctionArgExpr::Expr(sql_expr3),
1740 ] => {
1741 let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
1742 let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
1743 let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?;
1744 f(expr1, expr2, expr3)
1745 },
1746 _ => self.not_supported_error(),
1747 }
1748 }
1749
1750 fn visit_nullary(&self, f: impl Fn() -> Expr) -> PolarsResult<Expr> {
1751 let args = extract_args(self.func)?;
1752 if !args.is_empty() {
1753 return self.not_supported_error();
1754 }
1755 Ok(f())
1756 }
1757
1758 fn visit_arr_agg(&mut self) -> PolarsResult<Expr> {
1759 let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
1760 match args.as_slice() {
1761 [FunctionArgExpr::Expr(sql_expr)] => {
1762 let mut base = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1763 if is_distinct {
1764 base = base.unique_stable();
1765 }
1766 for clause in clauses {
1767 match clause {
1768 FunctionArgumentClause::OrderBy(order_exprs) => {
1769 base = self.apply_order_by(base, order_exprs.as_slice())?;
1770 },
1771 FunctionArgumentClause::Limit(limit_expr) => {
1772 let limit = parse_sql_expr(&limit_expr, self.ctx, self.active_schema)?;
1773 match limit {
1774 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))
1775 if n >= 0 =>
1776 {
1777 base = base.head(Some(n as usize))
1778 },
1779 _ => {
1780 polars_bail!(SQLSyntax: "LIMIT in ARRAY_AGG must be a positive integer")
1781 },
1782 };
1783 },
1784 _ => {},
1785 }
1786 }
1787 Ok(base.implode())
1788 },
1789 _ => {
1790 polars_bail!(SQLSyntax: "ARRAY_AGG must have exactly one argument; found {}", args.len())
1791 },
1792 }
1793 }
1794
1795 fn visit_arr_to_string(&mut self) -> PolarsResult<Expr> {
1796 let args = extract_args(self.func)?;
1797 match args.len() {
1798 2 => self.try_visit_binary(|e, sep| {
1799 Ok(e.cast(DataType::List(Box::from(DataType::String)))
1800 .list()
1801 .join(sep, true))
1802 }),
1803 #[cfg(feature = "list_eval")]
1804 3 => self.try_visit_ternary(|e, sep, null_value| match null_value {
1805 Expr::Literal(lv) if lv.extract_str().is_some() => {
1806 Ok(if lv.extract_str().unwrap().is_empty() {
1807 e.cast(DataType::List(Box::from(DataType::String)))
1808 .list()
1809 .join(sep, true)
1810 } else {
1811 e.cast(DataType::List(Box::from(DataType::String)))
1812 .list()
1813 .eval(col("").fill_null(lit(lv.extract_str().unwrap())))
1814 .list()
1815 .join(sep, false)
1816 })
1817 },
1818 _ => {
1819 polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING ({})", args[2])
1820 },
1821 }),
1822 _ => {
1823 polars_bail!(SQLSyntax: "ARRAY_TO_STRING expects 2-3 arguments (found {})", args.len())
1824 },
1825 }
1826 }
1827
1828 fn visit_count(&mut self) -> PolarsResult<Expr> {
1829 let (args, is_distinct) = extract_args_distinct(self.func)?;
1830 let count_expr = match (is_distinct, args.as_slice()) {
1831 (false, [FunctionArgExpr::Wildcard] | []) => len(),
1833 (false, [FunctionArgExpr::Expr(sql_expr)]) => {
1835 let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1836 expr.count()
1837 },
1838 (true, [FunctionArgExpr::Expr(sql_expr)]) => {
1840 let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1841 expr.clone().n_unique().sub(expr.null_count().gt(lit(0)))
1842 },
1843 _ => self.not_supported_error()?,
1844 };
1845 self.apply_window_spec(count_expr, &self.func.over)
1846 }
1847
1848 fn apply_order_by(&mut self, expr: Expr, order_by: &[OrderByExpr]) -> PolarsResult<Expr> {
1849 let mut by = Vec::with_capacity(order_by.len());
1850 let mut descending = Vec::with_capacity(order_by.len());
1851 let mut nulls_last = Vec::with_capacity(order_by.len());
1852
1853 for ob in order_by {
1854 let desc_order = !ob.asc.unwrap_or(true);
1857 by.push(parse_sql_expr(&ob.expr, self.ctx, self.active_schema)?);
1858 nulls_last.push(!ob.nulls_first.unwrap_or(desc_order));
1859 descending.push(desc_order);
1860 }
1861 Ok(expr.sort_by(
1862 by,
1863 SortMultipleOptions::default()
1864 .with_order_descending_multi(descending)
1865 .with_nulls_last_multi(nulls_last)
1866 .with_maintain_order(true),
1867 ))
1868 }
1869
1870 fn apply_window_spec(
1871 &mut self,
1872 expr: Expr,
1873 window_type: &Option<WindowType>,
1874 ) -> PolarsResult<Expr> {
1875 Ok(match &window_type {
1876 Some(WindowType::WindowSpec(window_spec)) => {
1877 if window_spec.partition_by.is_empty() {
1878 let exprs = window_spec
1879 .order_by
1880 .iter()
1881 .map(|o| {
1882 let e = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
1883 Ok(o.asc.map_or(e.clone(), |b| {
1884 e.sort(SortOptions::default().with_order_descending(!b))
1885 }))
1886 })
1887 .collect::<PolarsResult<Vec<_>>>()?;
1888 expr.over(exprs)
1889 } else {
1890 let partition_by = window_spec
1892 .partition_by
1893 .iter()
1894 .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
1895 .collect::<PolarsResult<Vec<_>>>()?;
1896 expr.over(partition_by)
1897 }
1898 },
1899 Some(WindowType::NamedWindow(named_window)) => polars_bail!(
1900 SQLInterface: "Named windows are not currently supported; found {:?}",
1901 named_window
1902 ),
1903 None => expr,
1904 })
1905 }
1906
1907 fn not_supported_error(&self) -> PolarsResult<Expr> {
1908 polars_bail!(
1909 SQLInterface:
1910 "no function matches the given name and arguments: `{}`",
1911 self.func.to_string()
1912 );
1913 }
1914}
1915
1916fn extract_args(func: &SQLFunction) -> PolarsResult<Vec<&FunctionArgExpr>> {
1917 let (args, _, _) = _extract_func_args(func, false, false)?;
1918 Ok(args)
1919}
1920
1921fn extract_args_distinct(func: &SQLFunction) -> PolarsResult<(Vec<&FunctionArgExpr>, bool)> {
1922 let (args, is_distinct, _) = _extract_func_args(func, true, false)?;
1923 Ok((args, is_distinct))
1924}
1925
1926fn extract_args_and_clauses(
1927 func: &SQLFunction,
1928) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
1929 _extract_func_args(func, true, true)
1930}
1931
1932fn _extract_func_args(
1933 func: &SQLFunction,
1934 get_distinct: bool,
1935 get_clauses: bool,
1936) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
1937 match &func.args {
1938 FunctionArguments::List(FunctionArgumentList {
1939 args,
1940 duplicate_treatment,
1941 clauses,
1942 }) => {
1943 let is_distinct = matches!(duplicate_treatment, Some(DuplicateTreatment::Distinct));
1944 if !(get_clauses || get_distinct) && is_distinct {
1945 polars_bail!(SQLSyntax: "unexpected use of DISTINCT found in '{}'", func.name)
1946 } else if !get_clauses && !clauses.is_empty() {
1947 polars_bail!(SQLSyntax: "unexpected clause found in '{}' ({})", func.name, clauses[0])
1948 } else {
1949 let unpacked_args = args
1950 .iter()
1951 .map(|arg| match arg {
1952 FunctionArg::Named { arg, .. } => arg,
1953 FunctionArg::ExprNamed { arg, .. } => arg,
1954 FunctionArg::Unnamed(arg) => arg,
1955 })
1956 .collect();
1957 Ok((unpacked_args, is_distinct, clauses.clone()))
1958 }
1959 },
1960 FunctionArguments::Subquery { .. } => {
1961 Err(polars_err!(SQLInterface: "subquery not expected in {}", func.name))
1962 },
1963 FunctionArguments::None => Ok((vec![], false, vec![])),
1964 }
1965}
1966
1967pub(crate) trait FromSQLExpr {
1968 fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
1969 where
1970 Self: Sized;
1971}
1972
1973impl FromSQLExpr for f64 {
1974 fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
1975 where
1976 Self: Sized,
1977 {
1978 match expr {
1979 SQLExpr::Value(v) => match v {
1980 SQLValue::Number(s, _) => s
1981 .parse()
1982 .map_err(|_| polars_err!(SQLInterface: "cannot parse literal {:?}", s)),
1983 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
1984 },
1985 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
1986 }
1987 }
1988}
1989
1990impl FromSQLExpr for bool {
1991 fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
1992 where
1993 Self: Sized,
1994 {
1995 match expr {
1996 SQLExpr::Value(v) => match v {
1997 SQLValue::Boolean(v) => Ok(*v),
1998 _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", v),
1999 },
2000 _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", expr),
2001 }
2002 }
2003}
2004
2005impl FromSQLExpr for String {
2006 fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2007 where
2008 Self: Sized,
2009 {
2010 match expr {
2011 SQLExpr::Value(v) => match v {
2012 SQLValue::SingleQuotedString(s) => Ok(s.clone()),
2013 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2014 },
2015 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2016 }
2017 }
2018}
2019
2020impl FromSQLExpr for StrptimeOptions {
2021 fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2022 where
2023 Self: Sized,
2024 {
2025 match expr {
2026 SQLExpr::Value(v) => match v {
2027 SQLValue::SingleQuotedString(s) => Ok(StrptimeOptions {
2028 format: Some(PlSmallStr::from_str(s)),
2029 ..StrptimeOptions::default()
2030 }),
2031 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2032 },
2033 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2034 }
2035 }
2036}
2037
2038impl FromSQLExpr for Expr {
2039 fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2040 where
2041 Self: Sized,
2042 {
2043 parse_sql_expr(expr, ctx, None)
2044 }
2045}