1use crate::expr::{
21 AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22 Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction,
23};
24use crate::function::{
25 AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26 StateFieldsArgs,
27};
28use crate::ptr_eq::PtrEq;
29use crate::select_expr::SelectExpr;
30use crate::{
31 conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
32 AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs,
33 ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
34};
35use crate::{
36 AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
37};
38use arrow::compute::kernels::cast_utils::{
39 parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
40};
41use arrow::datatypes::{DataType, Field, FieldRef};
42use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference};
43use datafusion_functions_window_common::field::WindowUDFFieldArgs;
44use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
45use sqlparser::ast::NullTreatment;
46use std::any::Any;
47use std::fmt::Debug;
48use std::hash::Hash;
49use std::ops::Not;
50use std::sync::Arc;
51
52pub fn col(ident: impl Into<Column>) -> Expr {
68 Expr::Column(ident.into())
69}
70
71pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
74 Expr::OuterReferenceColumn(dt, ident.into())
75}
76
77pub fn ident(name: impl Into<String>) -> Expr {
96 Expr::Column(Column::from_name(name))
97}
98
99pub fn placeholder(id: impl Into<String>) -> Expr {
111 Expr::Placeholder(Placeholder {
112 id: id.into(),
113 data_type: None,
114 })
115}
116
117pub fn wildcard() -> SelectExpr {
127 SelectExpr::Wildcard(WildcardOptions::default())
128}
129
130pub fn wildcard_with_options(options: WildcardOptions) -> SelectExpr {
132 SelectExpr::Wildcard(options)
133}
134
135pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> SelectExpr {
146 SelectExpr::QualifiedWildcard(qualifier.into(), WildcardOptions::default())
147}
148
149pub fn qualified_wildcard_with_options(
151 qualifier: impl Into<TableReference>,
152 options: WildcardOptions,
153) -> SelectExpr {
154 SelectExpr::QualifiedWildcard(qualifier.into(), options)
155}
156
157pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
159 Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
160}
161
162pub fn and(left: Expr, right: Expr) -> Expr {
164 Expr::BinaryExpr(BinaryExpr::new(
165 Box::new(left),
166 Operator::And,
167 Box::new(right),
168 ))
169}
170
171pub fn or(left: Expr, right: Expr) -> Expr {
173 Expr::BinaryExpr(BinaryExpr::new(
174 Box::new(left),
175 Operator::Or,
176 Box::new(right),
177 ))
178}
179
180pub fn not(expr: Expr) -> Expr {
182 expr.not()
183}
184
185pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
187 Expr::BinaryExpr(BinaryExpr::new(
188 Box::new(left),
189 Operator::BitwiseAnd,
190 Box::new(right),
191 ))
192}
193
194pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
196 Expr::BinaryExpr(BinaryExpr::new(
197 Box::new(left),
198 Operator::BitwiseOr,
199 Box::new(right),
200 ))
201}
202
203pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
205 Expr::BinaryExpr(BinaryExpr::new(
206 Box::new(left),
207 Operator::BitwiseXor,
208 Box::new(right),
209 ))
210}
211
212pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
214 Expr::BinaryExpr(BinaryExpr::new(
215 Box::new(left),
216 Operator::BitwiseShiftRight,
217 Box::new(right),
218 ))
219}
220
221pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
223 Expr::BinaryExpr(BinaryExpr::new(
224 Box::new(left),
225 Operator::BitwiseShiftLeft,
226 Box::new(right),
227 ))
228}
229
230pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
232 Expr::InList(InList::new(Box::new(expr), list, negated))
233}
234
235pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
237 let outer_ref_columns = subquery.all_out_ref_exprs();
238 Expr::Exists(Exists {
239 subquery: Subquery {
240 subquery,
241 outer_ref_columns,
242 spans: Spans::new(),
243 },
244 negated: false,
245 })
246}
247
248pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
250 let outer_ref_columns = subquery.all_out_ref_exprs();
251 Expr::Exists(Exists {
252 subquery: Subquery {
253 subquery,
254 outer_ref_columns,
255 spans: Spans::new(),
256 },
257 negated: true,
258 })
259}
260
261pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
263 let outer_ref_columns = subquery.all_out_ref_exprs();
264 Expr::InSubquery(InSubquery::new(
265 Box::new(expr),
266 Subquery {
267 subquery,
268 outer_ref_columns,
269 spans: Spans::new(),
270 },
271 false,
272 ))
273}
274
275pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
277 let outer_ref_columns = subquery.all_out_ref_exprs();
278 Expr::InSubquery(InSubquery::new(
279 Box::new(expr),
280 Subquery {
281 subquery,
282 outer_ref_columns,
283 spans: Spans::new(),
284 },
285 true,
286 ))
287}
288
289pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
291 let outer_ref_columns = subquery.all_out_ref_exprs();
292 Expr::ScalarSubquery(Subquery {
293 subquery,
294 outer_ref_columns,
295 spans: Spans::new(),
296 })
297}
298
299pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
301 Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
302}
303
304pub fn cube(exprs: Vec<Expr>) -> Expr {
306 Expr::GroupingSet(GroupingSet::Cube(exprs))
307}
308
309pub fn rollup(exprs: Vec<Expr>) -> Expr {
311 Expr::GroupingSet(GroupingSet::Rollup(exprs))
312}
313
314pub fn cast(expr: Expr, data_type: DataType) -> Expr {
316 Expr::Cast(Cast::new(Box::new(expr), data_type))
317}
318
319pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
321 Expr::TryCast(TryCast::new(Box::new(expr), data_type))
322}
323
324pub fn is_null(expr: Expr) -> Expr {
326 Expr::IsNull(Box::new(expr))
327}
328
329pub fn is_true(expr: Expr) -> Expr {
331 Expr::IsTrue(Box::new(expr))
332}
333
334pub fn is_not_true(expr: Expr) -> Expr {
336 Expr::IsNotTrue(Box::new(expr))
337}
338
339pub fn is_false(expr: Expr) -> Expr {
341 Expr::IsFalse(Box::new(expr))
342}
343
344pub fn is_not_false(expr: Expr) -> Expr {
346 Expr::IsNotFalse(Box::new(expr))
347}
348
349pub fn is_unknown(expr: Expr) -> Expr {
351 Expr::IsUnknown(Box::new(expr))
352}
353
354pub fn is_not_unknown(expr: Expr) -> Expr {
356 Expr::IsNotUnknown(Box::new(expr))
357}
358
359pub fn case(expr: Expr) -> CaseBuilder {
361 CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
362}
363
364pub fn when(when: Expr, then: Expr) -> CaseBuilder {
366 CaseBuilder::new(None, vec![when], vec![then], None)
367}
368
369pub fn unnest(expr: Expr) -> Expr {
371 Expr::Unnest(Unnest {
372 expr: Box::new(expr),
373 })
374}
375
376pub fn create_udf(
389 name: &str,
390 input_types: Vec<DataType>,
391 return_type: DataType,
392 volatility: Volatility,
393 fun: ScalarFunctionImplementation,
394) -> ScalarUDF {
395 ScalarUDF::from(SimpleScalarUDF::new(
396 name,
397 input_types,
398 return_type,
399 volatility,
400 fun,
401 ))
402}
403
404#[derive(PartialEq, Eq, Hash)]
407pub struct SimpleScalarUDF {
408 name: String,
409 signature: Signature,
410 return_type: DataType,
411 fun: PtrEq<ScalarFunctionImplementation>,
412}
413
414impl Debug for SimpleScalarUDF {
415 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
416 f.debug_struct("SimpleScalarUDF")
417 .field("name", &self.name)
418 .field("signature", &self.signature)
419 .field("return_type", &self.return_type)
420 .field("fun", &"<FUNC>")
421 .finish()
422 }
423}
424
425impl SimpleScalarUDF {
426 pub fn new(
429 name: impl Into<String>,
430 input_types: Vec<DataType>,
431 return_type: DataType,
432 volatility: Volatility,
433 fun: ScalarFunctionImplementation,
434 ) -> Self {
435 Self::new_with_signature(
436 name,
437 Signature::exact(input_types, volatility),
438 return_type,
439 fun,
440 )
441 }
442
443 pub fn new_with_signature(
446 name: impl Into<String>,
447 signature: Signature,
448 return_type: DataType,
449 fun: ScalarFunctionImplementation,
450 ) -> Self {
451 Self {
452 name: name.into(),
453 signature,
454 return_type,
455 fun: fun.into(),
456 }
457 }
458}
459
460impl ScalarUDFImpl for SimpleScalarUDF {
461 fn as_any(&self) -> &dyn Any {
462 self
463 }
464
465 fn name(&self) -> &str {
466 &self.name
467 }
468
469 fn signature(&self) -> &Signature {
470 &self.signature
471 }
472
473 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
474 Ok(self.return_type.clone())
475 }
476
477 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
478 (self.fun)(&args.args)
479 }
480}
481
482pub fn create_udaf(
485 name: &str,
486 input_type: Vec<DataType>,
487 return_type: Arc<DataType>,
488 volatility: Volatility,
489 accumulator: AccumulatorFactoryFunction,
490 state_type: Arc<Vec<DataType>>,
491) -> AggregateUDF {
492 let return_type = Arc::unwrap_or_clone(return_type);
493 let state_type = Arc::unwrap_or_clone(state_type);
494 let state_fields = state_type
495 .into_iter()
496 .enumerate()
497 .map(|(i, t)| Field::new(format!("{i}"), t, true))
498 .map(Arc::new)
499 .collect::<Vec<_>>();
500 AggregateUDF::from(SimpleAggregateUDF::new(
501 name,
502 input_type,
503 return_type,
504 volatility,
505 accumulator,
506 state_fields,
507 ))
508}
509
510#[derive(PartialEq, Eq, Hash)]
513pub struct SimpleAggregateUDF {
514 name: String,
515 signature: Signature,
516 return_type: DataType,
517 accumulator: PtrEq<AccumulatorFactoryFunction>,
518 state_fields: Vec<FieldRef>,
519}
520
521impl Debug for SimpleAggregateUDF {
522 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
523 f.debug_struct("SimpleAggregateUDF")
524 .field("name", &self.name)
525 .field("signature", &self.signature)
526 .field("return_type", &self.return_type)
527 .field("fun", &"<FUNC>")
528 .finish()
529 }
530}
531
532impl SimpleAggregateUDF {
533 pub fn new(
536 name: impl Into<String>,
537 input_type: Vec<DataType>,
538 return_type: DataType,
539 volatility: Volatility,
540 accumulator: AccumulatorFactoryFunction,
541 state_fields: Vec<FieldRef>,
542 ) -> Self {
543 let name = name.into();
544 let signature = Signature::exact(input_type, volatility);
545 Self {
546 name,
547 signature,
548 return_type,
549 accumulator: accumulator.into(),
550 state_fields,
551 }
552 }
553
554 pub fn new_with_signature(
557 name: impl Into<String>,
558 signature: Signature,
559 return_type: DataType,
560 accumulator: AccumulatorFactoryFunction,
561 state_fields: Vec<FieldRef>,
562 ) -> Self {
563 let name = name.into();
564 Self {
565 name,
566 signature,
567 return_type,
568 accumulator: accumulator.into(),
569 state_fields,
570 }
571 }
572}
573
574impl AggregateUDFImpl for SimpleAggregateUDF {
575 fn as_any(&self) -> &dyn Any {
576 self
577 }
578
579 fn name(&self) -> &str {
580 &self.name
581 }
582
583 fn signature(&self) -> &Signature {
584 &self.signature
585 }
586
587 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
588 Ok(self.return_type.clone())
589 }
590
591 fn accumulator(
592 &self,
593 acc_args: AccumulatorArgs,
594 ) -> Result<Box<dyn crate::Accumulator>> {
595 (self.accumulator)(acc_args)
596 }
597
598 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
599 Ok(self.state_fields.clone())
600 }
601}
602
603pub fn create_udwf(
609 name: &str,
610 input_type: DataType,
611 return_type: Arc<DataType>,
612 volatility: Volatility,
613 partition_evaluator_factory: PartitionEvaluatorFactory,
614) -> WindowUDF {
615 let return_type = Arc::unwrap_or_clone(return_type);
616 WindowUDF::from(SimpleWindowUDF::new(
617 name,
618 input_type,
619 return_type,
620 volatility,
621 partition_evaluator_factory,
622 ))
623}
624
625#[derive(PartialEq, Eq, Hash)]
628pub struct SimpleWindowUDF {
629 name: String,
630 signature: Signature,
631 return_type: DataType,
632 partition_evaluator_factory: PtrEq<PartitionEvaluatorFactory>,
633}
634
635impl Debug for SimpleWindowUDF {
636 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
637 f.debug_struct("WindowUDF")
638 .field("name", &self.name)
639 .field("signature", &self.signature)
640 .field("return_type", &"<func>")
641 .field("partition_evaluator_factory", &"<FUNC>")
642 .finish()
643 }
644}
645
646impl SimpleWindowUDF {
647 pub fn new(
650 name: impl Into<String>,
651 input_type: DataType,
652 return_type: DataType,
653 volatility: Volatility,
654 partition_evaluator_factory: PartitionEvaluatorFactory,
655 ) -> Self {
656 let name = name.into();
657 let signature = Signature::exact([input_type].to_vec(), volatility);
658 Self {
659 name,
660 signature,
661 return_type,
662 partition_evaluator_factory: partition_evaluator_factory.into(),
663 }
664 }
665}
666
667impl WindowUDFImpl for SimpleWindowUDF {
668 fn as_any(&self) -> &dyn Any {
669 self
670 }
671
672 fn name(&self) -> &str {
673 &self.name
674 }
675
676 fn signature(&self) -> &Signature {
677 &self.signature
678 }
679
680 fn partition_evaluator(
681 &self,
682 _partition_evaluator_args: PartitionEvaluatorArgs,
683 ) -> Result<Box<dyn PartitionEvaluator>> {
684 (self.partition_evaluator_factory)()
685 }
686
687 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
688 Ok(Arc::new(Field::new(
689 field_args.name(),
690 self.return_type.clone(),
691 true,
692 )))
693 }
694}
695
696pub fn interval_year_month_lit(value: &str) -> Expr {
697 let interval = parse_interval_year_month(value).ok();
698 Expr::Literal(ScalarValue::IntervalYearMonth(interval), None)
699}
700
701pub fn interval_datetime_lit(value: &str) -> Expr {
702 let interval = parse_interval_day_time(value).ok();
703 Expr::Literal(ScalarValue::IntervalDayTime(interval), None)
704}
705
706pub fn interval_month_day_nano_lit(value: &str) -> Expr {
707 let interval = parse_interval_month_day_nano(value).ok();
708 Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None)
709}
710
711pub trait ExprFunctionExt {
753 fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
755 fn filter(self, filter: Expr) -> ExprFuncBuilder;
757 fn distinct(self) -> ExprFuncBuilder;
759 fn null_treatment(
761 self,
762 null_treatment: impl Into<Option<NullTreatment>>,
763 ) -> ExprFuncBuilder;
764 fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
766 fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
768}
769
770#[derive(Debug, Clone)]
771pub enum ExprFuncKind {
772 Aggregate(AggregateFunction),
773 Window(Box<WindowFunction>),
774}
775
776#[derive(Debug, Clone)]
780pub struct ExprFuncBuilder {
781 fun: Option<ExprFuncKind>,
782 order_by: Option<Vec<Sort>>,
783 filter: Option<Expr>,
784 distinct: bool,
785 null_treatment: Option<NullTreatment>,
786 partition_by: Option<Vec<Expr>>,
787 window_frame: Option<WindowFrame>,
788}
789
790impl ExprFuncBuilder {
791 fn new(fun: Option<ExprFuncKind>) -> Self {
793 Self {
794 fun,
795 order_by: None,
796 filter: None,
797 distinct: false,
798 null_treatment: None,
799 partition_by: None,
800 window_frame: None,
801 }
802 }
803
804 pub fn build(self) -> Result<Expr> {
811 let Self {
812 fun,
813 order_by,
814 filter,
815 distinct,
816 null_treatment,
817 partition_by,
818 window_frame,
819 } = self;
820
821 let Some(fun) = fun else {
822 return plan_err!(
823 "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
824 );
825 };
826
827 let fun_expr = match fun {
828 ExprFuncKind::Aggregate(mut udaf) => {
829 udaf.params.order_by = order_by.unwrap_or_default();
830 udaf.params.filter = filter.map(Box::new);
831 udaf.params.distinct = distinct;
832 udaf.params.null_treatment = null_treatment;
833 Expr::AggregateFunction(udaf)
834 }
835 ExprFuncKind::Window(mut udwf) => {
836 let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
837 udwf.params.partition_by = partition_by.unwrap_or_default();
838 udwf.params.order_by = order_by.unwrap_or_default();
839 udwf.params.window_frame =
840 window_frame.unwrap_or_else(|| WindowFrame::new(has_order_by));
841 udwf.params.filter = filter.map(Box::new);
842 udwf.params.null_treatment = null_treatment;
843 udwf.params.distinct = distinct;
844 Expr::WindowFunction(udwf)
845 }
846 };
847
848 Ok(fun_expr)
849 }
850}
851
852impl ExprFunctionExt for ExprFuncBuilder {
853 fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
855 self.order_by = Some(order_by);
856 self
857 }
858
859 fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
861 self.filter = Some(filter);
862 self
863 }
864
865 fn distinct(mut self) -> ExprFuncBuilder {
867 self.distinct = true;
868 self
869 }
870
871 fn null_treatment(
873 mut self,
874 null_treatment: impl Into<Option<NullTreatment>>,
875 ) -> ExprFuncBuilder {
876 self.null_treatment = null_treatment.into();
877 self
878 }
879
880 fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
881 self.partition_by = Some(partition_by);
882 self
883 }
884
885 fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
886 self.window_frame = Some(window_frame);
887 self
888 }
889}
890
891impl ExprFunctionExt for Expr {
892 fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
893 let mut builder = match self {
894 Expr::AggregateFunction(udaf) => {
895 ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
896 }
897 Expr::WindowFunction(udwf) => {
898 ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
899 }
900 _ => ExprFuncBuilder::new(None),
901 };
902 if builder.fun.is_some() {
903 builder.order_by = Some(order_by);
904 }
905 builder
906 }
907 fn filter(self, filter: Expr) -> ExprFuncBuilder {
908 match self {
909 Expr::AggregateFunction(udaf) => {
910 let mut builder =
911 ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
912 builder.filter = Some(filter);
913 builder
914 }
915 _ => ExprFuncBuilder::new(None),
916 }
917 }
918 fn distinct(self) -> ExprFuncBuilder {
919 match self {
920 Expr::AggregateFunction(udaf) => {
921 let mut builder =
922 ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
923 builder.distinct = true;
924 builder
925 }
926 _ => ExprFuncBuilder::new(None),
927 }
928 }
929 fn null_treatment(
930 self,
931 null_treatment: impl Into<Option<NullTreatment>>,
932 ) -> ExprFuncBuilder {
933 let mut builder = match self {
934 Expr::AggregateFunction(udaf) => {
935 ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
936 }
937 Expr::WindowFunction(udwf) => {
938 ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
939 }
940 _ => ExprFuncBuilder::new(None),
941 };
942 if builder.fun.is_some() {
943 builder.null_treatment = null_treatment.into();
944 }
945 builder
946 }
947
948 fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
949 match self {
950 Expr::WindowFunction(udwf) => {
951 let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
952 builder.partition_by = Some(partition_by);
953 builder
954 }
955 _ => ExprFuncBuilder::new(None),
956 }
957 }
958
959 fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
960 match self {
961 Expr::WindowFunction(udwf) => {
962 let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
963 builder.window_frame = Some(window_frame);
964 builder
965 }
966 _ => ExprFuncBuilder::new(None),
967 }
968 }
969}
970
971#[cfg(test)]
972mod test {
973 use super::*;
974
975 #[test]
976 fn filter_is_null_and_is_not_null() {
977 let col_null = col("col1");
978 let col_not_null = ident("col2");
979 assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
980 assert_eq!(
981 format!("{}", col_not_null.is_not_null()),
982 "col2 IS NOT NULL"
983 );
984 }
985}