datafusion_expr/
expr_fn.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Functions for creating logical expressions
19
20use crate::expr::{
21    AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22    NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction,
23};
24use crate::function::{
25    AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26    StateFieldsArgs,
27};
28use crate::ptr_eq::PtrEq;
29use crate::select_expr::SelectExpr;
30use crate::{
31    conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
32    AggregateUDF, Expr, LimitEffect, LogicalPlan, Operator, PartitionEvaluator,
33    ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
34};
35use crate::{
36    AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
37};
38use arrow::compute::kernels::cast_utils::{
39    parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
40};
41use arrow::datatypes::{DataType, Field, FieldRef};
42use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference};
43use datafusion_functions_window_common::field::WindowUDFFieldArgs;
44use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
45use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
46use std::any::Any;
47use std::collections::HashMap;
48use std::fmt::Debug;
49use std::hash::Hash;
50use std::ops::Not;
51use std::sync::Arc;
52
53/// Create a column expression based on a qualified or unqualified column name. Will
54/// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase).
55///
56/// For example:
57///
58/// ```rust
59/// # use datafusion_expr::col;
60/// let c1 = col("a");
61/// let c2 = col("A");
62/// assert_eq!(c1, c2);
63///
64/// // note how quoting with double quotes preserves the case
65/// let c3 = col(r#""A""#);
66/// assert_ne!(c1, c3);
67/// ```
68pub fn col(ident: impl Into<Column>) -> Expr {
69    Expr::Column(ident.into())
70}
71
72/// Create an out reference column which hold a reference that has been resolved to a field
73/// outside of the current plan.
74/// The expression created by this function does not preserve the metadata of the outer column.
75/// Please use `out_ref_col_with_metadata` if you want to preserve the metadata.
76pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
77    out_ref_col_with_metadata(dt, HashMap::new(), ident)
78}
79
80/// Create an out reference column from an existing field (preserving metadata)
81pub fn out_ref_col_with_metadata(
82    dt: DataType,
83    metadata: HashMap<String, String>,
84    ident: impl Into<Column>,
85) -> Expr {
86    let column = ident.into();
87    let field: FieldRef =
88        Arc::new(Field::new(column.name(), dt, true).with_metadata(metadata));
89    Expr::OuterReferenceColumn(field, column)
90}
91
92/// Create an unqualified column expression from the provided name, without normalizing
93/// the column.
94///
95/// For example:
96///
97/// ```rust
98/// # use datafusion_expr::{col, ident};
99/// let c1 = ident("A"); // not normalized staying as column 'A'
100/// let c2 = col("A"); // normalized via SQL rules becoming column 'a'
101/// assert_ne!(c1, c2);
102///
103/// let c3 = col(r#""A""#);
104/// assert_eq!(c1, c3);
105///
106/// let c4 = col("t1.a"); // parses as relation 't1' column 'a'
107/// let c5 = ident("t1.a"); // parses as column 't1.a'
108/// assert_ne!(c4, c5);
109/// ```
110pub fn ident(name: impl Into<String>) -> Expr {
111    Expr::Column(Column::from_name(name))
112}
113
114/// Create placeholder value that will be filled in (such as `$1`)
115///
116/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
117///
118/// # Example
119///
120/// ```rust
121/// # use datafusion_expr::{placeholder};
122/// let p = placeholder("$1"); // $1, refers to parameter 1
123/// assert_eq!(p.to_string(), "$1")
124/// ```
125pub fn placeholder(id: impl Into<String>) -> Expr {
126    Expr::Placeholder(Placeholder {
127        id: id.into(),
128        field: None,
129    })
130}
131
132/// Create an '*' [`Expr::Wildcard`] expression that matches all columns
133///
134/// # Example
135///
136/// ```rust
137/// # use datafusion_expr::{wildcard};
138/// let p = wildcard();
139/// assert_eq!(p.to_string(), "*")
140/// ```
141pub fn wildcard() -> SelectExpr {
142    SelectExpr::Wildcard(WildcardOptions::default())
143}
144
145/// Create an '*' [`Expr::Wildcard`] expression with the wildcard options
146pub fn wildcard_with_options(options: WildcardOptions) -> SelectExpr {
147    SelectExpr::Wildcard(options)
148}
149
150/// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table
151///
152/// # Example
153///
154/// ```rust
155/// # use datafusion_common::TableReference;
156/// # use datafusion_expr::{qualified_wildcard};
157/// let p = qualified_wildcard(TableReference::bare("t"));
158/// assert_eq!(p.to_string(), "t.*")
159/// ```
160pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> SelectExpr {
161    SelectExpr::QualifiedWildcard(qualifier.into(), WildcardOptions::default())
162}
163
164/// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options
165pub fn qualified_wildcard_with_options(
166    qualifier: impl Into<TableReference>,
167    options: WildcardOptions,
168) -> SelectExpr {
169    SelectExpr::QualifiedWildcard(qualifier.into(), options)
170}
171
172/// Return a new expression `left <op> right`
173pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
174    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
175}
176
177/// Return a new expression with a logical AND
178pub fn and(left: Expr, right: Expr) -> Expr {
179    Expr::BinaryExpr(BinaryExpr::new(
180        Box::new(left),
181        Operator::And,
182        Box::new(right),
183    ))
184}
185
186/// Return a new expression with a logical OR
187pub fn or(left: Expr, right: Expr) -> Expr {
188    Expr::BinaryExpr(BinaryExpr::new(
189        Box::new(left),
190        Operator::Or,
191        Box::new(right),
192    ))
193}
194
195/// Return a new expression with a logical NOT
196pub fn not(expr: Expr) -> Expr {
197    expr.not()
198}
199
200/// Return a new expression with bitwise AND
201pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
202    Expr::BinaryExpr(BinaryExpr::new(
203        Box::new(left),
204        Operator::BitwiseAnd,
205        Box::new(right),
206    ))
207}
208
209/// Return a new expression with bitwise OR
210pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
211    Expr::BinaryExpr(BinaryExpr::new(
212        Box::new(left),
213        Operator::BitwiseOr,
214        Box::new(right),
215    ))
216}
217
218/// Return a new expression with bitwise XOR
219pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
220    Expr::BinaryExpr(BinaryExpr::new(
221        Box::new(left),
222        Operator::BitwiseXor,
223        Box::new(right),
224    ))
225}
226
227/// Return a new expression with bitwise SHIFT RIGHT
228pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
229    Expr::BinaryExpr(BinaryExpr::new(
230        Box::new(left),
231        Operator::BitwiseShiftRight,
232        Box::new(right),
233    ))
234}
235
236/// Return a new expression with bitwise SHIFT LEFT
237pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
238    Expr::BinaryExpr(BinaryExpr::new(
239        Box::new(left),
240        Operator::BitwiseShiftLeft,
241        Box::new(right),
242    ))
243}
244
245/// Create an in_list expression
246pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
247    Expr::InList(InList::new(Box::new(expr), list, negated))
248}
249
250/// Create an EXISTS subquery expression
251pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
252    let outer_ref_columns = subquery.all_out_ref_exprs();
253    Expr::Exists(Exists {
254        subquery: Subquery {
255            subquery,
256            outer_ref_columns,
257            spans: Spans::new(),
258        },
259        negated: false,
260    })
261}
262
263/// Create a NOT EXISTS subquery expression
264pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
265    let outer_ref_columns = subquery.all_out_ref_exprs();
266    Expr::Exists(Exists {
267        subquery: Subquery {
268            subquery,
269            outer_ref_columns,
270            spans: Spans::new(),
271        },
272        negated: true,
273    })
274}
275
276/// Create an IN subquery expression
277pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
278    let outer_ref_columns = subquery.all_out_ref_exprs();
279    Expr::InSubquery(InSubquery::new(
280        Box::new(expr),
281        Subquery {
282            subquery,
283            outer_ref_columns,
284            spans: Spans::new(),
285        },
286        false,
287    ))
288}
289
290/// Create a NOT IN subquery expression
291pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
292    let outer_ref_columns = subquery.all_out_ref_exprs();
293    Expr::InSubquery(InSubquery::new(
294        Box::new(expr),
295        Subquery {
296            subquery,
297            outer_ref_columns,
298            spans: Spans::new(),
299        },
300        true,
301    ))
302}
303
304/// Create a scalar subquery expression
305pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
306    let outer_ref_columns = subquery.all_out_ref_exprs();
307    Expr::ScalarSubquery(Subquery {
308        subquery,
309        outer_ref_columns,
310        spans: Spans::new(),
311    })
312}
313
314/// Create a grouping set
315pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
316    Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
317}
318
319/// Create a grouping set for all combination of `exprs`
320pub fn cube(exprs: Vec<Expr>) -> Expr {
321    Expr::GroupingSet(GroupingSet::Cube(exprs))
322}
323
324/// Create a grouping set for rollup
325pub fn rollup(exprs: Vec<Expr>) -> Expr {
326    Expr::GroupingSet(GroupingSet::Rollup(exprs))
327}
328
329/// Create a cast expression
330pub fn cast(expr: Expr, data_type: DataType) -> Expr {
331    Expr::Cast(Cast::new(Box::new(expr), data_type))
332}
333
334/// Create a try cast expression
335pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
336    Expr::TryCast(TryCast::new(Box::new(expr), data_type))
337}
338
339/// Create is null expression
340pub fn is_null(expr: Expr) -> Expr {
341    Expr::IsNull(Box::new(expr))
342}
343
344/// Create is true expression
345pub fn is_true(expr: Expr) -> Expr {
346    Expr::IsTrue(Box::new(expr))
347}
348
349/// Create is not true expression
350pub fn is_not_true(expr: Expr) -> Expr {
351    Expr::IsNotTrue(Box::new(expr))
352}
353
354/// Create is false expression
355pub fn is_false(expr: Expr) -> Expr {
356    Expr::IsFalse(Box::new(expr))
357}
358
359/// Create is not false expression
360pub fn is_not_false(expr: Expr) -> Expr {
361    Expr::IsNotFalse(Box::new(expr))
362}
363
364/// Create is unknown expression
365pub fn is_unknown(expr: Expr) -> Expr {
366    Expr::IsUnknown(Box::new(expr))
367}
368
369/// Create is not unknown expression
370pub fn is_not_unknown(expr: Expr) -> Expr {
371    Expr::IsNotUnknown(Box::new(expr))
372}
373
374/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
375pub fn case(expr: Expr) -> CaseBuilder {
376    CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
377}
378
379/// Create a CASE WHEN statement with boolean WHEN expressions and no base expression.
380pub fn when(when: Expr, then: Expr) -> CaseBuilder {
381    CaseBuilder::new(None, vec![when], vec![then], None)
382}
383
384/// Create a Unnest expression
385pub fn unnest(expr: Expr) -> Expr {
386    Expr::Unnest(Unnest {
387        expr: Box::new(expr),
388    })
389}
390
391/// Convenience method to create a new user defined scalar function (UDF) with a
392/// specific signature and specific return type.
393///
394/// Note this function does not expose all available features of [`ScalarUDF`],
395/// such as
396///
397/// * computing return types based on input types
398/// * multiple [`Signature`]s
399/// * aliases
400///
401/// See [`ScalarUDF`] for details and examples on how to use the full
402/// functionality.
403pub fn create_udf(
404    name: &str,
405    input_types: Vec<DataType>,
406    return_type: DataType,
407    volatility: Volatility,
408    fun: ScalarFunctionImplementation,
409) -> ScalarUDF {
410    ScalarUDF::from(SimpleScalarUDF::new(
411        name,
412        input_types,
413        return_type,
414        volatility,
415        fun,
416    ))
417}
418
419/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
420/// return type.
421#[derive(PartialEq, Eq, Hash)]
422pub struct SimpleScalarUDF {
423    name: String,
424    signature: Signature,
425    return_type: DataType,
426    fun: PtrEq<ScalarFunctionImplementation>,
427}
428
429impl Debug for SimpleScalarUDF {
430    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
431        f.debug_struct("SimpleScalarUDF")
432            .field("name", &self.name)
433            .field("signature", &self.signature)
434            .field("return_type", &self.return_type)
435            .field("fun", &"<FUNC>")
436            .finish()
437    }
438}
439
440impl SimpleScalarUDF {
441    /// Create a new `SimpleScalarUDF` from a name, input types, return type and
442    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
443    pub fn new(
444        name: impl Into<String>,
445        input_types: Vec<DataType>,
446        return_type: DataType,
447        volatility: Volatility,
448        fun: ScalarFunctionImplementation,
449    ) -> Self {
450        Self::new_with_signature(
451            name,
452            Signature::exact(input_types, volatility),
453            return_type,
454            fun,
455        )
456    }
457
458    /// Create a new `SimpleScalarUDF` from a name, signature, return type and
459    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
460    pub fn new_with_signature(
461        name: impl Into<String>,
462        signature: Signature,
463        return_type: DataType,
464        fun: ScalarFunctionImplementation,
465    ) -> Self {
466        Self {
467            name: name.into(),
468            signature,
469            return_type,
470            fun: fun.into(),
471        }
472    }
473}
474
475impl ScalarUDFImpl for SimpleScalarUDF {
476    fn as_any(&self) -> &dyn Any {
477        self
478    }
479
480    fn name(&self) -> &str {
481        &self.name
482    }
483
484    fn signature(&self) -> &Signature {
485        &self.signature
486    }
487
488    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
489        Ok(self.return_type.clone())
490    }
491
492    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
493        (self.fun)(&args.args)
494    }
495}
496
497/// Creates a new UDAF with a specific signature, state type and return type.
498/// The signature and state type must match the `Accumulator's implementation`.
499pub fn create_udaf(
500    name: &str,
501    input_type: Vec<DataType>,
502    return_type: Arc<DataType>,
503    volatility: Volatility,
504    accumulator: AccumulatorFactoryFunction,
505    state_type: Arc<Vec<DataType>>,
506) -> AggregateUDF {
507    let return_type = Arc::unwrap_or_clone(return_type);
508    let state_type = Arc::unwrap_or_clone(state_type);
509    let state_fields = state_type
510        .into_iter()
511        .enumerate()
512        .map(|(i, t)| Field::new(format!("{i}"), t, true))
513        .map(Arc::new)
514        .collect::<Vec<_>>();
515    AggregateUDF::from(SimpleAggregateUDF::new(
516        name,
517        input_type,
518        return_type,
519        volatility,
520        accumulator,
521        state_fields,
522    ))
523}
524
525/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
526/// return type.
527#[derive(PartialEq, Eq, Hash)]
528pub struct SimpleAggregateUDF {
529    name: String,
530    signature: Signature,
531    return_type: DataType,
532    accumulator: PtrEq<AccumulatorFactoryFunction>,
533    state_fields: Vec<FieldRef>,
534}
535
536impl Debug for SimpleAggregateUDF {
537    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
538        f.debug_struct("SimpleAggregateUDF")
539            .field("name", &self.name)
540            .field("signature", &self.signature)
541            .field("return_type", &self.return_type)
542            .field("fun", &"<FUNC>")
543            .finish()
544    }
545}
546
547impl SimpleAggregateUDF {
548    /// Create a new `SimpleAggregateUDF` from a name, input types, return type, state type and
549    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
550    pub fn new(
551        name: impl Into<String>,
552        input_type: Vec<DataType>,
553        return_type: DataType,
554        volatility: Volatility,
555        accumulator: AccumulatorFactoryFunction,
556        state_fields: Vec<FieldRef>,
557    ) -> Self {
558        let name = name.into();
559        let signature = Signature::exact(input_type, volatility);
560        Self {
561            name,
562            signature,
563            return_type,
564            accumulator: accumulator.into(),
565            state_fields,
566        }
567    }
568
569    /// Create a new `SimpleAggregateUDF` from a name, signature, return type, state type and
570    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
571    pub fn new_with_signature(
572        name: impl Into<String>,
573        signature: Signature,
574        return_type: DataType,
575        accumulator: AccumulatorFactoryFunction,
576        state_fields: Vec<FieldRef>,
577    ) -> Self {
578        let name = name.into();
579        Self {
580            name,
581            signature,
582            return_type,
583            accumulator: accumulator.into(),
584            state_fields,
585        }
586    }
587}
588
589impl AggregateUDFImpl for SimpleAggregateUDF {
590    fn as_any(&self) -> &dyn Any {
591        self
592    }
593
594    fn name(&self) -> &str {
595        &self.name
596    }
597
598    fn signature(&self) -> &Signature {
599        &self.signature
600    }
601
602    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
603        Ok(self.return_type.clone())
604    }
605
606    fn accumulator(
607        &self,
608        acc_args: AccumulatorArgs,
609    ) -> Result<Box<dyn crate::Accumulator>> {
610        (self.accumulator)(acc_args)
611    }
612
613    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
614        Ok(self.state_fields.clone())
615    }
616}
617
618/// Creates a new UDWF with a specific signature, state type and return type.
619///
620/// The signature and state type must match the [`PartitionEvaluator`]'s implementation`.
621///
622/// [`PartitionEvaluator`]: crate::PartitionEvaluator
623pub fn create_udwf(
624    name: &str,
625    input_type: DataType,
626    return_type: Arc<DataType>,
627    volatility: Volatility,
628    partition_evaluator_factory: PartitionEvaluatorFactory,
629) -> WindowUDF {
630    let return_type = Arc::unwrap_or_clone(return_type);
631    WindowUDF::from(SimpleWindowUDF::new(
632        name,
633        input_type,
634        return_type,
635        volatility,
636        partition_evaluator_factory,
637    ))
638}
639
640/// Implements [`WindowUDFImpl`] for functions that have a single signature and
641/// return type.
642#[derive(PartialEq, Eq, Hash)]
643pub struct SimpleWindowUDF {
644    name: String,
645    signature: Signature,
646    return_type: DataType,
647    partition_evaluator_factory: PtrEq<PartitionEvaluatorFactory>,
648}
649
650impl Debug for SimpleWindowUDF {
651    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
652        f.debug_struct("WindowUDF")
653            .field("name", &self.name)
654            .field("signature", &self.signature)
655            .field("return_type", &"<func>")
656            .field("partition_evaluator_factory", &"<FUNC>")
657            .finish()
658    }
659}
660
661impl SimpleWindowUDF {
662    /// Create a new `SimpleWindowUDF` from a name, input types, return type and
663    /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility
664    pub fn new(
665        name: impl Into<String>,
666        input_type: DataType,
667        return_type: DataType,
668        volatility: Volatility,
669        partition_evaluator_factory: PartitionEvaluatorFactory,
670    ) -> Self {
671        let name = name.into();
672        let signature = Signature::exact([input_type].to_vec(), volatility);
673        Self {
674            name,
675            signature,
676            return_type,
677            partition_evaluator_factory: partition_evaluator_factory.into(),
678        }
679    }
680}
681
682impl WindowUDFImpl for SimpleWindowUDF {
683    fn as_any(&self) -> &dyn Any {
684        self
685    }
686
687    fn name(&self) -> &str {
688        &self.name
689    }
690
691    fn signature(&self) -> &Signature {
692        &self.signature
693    }
694
695    fn partition_evaluator(
696        &self,
697        _partition_evaluator_args: PartitionEvaluatorArgs,
698    ) -> Result<Box<dyn PartitionEvaluator>> {
699        (self.partition_evaluator_factory)()
700    }
701
702    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
703        Ok(Arc::new(Field::new(
704            field_args.name(),
705            self.return_type.clone(),
706            true,
707        )))
708    }
709
710    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
711        LimitEffect::Unknown
712    }
713}
714
715pub fn interval_year_month_lit(value: &str) -> Expr {
716    let interval = parse_interval_year_month(value).ok();
717    Expr::Literal(ScalarValue::IntervalYearMonth(interval), None)
718}
719
720pub fn interval_datetime_lit(value: &str) -> Expr {
721    let interval = parse_interval_day_time(value).ok();
722    Expr::Literal(ScalarValue::IntervalDayTime(interval), None)
723}
724
725pub fn interval_month_day_nano_lit(value: &str) -> Expr {
726    let interval = parse_interval_month_day_nano(value).ok();
727    Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None)
728}
729
730/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
731///
732/// Adds methods to [`Expr`] that make it easy to set optional options
733/// such as `ORDER BY`, `FILTER` and `DISTINCT`
734///
735/// # Example
736/// ```no_run
737/// # use datafusion_common::Result;
738/// # use datafusion_expr::expr::NullTreatment;
739/// # use datafusion_expr::test::function_stub::count;
740/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col};
741/// # // first_value is an aggregate function in another crate
742/// # fn first_value(_arg: Expr) -> Expr {
743/// unimplemented!() }
744/// # fn main() -> Result<()> {
745/// // Create an aggregate count, filtering on column y > 5
746/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?;
747///
748/// // Find the first value in an aggregate sorted by column y
749/// // equivalent to:
750/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)`
751/// let sort_expr = col("y").sort(true, true);
752/// let agg = first_value(col("x"))
753///     .order_by(vec![sort_expr])
754///     .null_treatment(NullTreatment::IgnoreNulls)
755///     .build()?;
756///
757/// // Create a window expression for percent rank partitioned on column a
758/// // equivalent to:
759/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)`
760/// // percent_rank is an udwf function in another crate
761/// # fn percent_rank() -> Expr {
762/// unimplemented!() }
763/// let window = percent_rank()
764///     .partition_by(vec![col("a")])
765///     .order_by(vec![col("b").sort(true, true)])
766///     .null_treatment(NullTreatment::IgnoreNulls)
767///     .build()?;
768/// #     Ok(())
769/// # }
770/// ```
771pub trait ExprFunctionExt {
772    /// Add `ORDER BY <order_by>`
773    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
774    /// Add `FILTER <filter>`
775    fn filter(self, filter: Expr) -> ExprFuncBuilder;
776    /// Add `DISTINCT`
777    fn distinct(self) -> ExprFuncBuilder;
778    /// Add `RESPECT NULLS` or `IGNORE NULLS`
779    fn null_treatment(
780        self,
781        null_treatment: impl Into<Option<NullTreatment>>,
782    ) -> ExprFuncBuilder;
783    /// Add `PARTITION BY`
784    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
785    /// Add appropriate window frame conditions
786    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
787}
788
789#[derive(Debug, Clone)]
790pub enum ExprFuncKind {
791    Aggregate(AggregateFunction),
792    Window(Box<WindowFunction>),
793}
794
795/// Implementation of [`ExprFunctionExt`].
796///
797/// See [`ExprFunctionExt`] for usage and examples
798#[derive(Debug, Clone)]
799pub struct ExprFuncBuilder {
800    fun: Option<ExprFuncKind>,
801    order_by: Option<Vec<Sort>>,
802    filter: Option<Expr>,
803    distinct: bool,
804    null_treatment: Option<NullTreatment>,
805    partition_by: Option<Vec<Expr>>,
806    window_frame: Option<WindowFrame>,
807}
808
809impl ExprFuncBuilder {
810    /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`]
811    fn new(fun: Option<ExprFuncKind>) -> Self {
812        Self {
813            fun,
814            order_by: None,
815            filter: None,
816            distinct: false,
817            null_treatment: None,
818            partition_by: None,
819            window_frame: None,
820        }
821    }
822
823    /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
824    ///
825    /// # Errors:
826    ///
827    /// Returns an error if this builder  [`ExprFunctionExt`] was used with an
828    /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
829    pub fn build(self) -> Result<Expr> {
830        let Self {
831            fun,
832            order_by,
833            filter,
834            distinct,
835            null_treatment,
836            partition_by,
837            window_frame,
838        } = self;
839
840        let Some(fun) = fun else {
841            return plan_err!(
842                "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
843            );
844        };
845
846        let fun_expr = match fun {
847            ExprFuncKind::Aggregate(mut udaf) => {
848                udaf.params.order_by = order_by.unwrap_or_default();
849                udaf.params.filter = filter.map(Box::new);
850                udaf.params.distinct = distinct;
851                udaf.params.null_treatment = null_treatment;
852                Expr::AggregateFunction(udaf)
853            }
854            ExprFuncKind::Window(mut udwf) => {
855                let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
856                udwf.params.partition_by = partition_by.unwrap_or_default();
857                udwf.params.order_by = order_by.unwrap_or_default();
858                udwf.params.window_frame =
859                    window_frame.unwrap_or_else(|| WindowFrame::new(has_order_by));
860                udwf.params.filter = filter.map(Box::new);
861                udwf.params.null_treatment = null_treatment;
862                udwf.params.distinct = distinct;
863                Expr::WindowFunction(udwf)
864            }
865        };
866
867        Ok(fun_expr)
868    }
869}
870
871impl ExprFunctionExt for ExprFuncBuilder {
872    /// Add `ORDER BY <order_by>`
873    fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
874        self.order_by = Some(order_by);
875        self
876    }
877
878    /// Add `FILTER <filter>`
879    fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
880        self.filter = Some(filter);
881        self
882    }
883
884    /// Add `DISTINCT`
885    fn distinct(mut self) -> ExprFuncBuilder {
886        self.distinct = true;
887        self
888    }
889
890    /// Add `RESPECT NULLS` or `IGNORE NULLS`
891    fn null_treatment(
892        mut self,
893        null_treatment: impl Into<Option<NullTreatment>>,
894    ) -> ExprFuncBuilder {
895        self.null_treatment = null_treatment.into();
896        self
897    }
898
899    fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
900        self.partition_by = Some(partition_by);
901        self
902    }
903
904    fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
905        self.window_frame = Some(window_frame);
906        self
907    }
908}
909
910impl ExprFunctionExt for Expr {
911    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
912        let mut builder = match self {
913            Expr::AggregateFunction(udaf) => {
914                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
915            }
916            Expr::WindowFunction(udwf) => {
917                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
918            }
919            _ => ExprFuncBuilder::new(None),
920        };
921        if builder.fun.is_some() {
922            builder.order_by = Some(order_by);
923        }
924        builder
925    }
926    fn filter(self, filter: Expr) -> ExprFuncBuilder {
927        match self {
928            Expr::AggregateFunction(udaf) => {
929                let mut builder =
930                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
931                builder.filter = Some(filter);
932                builder
933            }
934            _ => ExprFuncBuilder::new(None),
935        }
936    }
937    fn distinct(self) -> ExprFuncBuilder {
938        match self {
939            Expr::AggregateFunction(udaf) => {
940                let mut builder =
941                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
942                builder.distinct = true;
943                builder
944            }
945            _ => ExprFuncBuilder::new(None),
946        }
947    }
948    fn null_treatment(
949        self,
950        null_treatment: impl Into<Option<NullTreatment>>,
951    ) -> ExprFuncBuilder {
952        let mut builder = match self {
953            Expr::AggregateFunction(udaf) => {
954                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
955            }
956            Expr::WindowFunction(udwf) => {
957                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
958            }
959            _ => ExprFuncBuilder::new(None),
960        };
961        if builder.fun.is_some() {
962            builder.null_treatment = null_treatment.into();
963        }
964        builder
965    }
966
967    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
968        match self {
969            Expr::WindowFunction(udwf) => {
970                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
971                builder.partition_by = Some(partition_by);
972                builder
973            }
974            _ => ExprFuncBuilder::new(None),
975        }
976    }
977
978    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
979        match self {
980            Expr::WindowFunction(udwf) => {
981                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
982                builder.window_frame = Some(window_frame);
983                builder
984            }
985            _ => ExprFuncBuilder::new(None),
986        }
987    }
988}
989
990#[cfg(test)]
991mod test {
992    use super::*;
993
994    #[test]
995    fn filter_is_null_and_is_not_null() {
996        let col_null = col("col1");
997        let col_not_null = ident("col2");
998        assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
999        assert_eq!(
1000            format!("{}", col_not_null.is_not_null()),
1001            "col2 IS NOT NULL"
1002        );
1003    }
1004}