datafusion_expr/
expr_fn.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Functions for creating logical expressions
19
20use crate::expr::{
21    AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22    NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction,
23};
24use crate::function::{
25    AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26    StateFieldsArgs,
27};
28use crate::ptr_eq::PtrEq;
29use crate::select_expr::SelectExpr;
30use crate::{
31    AggregateUDF, Expr, LimitEffect, LogicalPlan, Operator, PartitionEvaluator,
32    ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
33    conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
34};
35use crate::{
36    AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
37};
38use arrow::compute::kernels::cast_utils::{
39    parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
40};
41use arrow::datatypes::{DataType, Field, FieldRef};
42use datafusion_common::{Column, Result, ScalarValue, Spans, TableReference, plan_err};
43use datafusion_functions_window_common::field::WindowUDFFieldArgs;
44use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
45use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
46use std::any::Any;
47use std::collections::HashMap;
48use std::fmt::Debug;
49use std::hash::Hash;
50use std::ops::Not;
51use std::sync::Arc;
52
53/// Create a column expression based on a qualified or unqualified column name. Will
54/// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase).
55///
56/// For example:
57///
58/// ```rust
59/// # use datafusion_expr::col;
60/// let c1 = col("a");
61/// let c2 = col("A");
62/// assert_eq!(c1, c2);
63///
64/// // note how quoting with double quotes preserves the case
65/// let c3 = col(r#""A""#);
66/// assert_ne!(c1, c3);
67/// ```
68pub fn col(ident: impl Into<Column>) -> Expr {
69    Expr::Column(ident.into())
70}
71
72/// Create an out reference column which hold a reference that has been resolved to a field
73/// outside of the current plan.
74/// The expression created by this function does not preserve the metadata of the outer column.
75/// Please use `out_ref_col_with_metadata` if you want to preserve the metadata.
76pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
77    out_ref_col_with_metadata(dt, HashMap::new(), ident)
78}
79
80/// Create an out reference column from an existing field (preserving metadata)
81pub fn out_ref_col_with_metadata(
82    dt: DataType,
83    metadata: HashMap<String, String>,
84    ident: impl Into<Column>,
85) -> Expr {
86    let column = ident.into();
87    let field: FieldRef =
88        Arc::new(Field::new(column.name(), dt, true).with_metadata(metadata));
89    Expr::OuterReferenceColumn(field, column)
90}
91
92/// Create an unqualified column expression from the provided name, without normalizing
93/// the column.
94///
95/// For example:
96///
97/// ```rust
98/// # use datafusion_expr::{col, ident};
99/// let c1 = ident("A"); // not normalized staying as column 'A'
100/// let c2 = col("A"); // normalized via SQL rules becoming column 'a'
101/// assert_ne!(c1, c2);
102///
103/// let c3 = col(r#""A""#);
104/// assert_eq!(c1, c3);
105///
106/// let c4 = col("t1.a"); // parses as relation 't1' column 'a'
107/// let c5 = ident("t1.a"); // parses as column 't1.a'
108/// assert_ne!(c4, c5);
109/// ```
110pub fn ident(name: impl Into<String>) -> Expr {
111    Expr::Column(Column::from_name(name))
112}
113
114/// Create placeholder value that will be filled in (such as `$1`)
115///
116/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
117///
118/// # Example
119///
120/// ```rust
121/// # use datafusion_expr::{placeholder};
122/// let p = placeholder("$1"); // $1, refers to parameter 1
123/// assert_eq!(p.to_string(), "$1")
124/// ```
125pub fn placeholder(id: impl Into<String>) -> Expr {
126    Expr::Placeholder(Placeholder {
127        id: id.into(),
128        field: None,
129    })
130}
131
132/// Create an '*' [`Expr::Wildcard`] expression that matches all columns
133///
134/// # Example
135///
136/// ```rust
137/// # use datafusion_expr::{wildcard};
138/// let p = wildcard();
139/// assert_eq!(p.to_string(), "*")
140/// ```
141pub fn wildcard() -> SelectExpr {
142    SelectExpr::Wildcard(WildcardOptions::default())
143}
144
145/// Create an '*' [`Expr::Wildcard`] expression with the wildcard options
146pub fn wildcard_with_options(options: WildcardOptions) -> SelectExpr {
147    SelectExpr::Wildcard(options)
148}
149
150/// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table
151///
152/// # Example
153///
154/// ```rust
155/// # use datafusion_common::TableReference;
156/// # use datafusion_expr::{qualified_wildcard};
157/// let p = qualified_wildcard(TableReference::bare("t"));
158/// assert_eq!(p.to_string(), "t.*")
159/// ```
160pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> SelectExpr {
161    SelectExpr::QualifiedWildcard(qualifier.into(), WildcardOptions::default())
162}
163
164/// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options
165pub fn qualified_wildcard_with_options(
166    qualifier: impl Into<TableReference>,
167    options: WildcardOptions,
168) -> SelectExpr {
169    SelectExpr::QualifiedWildcard(qualifier.into(), options)
170}
171
172/// Return a new expression `left <op> right`
173pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
174    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
175}
176
177/// Return a new expression with a logical AND
178pub fn and(left: Expr, right: Expr) -> Expr {
179    Expr::BinaryExpr(BinaryExpr::new(
180        Box::new(left),
181        Operator::And,
182        Box::new(right),
183    ))
184}
185
186/// Return a new expression with a logical OR
187pub fn or(left: Expr, right: Expr) -> Expr {
188    Expr::BinaryExpr(BinaryExpr::new(
189        Box::new(left),
190        Operator::Or,
191        Box::new(right),
192    ))
193}
194
195/// Return a new expression with a logical NOT
196pub fn not(expr: Expr) -> Expr {
197    expr.not()
198}
199
200/// Return a new expression with bitwise AND
201pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
202    Expr::BinaryExpr(BinaryExpr::new(
203        Box::new(left),
204        Operator::BitwiseAnd,
205        Box::new(right),
206    ))
207}
208
209/// Return a new expression with bitwise OR
210pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
211    Expr::BinaryExpr(BinaryExpr::new(
212        Box::new(left),
213        Operator::BitwiseOr,
214        Box::new(right),
215    ))
216}
217
218/// Return a new expression with bitwise XOR
219pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
220    Expr::BinaryExpr(BinaryExpr::new(
221        Box::new(left),
222        Operator::BitwiseXor,
223        Box::new(right),
224    ))
225}
226
227/// Return a new expression with bitwise SHIFT RIGHT
228pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
229    Expr::BinaryExpr(BinaryExpr::new(
230        Box::new(left),
231        Operator::BitwiseShiftRight,
232        Box::new(right),
233    ))
234}
235
236/// Return a new expression with bitwise SHIFT LEFT
237pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
238    Expr::BinaryExpr(BinaryExpr::new(
239        Box::new(left),
240        Operator::BitwiseShiftLeft,
241        Box::new(right),
242    ))
243}
244
245/// Create an in_list expression
246pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
247    Expr::InList(InList::new(Box::new(expr), list, negated))
248}
249
250/// Create an EXISTS subquery expression
251pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
252    let outer_ref_columns = subquery.all_out_ref_exprs();
253    Expr::Exists(Exists {
254        subquery: Subquery {
255            subquery,
256            outer_ref_columns,
257            spans: Spans::new(),
258        },
259        negated: false,
260    })
261}
262
263/// Create a NOT EXISTS subquery expression
264pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
265    let outer_ref_columns = subquery.all_out_ref_exprs();
266    Expr::Exists(Exists {
267        subquery: Subquery {
268            subquery,
269            outer_ref_columns,
270            spans: Spans::new(),
271        },
272        negated: true,
273    })
274}
275
276/// Create an IN subquery expression
277pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
278    let outer_ref_columns = subquery.all_out_ref_exprs();
279    Expr::InSubquery(InSubquery::new(
280        Box::new(expr),
281        Subquery {
282            subquery,
283            outer_ref_columns,
284            spans: Spans::new(),
285        },
286        false,
287    ))
288}
289
290/// Create a NOT IN subquery expression
291pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
292    let outer_ref_columns = subquery.all_out_ref_exprs();
293    Expr::InSubquery(InSubquery::new(
294        Box::new(expr),
295        Subquery {
296            subquery,
297            outer_ref_columns,
298            spans: Spans::new(),
299        },
300        true,
301    ))
302}
303
304/// Create a scalar subquery expression
305pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
306    let outer_ref_columns = subquery.all_out_ref_exprs();
307    Expr::ScalarSubquery(Subquery {
308        subquery,
309        outer_ref_columns,
310        spans: Spans::new(),
311    })
312}
313
314/// Create a grouping set
315pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
316    Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
317}
318
319/// Create a grouping set for all combination of `exprs`
320pub fn cube(exprs: Vec<Expr>) -> Expr {
321    Expr::GroupingSet(GroupingSet::Cube(exprs))
322}
323
324/// Create a grouping set for rollup
325pub fn rollup(exprs: Vec<Expr>) -> Expr {
326    Expr::GroupingSet(GroupingSet::Rollup(exprs))
327}
328
329/// Create a cast expression
330pub fn cast(expr: Expr, data_type: DataType) -> Expr {
331    Expr::Cast(Cast::new(Box::new(expr), data_type))
332}
333
334/// Create a try cast expression
335pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
336    Expr::TryCast(TryCast::new(Box::new(expr), data_type))
337}
338
339/// Create is null expression
340pub fn is_null(expr: Expr) -> Expr {
341    Expr::IsNull(Box::new(expr))
342}
343
344/// Create is not null expression
345pub fn is_not_null(expr: Expr) -> Expr {
346    Expr::IsNotNull(Box::new(expr))
347}
348
349/// Create is true expression
350pub fn is_true(expr: Expr) -> Expr {
351    Expr::IsTrue(Box::new(expr))
352}
353
354/// Create is not true expression
355pub fn is_not_true(expr: Expr) -> Expr {
356    Expr::IsNotTrue(Box::new(expr))
357}
358
359/// Create is false expression
360pub fn is_false(expr: Expr) -> Expr {
361    Expr::IsFalse(Box::new(expr))
362}
363
364/// Create is not false expression
365pub fn is_not_false(expr: Expr) -> Expr {
366    Expr::IsNotFalse(Box::new(expr))
367}
368
369/// Create is unknown expression
370pub fn is_unknown(expr: Expr) -> Expr {
371    Expr::IsUnknown(Box::new(expr))
372}
373
374/// Create is not unknown expression
375pub fn is_not_unknown(expr: Expr) -> Expr {
376    Expr::IsNotUnknown(Box::new(expr))
377}
378
379/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
380pub fn case(expr: Expr) -> CaseBuilder {
381    CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
382}
383
384/// Create a CASE WHEN statement with boolean WHEN expressions and no base expression.
385pub fn when(when: Expr, then: Expr) -> CaseBuilder {
386    CaseBuilder::new(None, vec![when], vec![then], None)
387}
388
389/// Create a Unnest expression
390pub fn unnest(expr: Expr) -> Expr {
391    Expr::Unnest(Unnest {
392        expr: Box::new(expr),
393    })
394}
395
396/// Convenience method to create a new user defined scalar function (UDF) with a
397/// specific signature and specific return type.
398///
399/// Note this function does not expose all available features of [`ScalarUDF`],
400/// such as
401///
402/// * computing return types based on input types
403/// * multiple [`Signature`]s
404/// * aliases
405///
406/// See [`ScalarUDF`] for details and examples on how to use the full
407/// functionality.
408pub fn create_udf(
409    name: &str,
410    input_types: Vec<DataType>,
411    return_type: DataType,
412    volatility: Volatility,
413    fun: ScalarFunctionImplementation,
414) -> ScalarUDF {
415    ScalarUDF::from(SimpleScalarUDF::new(
416        name,
417        input_types,
418        return_type,
419        volatility,
420        fun,
421    ))
422}
423
424/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
425/// return type.
426#[derive(PartialEq, Eq, Hash)]
427pub struct SimpleScalarUDF {
428    name: String,
429    signature: Signature,
430    return_type: DataType,
431    fun: PtrEq<ScalarFunctionImplementation>,
432}
433
434impl Debug for SimpleScalarUDF {
435    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
436        f.debug_struct("SimpleScalarUDF")
437            .field("name", &self.name)
438            .field("signature", &self.signature)
439            .field("return_type", &self.return_type)
440            .field("fun", &"<FUNC>")
441            .finish()
442    }
443}
444
445impl SimpleScalarUDF {
446    /// Create a new `SimpleScalarUDF` from a name, input types, return type and
447    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
448    pub fn new(
449        name: impl Into<String>,
450        input_types: Vec<DataType>,
451        return_type: DataType,
452        volatility: Volatility,
453        fun: ScalarFunctionImplementation,
454    ) -> Self {
455        Self::new_with_signature(
456            name,
457            Signature::exact(input_types, volatility),
458            return_type,
459            fun,
460        )
461    }
462
463    /// Create a new `SimpleScalarUDF` from a name, signature, return type and
464    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
465    pub fn new_with_signature(
466        name: impl Into<String>,
467        signature: Signature,
468        return_type: DataType,
469        fun: ScalarFunctionImplementation,
470    ) -> Self {
471        Self {
472            name: name.into(),
473            signature,
474            return_type,
475            fun: fun.into(),
476        }
477    }
478}
479
480impl ScalarUDFImpl for SimpleScalarUDF {
481    fn as_any(&self) -> &dyn Any {
482        self
483    }
484
485    fn name(&self) -> &str {
486        &self.name
487    }
488
489    fn signature(&self) -> &Signature {
490        &self.signature
491    }
492
493    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
494        Ok(self.return_type.clone())
495    }
496
497    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
498        (self.fun)(&args.args)
499    }
500}
501
502/// Creates a new UDAF with a specific signature, state type and return type.
503/// The signature and state type must match the `Accumulator's implementation`.
504pub fn create_udaf(
505    name: &str,
506    input_type: Vec<DataType>,
507    return_type: Arc<DataType>,
508    volatility: Volatility,
509    accumulator: AccumulatorFactoryFunction,
510    state_type: Arc<Vec<DataType>>,
511) -> AggregateUDF {
512    let return_type = Arc::unwrap_or_clone(return_type);
513    let state_type = Arc::unwrap_or_clone(state_type);
514    let state_fields = state_type
515        .into_iter()
516        .enumerate()
517        .map(|(i, t)| Field::new(format!("{i}"), t, true))
518        .map(Arc::new)
519        .collect::<Vec<_>>();
520    AggregateUDF::from(SimpleAggregateUDF::new(
521        name,
522        input_type,
523        return_type,
524        volatility,
525        accumulator,
526        state_fields,
527    ))
528}
529
530/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
531/// return type.
532#[derive(PartialEq, Eq, Hash)]
533pub struct SimpleAggregateUDF {
534    name: String,
535    signature: Signature,
536    return_type: DataType,
537    accumulator: PtrEq<AccumulatorFactoryFunction>,
538    state_fields: Vec<FieldRef>,
539}
540
541impl Debug for SimpleAggregateUDF {
542    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
543        f.debug_struct("SimpleAggregateUDF")
544            .field("name", &self.name)
545            .field("signature", &self.signature)
546            .field("return_type", &self.return_type)
547            .field("fun", &"<FUNC>")
548            .finish()
549    }
550}
551
552impl SimpleAggregateUDF {
553    /// Create a new `SimpleAggregateUDF` from a name, input types, return type, state type and
554    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
555    pub fn new(
556        name: impl Into<String>,
557        input_type: Vec<DataType>,
558        return_type: DataType,
559        volatility: Volatility,
560        accumulator: AccumulatorFactoryFunction,
561        state_fields: Vec<FieldRef>,
562    ) -> Self {
563        let name = name.into();
564        let signature = Signature::exact(input_type, volatility);
565        Self {
566            name,
567            signature,
568            return_type,
569            accumulator: accumulator.into(),
570            state_fields,
571        }
572    }
573
574    /// Create a new `SimpleAggregateUDF` from a name, signature, return type, state type and
575    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
576    pub fn new_with_signature(
577        name: impl Into<String>,
578        signature: Signature,
579        return_type: DataType,
580        accumulator: AccumulatorFactoryFunction,
581        state_fields: Vec<FieldRef>,
582    ) -> Self {
583        let name = name.into();
584        Self {
585            name,
586            signature,
587            return_type,
588            accumulator: accumulator.into(),
589            state_fields,
590        }
591    }
592}
593
594impl AggregateUDFImpl for SimpleAggregateUDF {
595    fn as_any(&self) -> &dyn Any {
596        self
597    }
598
599    fn name(&self) -> &str {
600        &self.name
601    }
602
603    fn signature(&self) -> &Signature {
604        &self.signature
605    }
606
607    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
608        Ok(self.return_type.clone())
609    }
610
611    fn accumulator(
612        &self,
613        acc_args: AccumulatorArgs,
614    ) -> Result<Box<dyn crate::Accumulator>> {
615        (self.accumulator)(acc_args)
616    }
617
618    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
619        Ok(self.state_fields.clone())
620    }
621}
622
623/// Creates a new UDWF with a specific signature, state type and return type.
624///
625/// The signature and state type must match the [`PartitionEvaluator`]'s implementation`.
626///
627/// [`PartitionEvaluator`]: crate::PartitionEvaluator
628pub fn create_udwf(
629    name: &str,
630    input_type: DataType,
631    return_type: Arc<DataType>,
632    volatility: Volatility,
633    partition_evaluator_factory: PartitionEvaluatorFactory,
634) -> WindowUDF {
635    let return_type = Arc::unwrap_or_clone(return_type);
636    WindowUDF::from(SimpleWindowUDF::new(
637        name,
638        input_type,
639        return_type,
640        volatility,
641        partition_evaluator_factory,
642    ))
643}
644
645/// Implements [`WindowUDFImpl`] for functions that have a single signature and
646/// return type.
647#[derive(PartialEq, Eq, Hash)]
648pub struct SimpleWindowUDF {
649    name: String,
650    signature: Signature,
651    return_type: DataType,
652    partition_evaluator_factory: PtrEq<PartitionEvaluatorFactory>,
653}
654
655impl Debug for SimpleWindowUDF {
656    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
657        f.debug_struct("WindowUDF")
658            .field("name", &self.name)
659            .field("signature", &self.signature)
660            .field("return_type", &"<func>")
661            .field("partition_evaluator_factory", &"<FUNC>")
662            .finish()
663    }
664}
665
666impl SimpleWindowUDF {
667    /// Create a new `SimpleWindowUDF` from a name, input types, return type and
668    /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility
669    pub fn new(
670        name: impl Into<String>,
671        input_type: DataType,
672        return_type: DataType,
673        volatility: Volatility,
674        partition_evaluator_factory: PartitionEvaluatorFactory,
675    ) -> Self {
676        let name = name.into();
677        let signature = Signature::exact([input_type].to_vec(), volatility);
678        Self {
679            name,
680            signature,
681            return_type,
682            partition_evaluator_factory: partition_evaluator_factory.into(),
683        }
684    }
685}
686
687impl WindowUDFImpl for SimpleWindowUDF {
688    fn as_any(&self) -> &dyn Any {
689        self
690    }
691
692    fn name(&self) -> &str {
693        &self.name
694    }
695
696    fn signature(&self) -> &Signature {
697        &self.signature
698    }
699
700    fn partition_evaluator(
701        &self,
702        _partition_evaluator_args: PartitionEvaluatorArgs,
703    ) -> Result<Box<dyn PartitionEvaluator>> {
704        (self.partition_evaluator_factory)()
705    }
706
707    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
708        Ok(Arc::new(Field::new(
709            field_args.name(),
710            self.return_type.clone(),
711            true,
712        )))
713    }
714
715    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
716        LimitEffect::Unknown
717    }
718}
719
720pub fn interval_year_month_lit(value: &str) -> Expr {
721    let interval = parse_interval_year_month(value).ok();
722    Expr::Literal(ScalarValue::IntervalYearMonth(interval), None)
723}
724
725pub fn interval_datetime_lit(value: &str) -> Expr {
726    let interval = parse_interval_day_time(value).ok();
727    Expr::Literal(ScalarValue::IntervalDayTime(interval), None)
728}
729
730pub fn interval_month_day_nano_lit(value: &str) -> Expr {
731    let interval = parse_interval_month_day_nano(value).ok();
732    Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None)
733}
734
735/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
736///
737/// Adds methods to [`Expr`] that make it easy to set optional options
738/// such as `ORDER BY`, `FILTER` and `DISTINCT`
739///
740/// # Example
741/// ```no_run
742/// # use datafusion_common::Result;
743/// # use datafusion_expr::expr::NullTreatment;
744/// # use datafusion_expr::test::function_stub::count;
745/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col};
746/// # // first_value is an aggregate function in another crate
747/// # fn first_value(_arg: Expr) -> Expr {
748/// unimplemented!() }
749/// # fn main() -> Result<()> {
750/// // Create an aggregate count, filtering on column y > 5
751/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?;
752///
753/// // Find the first value in an aggregate sorted by column y
754/// // equivalent to:
755/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)`
756/// let sort_expr = col("y").sort(true, true);
757/// let agg = first_value(col("x"))
758///     .order_by(vec![sort_expr])
759///     .null_treatment(NullTreatment::IgnoreNulls)
760///     .build()?;
761///
762/// // Create a window expression for percent rank partitioned on column a
763/// // equivalent to:
764/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)`
765/// // percent_rank is an udwf function in another crate
766/// # fn percent_rank() -> Expr {
767/// unimplemented!() }
768/// let window = percent_rank()
769///     .partition_by(vec![col("a")])
770///     .order_by(vec![col("b").sort(true, true)])
771///     .null_treatment(NullTreatment::IgnoreNulls)
772///     .build()?;
773/// #     Ok(())
774/// # }
775/// ```
776pub trait ExprFunctionExt {
777    /// Add `ORDER BY <order_by>`
778    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
779    /// Add `FILTER <filter>`
780    fn filter(self, filter: Expr) -> ExprFuncBuilder;
781    /// Add `DISTINCT`
782    fn distinct(self) -> ExprFuncBuilder;
783    /// Add `RESPECT NULLS` or `IGNORE NULLS`
784    fn null_treatment(
785        self,
786        null_treatment: impl Into<Option<NullTreatment>>,
787    ) -> ExprFuncBuilder;
788    /// Add `PARTITION BY`
789    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
790    /// Add appropriate window frame conditions
791    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
792}
793
794#[derive(Debug, Clone)]
795pub enum ExprFuncKind {
796    Aggregate(AggregateFunction),
797    Window(Box<WindowFunction>),
798}
799
800/// Implementation of [`ExprFunctionExt`].
801///
802/// See [`ExprFunctionExt`] for usage and examples
803#[derive(Debug, Clone)]
804pub struct ExprFuncBuilder {
805    fun: Option<ExprFuncKind>,
806    order_by: Option<Vec<Sort>>,
807    filter: Option<Expr>,
808    distinct: bool,
809    null_treatment: Option<NullTreatment>,
810    partition_by: Option<Vec<Expr>>,
811    window_frame: Option<WindowFrame>,
812}
813
814impl ExprFuncBuilder {
815    /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`]
816    fn new(fun: Option<ExprFuncKind>) -> Self {
817        Self {
818            fun,
819            order_by: None,
820            filter: None,
821            distinct: false,
822            null_treatment: None,
823            partition_by: None,
824            window_frame: None,
825        }
826    }
827
828    /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
829    ///
830    /// # Errors:
831    ///
832    /// Returns an error if this builder  [`ExprFunctionExt`] was used with an
833    /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
834    pub fn build(self) -> Result<Expr> {
835        let Self {
836            fun,
837            order_by,
838            filter,
839            distinct,
840            null_treatment,
841            partition_by,
842            window_frame,
843        } = self;
844
845        let Some(fun) = fun else {
846            return plan_err!(
847                "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
848            );
849        };
850
851        let fun_expr = match fun {
852            ExprFuncKind::Aggregate(mut udaf) => {
853                udaf.params.order_by = order_by.unwrap_or_default();
854                udaf.params.filter = filter.map(Box::new);
855                udaf.params.distinct = distinct;
856                udaf.params.null_treatment = null_treatment;
857                Expr::AggregateFunction(udaf)
858            }
859            ExprFuncKind::Window(mut udwf) => {
860                let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
861                udwf.params.partition_by = partition_by.unwrap_or_default();
862                udwf.params.order_by = order_by.unwrap_or_default();
863                udwf.params.window_frame =
864                    window_frame.unwrap_or_else(|| WindowFrame::new(has_order_by));
865                udwf.params.filter = filter.map(Box::new);
866                udwf.params.null_treatment = null_treatment;
867                udwf.params.distinct = distinct;
868                Expr::WindowFunction(udwf)
869            }
870        };
871
872        Ok(fun_expr)
873    }
874}
875
876impl ExprFunctionExt for ExprFuncBuilder {
877    /// Add `ORDER BY <order_by>`
878    fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
879        self.order_by = Some(order_by);
880        self
881    }
882
883    /// Add `FILTER <filter>`
884    fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
885        self.filter = Some(filter);
886        self
887    }
888
889    /// Add `DISTINCT`
890    fn distinct(mut self) -> ExprFuncBuilder {
891        self.distinct = true;
892        self
893    }
894
895    /// Add `RESPECT NULLS` or `IGNORE NULLS`
896    fn null_treatment(
897        mut self,
898        null_treatment: impl Into<Option<NullTreatment>>,
899    ) -> ExprFuncBuilder {
900        self.null_treatment = null_treatment.into();
901        self
902    }
903
904    fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
905        self.partition_by = Some(partition_by);
906        self
907    }
908
909    fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
910        self.window_frame = Some(window_frame);
911        self
912    }
913}
914
915impl ExprFunctionExt for Expr {
916    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
917        let mut builder = match self {
918            Expr::AggregateFunction(udaf) => {
919                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
920            }
921            Expr::WindowFunction(udwf) => {
922                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
923            }
924            _ => ExprFuncBuilder::new(None),
925        };
926        if builder.fun.is_some() {
927            builder.order_by = Some(order_by);
928        }
929        builder
930    }
931    fn filter(self, filter: Expr) -> ExprFuncBuilder {
932        match self {
933            Expr::AggregateFunction(udaf) => {
934                let mut builder =
935                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
936                builder.filter = Some(filter);
937                builder
938            }
939            _ => ExprFuncBuilder::new(None),
940        }
941    }
942    fn distinct(self) -> ExprFuncBuilder {
943        match self {
944            Expr::AggregateFunction(udaf) => {
945                let mut builder =
946                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
947                builder.distinct = true;
948                builder
949            }
950            _ => ExprFuncBuilder::new(None),
951        }
952    }
953    fn null_treatment(
954        self,
955        null_treatment: impl Into<Option<NullTreatment>>,
956    ) -> ExprFuncBuilder {
957        let mut builder = match self {
958            Expr::AggregateFunction(udaf) => {
959                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
960            }
961            Expr::WindowFunction(udwf) => {
962                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
963            }
964            _ => ExprFuncBuilder::new(None),
965        };
966        if builder.fun.is_some() {
967            builder.null_treatment = null_treatment.into();
968        }
969        builder
970    }
971
972    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
973        match self {
974            Expr::WindowFunction(udwf) => {
975                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
976                builder.partition_by = Some(partition_by);
977                builder
978            }
979            _ => ExprFuncBuilder::new(None),
980        }
981    }
982
983    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
984        match self {
985            Expr::WindowFunction(udwf) => {
986                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
987                builder.window_frame = Some(window_frame);
988                builder
989            }
990            _ => ExprFuncBuilder::new(None),
991        }
992    }
993}
994
995#[cfg(test)]
996mod test {
997    use super::*;
998
999    #[test]
1000    fn filter_is_null_and_is_not_null() {
1001        let col_null = col("col1");
1002        let col_not_null = ident("col2");
1003        assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
1004        assert_eq!(
1005            format!("{}", col_not_null.is_not_null()),
1006            "col2 IS NOT NULL"
1007        );
1008    }
1009}