datafusion_expr/
expr_fn.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Functions for creating logical expressions
19
20use crate::expr::{
21    AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22    Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction,
23};
24use crate::function::{
25    AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26    StateFieldsArgs,
27};
28use crate::ptr_eq::PtrEq;
29use crate::select_expr::SelectExpr;
30use crate::{
31    conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
32    AggregateUDF, Expr, LimitEffect, LogicalPlan, Operator, PartitionEvaluator,
33    ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
34};
35use crate::{
36    AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
37};
38use arrow::compute::kernels::cast_utils::{
39    parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
40};
41use arrow::datatypes::{DataType, Field, FieldRef};
42use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference};
43use datafusion_functions_window_common::field::WindowUDFFieldArgs;
44use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
45use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
46use sqlparser::ast::NullTreatment;
47use std::any::Any;
48use std::fmt::Debug;
49use std::hash::Hash;
50use std::ops::Not;
51use std::sync::Arc;
52
53/// Create a column expression based on a qualified or unqualified column name. Will
54/// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase).
55///
56/// For example:
57///
58/// ```rust
59/// # use datafusion_expr::col;
60/// let c1 = col("a");
61/// let c2 = col("A");
62/// assert_eq!(c1, c2);
63///
64/// // note how quoting with double quotes preserves the case
65/// let c3 = col(r#""A""#);
66/// assert_ne!(c1, c3);
67/// ```
68pub fn col(ident: impl Into<Column>) -> Expr {
69    Expr::Column(ident.into())
70}
71
72/// Create an out reference column which hold a reference that has been resolved to a field
73/// outside of the current plan.
74pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
75    Expr::OuterReferenceColumn(dt, ident.into())
76}
77
78/// Create an unqualified column expression from the provided name, without normalizing
79/// the column.
80///
81/// For example:
82///
83/// ```rust
84/// # use datafusion_expr::{col, ident};
85/// let c1 = ident("A"); // not normalized staying as column 'A'
86/// let c2 = col("A"); // normalized via SQL rules becoming column 'a'
87/// assert_ne!(c1, c2);
88///
89/// let c3 = col(r#""A""#);
90/// assert_eq!(c1, c3);
91///
92/// let c4 = col("t1.a"); // parses as relation 't1' column 'a'
93/// let c5 = ident("t1.a"); // parses as column 't1.a'
94/// assert_ne!(c4, c5);
95/// ```
96pub fn ident(name: impl Into<String>) -> Expr {
97    Expr::Column(Column::from_name(name))
98}
99
100/// Create placeholder value that will be filled in (such as `$1`)
101///
102/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
103///
104/// # Example
105///
106/// ```rust
107/// # use datafusion_expr::{placeholder};
108/// let p = placeholder("$0"); // $0, refers to parameter 1
109/// assert_eq!(p.to_string(), "$0")
110/// ```
111pub fn placeholder(id: impl Into<String>) -> Expr {
112    Expr::Placeholder(Placeholder {
113        id: id.into(),
114        data_type: None,
115    })
116}
117
118/// Create an '*' [`Expr::Wildcard`] expression that matches all columns
119///
120/// # Example
121///
122/// ```rust
123/// # use datafusion_expr::{wildcard};
124/// let p = wildcard();
125/// assert_eq!(p.to_string(), "*")
126/// ```
127pub fn wildcard() -> SelectExpr {
128    SelectExpr::Wildcard(WildcardOptions::default())
129}
130
131/// Create an '*' [`Expr::Wildcard`] expression with the wildcard options
132pub fn wildcard_with_options(options: WildcardOptions) -> SelectExpr {
133    SelectExpr::Wildcard(options)
134}
135
136/// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table
137///
138/// # Example
139///
140/// ```rust
141/// # use datafusion_common::TableReference;
142/// # use datafusion_expr::{qualified_wildcard};
143/// let p = qualified_wildcard(TableReference::bare("t"));
144/// assert_eq!(p.to_string(), "t.*")
145/// ```
146pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> SelectExpr {
147    SelectExpr::QualifiedWildcard(qualifier.into(), WildcardOptions::default())
148}
149
150/// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options
151pub fn qualified_wildcard_with_options(
152    qualifier: impl Into<TableReference>,
153    options: WildcardOptions,
154) -> SelectExpr {
155    SelectExpr::QualifiedWildcard(qualifier.into(), options)
156}
157
158/// Return a new expression `left <op> right`
159pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
160    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
161}
162
163/// Return a new expression with a logical AND
164pub fn and(left: Expr, right: Expr) -> Expr {
165    Expr::BinaryExpr(BinaryExpr::new(
166        Box::new(left),
167        Operator::And,
168        Box::new(right),
169    ))
170}
171
172/// Return a new expression with a logical OR
173pub fn or(left: Expr, right: Expr) -> Expr {
174    Expr::BinaryExpr(BinaryExpr::new(
175        Box::new(left),
176        Operator::Or,
177        Box::new(right),
178    ))
179}
180
181/// Return a new expression with a logical NOT
182pub fn not(expr: Expr) -> Expr {
183    expr.not()
184}
185
186/// Return a new expression with bitwise AND
187pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
188    Expr::BinaryExpr(BinaryExpr::new(
189        Box::new(left),
190        Operator::BitwiseAnd,
191        Box::new(right),
192    ))
193}
194
195/// Return a new expression with bitwise OR
196pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
197    Expr::BinaryExpr(BinaryExpr::new(
198        Box::new(left),
199        Operator::BitwiseOr,
200        Box::new(right),
201    ))
202}
203
204/// Return a new expression with bitwise XOR
205pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
206    Expr::BinaryExpr(BinaryExpr::new(
207        Box::new(left),
208        Operator::BitwiseXor,
209        Box::new(right),
210    ))
211}
212
213/// Return a new expression with bitwise SHIFT RIGHT
214pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
215    Expr::BinaryExpr(BinaryExpr::new(
216        Box::new(left),
217        Operator::BitwiseShiftRight,
218        Box::new(right),
219    ))
220}
221
222/// Return a new expression with bitwise SHIFT LEFT
223pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
224    Expr::BinaryExpr(BinaryExpr::new(
225        Box::new(left),
226        Operator::BitwiseShiftLeft,
227        Box::new(right),
228    ))
229}
230
231/// Create an in_list expression
232pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
233    Expr::InList(InList::new(Box::new(expr), list, negated))
234}
235
236/// Create an EXISTS subquery expression
237pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
238    let outer_ref_columns = subquery.all_out_ref_exprs();
239    Expr::Exists(Exists {
240        subquery: Subquery {
241            subquery,
242            outer_ref_columns,
243            spans: Spans::new(),
244        },
245        negated: false,
246    })
247}
248
249/// Create a NOT EXISTS subquery expression
250pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
251    let outer_ref_columns = subquery.all_out_ref_exprs();
252    Expr::Exists(Exists {
253        subquery: Subquery {
254            subquery,
255            outer_ref_columns,
256            spans: Spans::new(),
257        },
258        negated: true,
259    })
260}
261
262/// Create an IN subquery expression
263pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
264    let outer_ref_columns = subquery.all_out_ref_exprs();
265    Expr::InSubquery(InSubquery::new(
266        Box::new(expr),
267        Subquery {
268            subquery,
269            outer_ref_columns,
270            spans: Spans::new(),
271        },
272        false,
273    ))
274}
275
276/// Create a NOT IN subquery expression
277pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
278    let outer_ref_columns = subquery.all_out_ref_exprs();
279    Expr::InSubquery(InSubquery::new(
280        Box::new(expr),
281        Subquery {
282            subquery,
283            outer_ref_columns,
284            spans: Spans::new(),
285        },
286        true,
287    ))
288}
289
290/// Create a scalar subquery expression
291pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
292    let outer_ref_columns = subquery.all_out_ref_exprs();
293    Expr::ScalarSubquery(Subquery {
294        subquery,
295        outer_ref_columns,
296        spans: Spans::new(),
297    })
298}
299
300/// Create a grouping set
301pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
302    Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
303}
304
305/// Create a grouping set for all combination of `exprs`
306pub fn cube(exprs: Vec<Expr>) -> Expr {
307    Expr::GroupingSet(GroupingSet::Cube(exprs))
308}
309
310/// Create a grouping set for rollup
311pub fn rollup(exprs: Vec<Expr>) -> Expr {
312    Expr::GroupingSet(GroupingSet::Rollup(exprs))
313}
314
315/// Create a cast expression
316pub fn cast(expr: Expr, data_type: DataType) -> Expr {
317    Expr::Cast(Cast::new(Box::new(expr), data_type))
318}
319
320/// Create a try cast expression
321pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
322    Expr::TryCast(TryCast::new(Box::new(expr), data_type))
323}
324
325/// Create is null expression
326pub fn is_null(expr: Expr) -> Expr {
327    Expr::IsNull(Box::new(expr))
328}
329
330/// Create is true expression
331pub fn is_true(expr: Expr) -> Expr {
332    Expr::IsTrue(Box::new(expr))
333}
334
335/// Create is not true expression
336pub fn is_not_true(expr: Expr) -> Expr {
337    Expr::IsNotTrue(Box::new(expr))
338}
339
340/// Create is false expression
341pub fn is_false(expr: Expr) -> Expr {
342    Expr::IsFalse(Box::new(expr))
343}
344
345/// Create is not false expression
346pub fn is_not_false(expr: Expr) -> Expr {
347    Expr::IsNotFalse(Box::new(expr))
348}
349
350/// Create is unknown expression
351pub fn is_unknown(expr: Expr) -> Expr {
352    Expr::IsUnknown(Box::new(expr))
353}
354
355/// Create is not unknown expression
356pub fn is_not_unknown(expr: Expr) -> Expr {
357    Expr::IsNotUnknown(Box::new(expr))
358}
359
360/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
361pub fn case(expr: Expr) -> CaseBuilder {
362    CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
363}
364
365/// Create a CASE WHEN statement with boolean WHEN expressions and no base expression.
366pub fn when(when: Expr, then: Expr) -> CaseBuilder {
367    CaseBuilder::new(None, vec![when], vec![then], None)
368}
369
370/// Create a Unnest expression
371pub fn unnest(expr: Expr) -> Expr {
372    Expr::Unnest(Unnest {
373        expr: Box::new(expr),
374    })
375}
376
377/// Convenience method to create a new user defined scalar function (UDF) with a
378/// specific signature and specific return type.
379///
380/// Note this function does not expose all available features of [`ScalarUDF`],
381/// such as
382///
383/// * computing return types based on input types
384/// * multiple [`Signature`]s
385/// * aliases
386///
387/// See [`ScalarUDF`] for details and examples on how to use the full
388/// functionality.
389pub fn create_udf(
390    name: &str,
391    input_types: Vec<DataType>,
392    return_type: DataType,
393    volatility: Volatility,
394    fun: ScalarFunctionImplementation,
395) -> ScalarUDF {
396    ScalarUDF::from(SimpleScalarUDF::new(
397        name,
398        input_types,
399        return_type,
400        volatility,
401        fun,
402    ))
403}
404
405/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
406/// return type.
407#[derive(PartialEq, Eq, Hash)]
408pub struct SimpleScalarUDF {
409    name: String,
410    signature: Signature,
411    return_type: DataType,
412    fun: PtrEq<ScalarFunctionImplementation>,
413}
414
415impl Debug for SimpleScalarUDF {
416    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
417        f.debug_struct("SimpleScalarUDF")
418            .field("name", &self.name)
419            .field("signature", &self.signature)
420            .field("return_type", &self.return_type)
421            .field("fun", &"<FUNC>")
422            .finish()
423    }
424}
425
426impl SimpleScalarUDF {
427    /// Create a new `SimpleScalarUDF` from a name, input types, return type and
428    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
429    pub fn new(
430        name: impl Into<String>,
431        input_types: Vec<DataType>,
432        return_type: DataType,
433        volatility: Volatility,
434        fun: ScalarFunctionImplementation,
435    ) -> Self {
436        Self::new_with_signature(
437            name,
438            Signature::exact(input_types, volatility),
439            return_type,
440            fun,
441        )
442    }
443
444    /// Create a new `SimpleScalarUDF` from a name, signature, return type and
445    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
446    pub fn new_with_signature(
447        name: impl Into<String>,
448        signature: Signature,
449        return_type: DataType,
450        fun: ScalarFunctionImplementation,
451    ) -> Self {
452        Self {
453            name: name.into(),
454            signature,
455            return_type,
456            fun: fun.into(),
457        }
458    }
459}
460
461impl ScalarUDFImpl for SimpleScalarUDF {
462    fn as_any(&self) -> &dyn Any {
463        self
464    }
465
466    fn name(&self) -> &str {
467        &self.name
468    }
469
470    fn signature(&self) -> &Signature {
471        &self.signature
472    }
473
474    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
475        Ok(self.return_type.clone())
476    }
477
478    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
479        (self.fun)(&args.args)
480    }
481}
482
483/// Creates a new UDAF with a specific signature, state type and return type.
484/// The signature and state type must match the `Accumulator's implementation`.
485pub fn create_udaf(
486    name: &str,
487    input_type: Vec<DataType>,
488    return_type: Arc<DataType>,
489    volatility: Volatility,
490    accumulator: AccumulatorFactoryFunction,
491    state_type: Arc<Vec<DataType>>,
492) -> AggregateUDF {
493    let return_type = Arc::unwrap_or_clone(return_type);
494    let state_type = Arc::unwrap_or_clone(state_type);
495    let state_fields = state_type
496        .into_iter()
497        .enumerate()
498        .map(|(i, t)| Field::new(format!("{i}"), t, true))
499        .map(Arc::new)
500        .collect::<Vec<_>>();
501    AggregateUDF::from(SimpleAggregateUDF::new(
502        name,
503        input_type,
504        return_type,
505        volatility,
506        accumulator,
507        state_fields,
508    ))
509}
510
511/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
512/// return type.
513#[derive(PartialEq, Eq, Hash)]
514pub struct SimpleAggregateUDF {
515    name: String,
516    signature: Signature,
517    return_type: DataType,
518    accumulator: PtrEq<AccumulatorFactoryFunction>,
519    state_fields: Vec<FieldRef>,
520}
521
522impl Debug for SimpleAggregateUDF {
523    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
524        f.debug_struct("SimpleAggregateUDF")
525            .field("name", &self.name)
526            .field("signature", &self.signature)
527            .field("return_type", &self.return_type)
528            .field("fun", &"<FUNC>")
529            .finish()
530    }
531}
532
533impl SimpleAggregateUDF {
534    /// Create a new `SimpleAggregateUDF` from a name, input types, return type, state type and
535    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
536    pub fn new(
537        name: impl Into<String>,
538        input_type: Vec<DataType>,
539        return_type: DataType,
540        volatility: Volatility,
541        accumulator: AccumulatorFactoryFunction,
542        state_fields: Vec<FieldRef>,
543    ) -> Self {
544        let name = name.into();
545        let signature = Signature::exact(input_type, volatility);
546        Self {
547            name,
548            signature,
549            return_type,
550            accumulator: accumulator.into(),
551            state_fields,
552        }
553    }
554
555    /// Create a new `SimpleAggregateUDF` from a name, signature, return type, state type and
556    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
557    pub fn new_with_signature(
558        name: impl Into<String>,
559        signature: Signature,
560        return_type: DataType,
561        accumulator: AccumulatorFactoryFunction,
562        state_fields: Vec<FieldRef>,
563    ) -> Self {
564        let name = name.into();
565        Self {
566            name,
567            signature,
568            return_type,
569            accumulator: accumulator.into(),
570            state_fields,
571        }
572    }
573}
574
575impl AggregateUDFImpl for SimpleAggregateUDF {
576    fn as_any(&self) -> &dyn Any {
577        self
578    }
579
580    fn name(&self) -> &str {
581        &self.name
582    }
583
584    fn signature(&self) -> &Signature {
585        &self.signature
586    }
587
588    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
589        Ok(self.return_type.clone())
590    }
591
592    fn accumulator(
593        &self,
594        acc_args: AccumulatorArgs,
595    ) -> Result<Box<dyn crate::Accumulator>> {
596        (self.accumulator)(acc_args)
597    }
598
599    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
600        Ok(self.state_fields.clone())
601    }
602}
603
604/// Creates a new UDWF with a specific signature, state type and return type.
605///
606/// The signature and state type must match the [`PartitionEvaluator`]'s implementation`.
607///
608/// [`PartitionEvaluator`]: crate::PartitionEvaluator
609pub fn create_udwf(
610    name: &str,
611    input_type: DataType,
612    return_type: Arc<DataType>,
613    volatility: Volatility,
614    partition_evaluator_factory: PartitionEvaluatorFactory,
615) -> WindowUDF {
616    let return_type = Arc::unwrap_or_clone(return_type);
617    WindowUDF::from(SimpleWindowUDF::new(
618        name,
619        input_type,
620        return_type,
621        volatility,
622        partition_evaluator_factory,
623    ))
624}
625
626/// Implements [`WindowUDFImpl`] for functions that have a single signature and
627/// return type.
628#[derive(PartialEq, Eq, Hash)]
629pub struct SimpleWindowUDF {
630    name: String,
631    signature: Signature,
632    return_type: DataType,
633    partition_evaluator_factory: PtrEq<PartitionEvaluatorFactory>,
634}
635
636impl Debug for SimpleWindowUDF {
637    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
638        f.debug_struct("WindowUDF")
639            .field("name", &self.name)
640            .field("signature", &self.signature)
641            .field("return_type", &"<func>")
642            .field("partition_evaluator_factory", &"<FUNC>")
643            .finish()
644    }
645}
646
647impl SimpleWindowUDF {
648    /// Create a new `SimpleWindowUDF` from a name, input types, return type and
649    /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility
650    pub fn new(
651        name: impl Into<String>,
652        input_type: DataType,
653        return_type: DataType,
654        volatility: Volatility,
655        partition_evaluator_factory: PartitionEvaluatorFactory,
656    ) -> Self {
657        let name = name.into();
658        let signature = Signature::exact([input_type].to_vec(), volatility);
659        Self {
660            name,
661            signature,
662            return_type,
663            partition_evaluator_factory: partition_evaluator_factory.into(),
664        }
665    }
666}
667
668impl WindowUDFImpl for SimpleWindowUDF {
669    fn as_any(&self) -> &dyn Any {
670        self
671    }
672
673    fn name(&self) -> &str {
674        &self.name
675    }
676
677    fn signature(&self) -> &Signature {
678        &self.signature
679    }
680
681    fn partition_evaluator(
682        &self,
683        _partition_evaluator_args: PartitionEvaluatorArgs,
684    ) -> Result<Box<dyn PartitionEvaluator>> {
685        (self.partition_evaluator_factory)()
686    }
687
688    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
689        Ok(Arc::new(Field::new(
690            field_args.name(),
691            self.return_type.clone(),
692            true,
693        )))
694    }
695
696    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
697        LimitEffect::Unknown
698    }
699}
700
701pub fn interval_year_month_lit(value: &str) -> Expr {
702    let interval = parse_interval_year_month(value).ok();
703    Expr::Literal(ScalarValue::IntervalYearMonth(interval), None)
704}
705
706pub fn interval_datetime_lit(value: &str) -> Expr {
707    let interval = parse_interval_day_time(value).ok();
708    Expr::Literal(ScalarValue::IntervalDayTime(interval), None)
709}
710
711pub fn interval_month_day_nano_lit(value: &str) -> Expr {
712    let interval = parse_interval_month_day_nano(value).ok();
713    Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None)
714}
715
716/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
717///
718/// Adds methods to [`Expr`] that make it easy to set optional options
719/// such as `ORDER BY`, `FILTER` and `DISTINCT`
720///
721/// # Example
722/// ```no_run
723/// # use datafusion_common::Result;
724/// # use datafusion_expr::test::function_stub::count;
725/// # use sqlparser::ast::NullTreatment;
726/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col};
727/// # // first_value is an aggregate function in another crate
728/// # fn first_value(_arg: Expr) -> Expr {
729/// unimplemented!() }
730/// # fn main() -> Result<()> {
731/// // Create an aggregate count, filtering on column y > 5
732/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?;
733///
734/// // Find the first value in an aggregate sorted by column y
735/// // equivalent to:
736/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)`
737/// let sort_expr = col("y").sort(true, true);
738/// let agg = first_value(col("x"))
739///     .order_by(vec![sort_expr])
740///     .null_treatment(NullTreatment::IgnoreNulls)
741///     .build()?;
742///
743/// // Create a window expression for percent rank partitioned on column a
744/// // equivalent to:
745/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)`
746/// // percent_rank is an udwf function in another crate
747/// # fn percent_rank() -> Expr {
748/// unimplemented!() }
749/// let window = percent_rank()
750///     .partition_by(vec![col("a")])
751///     .order_by(vec![col("b").sort(true, true)])
752///     .null_treatment(NullTreatment::IgnoreNulls)
753///     .build()?;
754/// #     Ok(())
755/// # }
756/// ```
757pub trait ExprFunctionExt {
758    /// Add `ORDER BY <order_by>`
759    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
760    /// Add `FILTER <filter>`
761    fn filter(self, filter: Expr) -> ExprFuncBuilder;
762    /// Add `DISTINCT`
763    fn distinct(self) -> ExprFuncBuilder;
764    /// Add `RESPECT NULLS` or `IGNORE NULLS`
765    fn null_treatment(
766        self,
767        null_treatment: impl Into<Option<NullTreatment>>,
768    ) -> ExprFuncBuilder;
769    /// Add `PARTITION BY`
770    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
771    /// Add appropriate window frame conditions
772    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
773}
774
775#[derive(Debug, Clone)]
776pub enum ExprFuncKind {
777    Aggregate(AggregateFunction),
778    Window(Box<WindowFunction>),
779}
780
781/// Implementation of [`ExprFunctionExt`].
782///
783/// See [`ExprFunctionExt`] for usage and examples
784#[derive(Debug, Clone)]
785pub struct ExprFuncBuilder {
786    fun: Option<ExprFuncKind>,
787    order_by: Option<Vec<Sort>>,
788    filter: Option<Expr>,
789    distinct: bool,
790    null_treatment: Option<NullTreatment>,
791    partition_by: Option<Vec<Expr>>,
792    window_frame: Option<WindowFrame>,
793}
794
795impl ExprFuncBuilder {
796    /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`]
797    fn new(fun: Option<ExprFuncKind>) -> Self {
798        Self {
799            fun,
800            order_by: None,
801            filter: None,
802            distinct: false,
803            null_treatment: None,
804            partition_by: None,
805            window_frame: None,
806        }
807    }
808
809    /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
810    ///
811    /// # Errors:
812    ///
813    /// Returns an error if this builder  [`ExprFunctionExt`] was used with an
814    /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
815    pub fn build(self) -> Result<Expr> {
816        let Self {
817            fun,
818            order_by,
819            filter,
820            distinct,
821            null_treatment,
822            partition_by,
823            window_frame,
824        } = self;
825
826        let Some(fun) = fun else {
827            return plan_err!(
828                "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
829            );
830        };
831
832        let fun_expr = match fun {
833            ExprFuncKind::Aggregate(mut udaf) => {
834                udaf.params.order_by = order_by.unwrap_or_default();
835                udaf.params.filter = filter.map(Box::new);
836                udaf.params.distinct = distinct;
837                udaf.params.null_treatment = null_treatment;
838                Expr::AggregateFunction(udaf)
839            }
840            ExprFuncKind::Window(mut udwf) => {
841                let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
842                udwf.params.partition_by = partition_by.unwrap_or_default();
843                udwf.params.order_by = order_by.unwrap_or_default();
844                udwf.params.window_frame =
845                    window_frame.unwrap_or_else(|| WindowFrame::new(has_order_by));
846                udwf.params.filter = filter.map(Box::new);
847                udwf.params.null_treatment = null_treatment;
848                udwf.params.distinct = distinct;
849                Expr::WindowFunction(udwf)
850            }
851        };
852
853        Ok(fun_expr)
854    }
855}
856
857impl ExprFunctionExt for ExprFuncBuilder {
858    /// Add `ORDER BY <order_by>`
859    fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
860        self.order_by = Some(order_by);
861        self
862    }
863
864    /// Add `FILTER <filter>`
865    fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
866        self.filter = Some(filter);
867        self
868    }
869
870    /// Add `DISTINCT`
871    fn distinct(mut self) -> ExprFuncBuilder {
872        self.distinct = true;
873        self
874    }
875
876    /// Add `RESPECT NULLS` or `IGNORE NULLS`
877    fn null_treatment(
878        mut self,
879        null_treatment: impl Into<Option<NullTreatment>>,
880    ) -> ExprFuncBuilder {
881        self.null_treatment = null_treatment.into();
882        self
883    }
884
885    fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
886        self.partition_by = Some(partition_by);
887        self
888    }
889
890    fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
891        self.window_frame = Some(window_frame);
892        self
893    }
894}
895
896impl ExprFunctionExt for Expr {
897    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
898        let mut builder = match self {
899            Expr::AggregateFunction(udaf) => {
900                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
901            }
902            Expr::WindowFunction(udwf) => {
903                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
904            }
905            _ => ExprFuncBuilder::new(None),
906        };
907        if builder.fun.is_some() {
908            builder.order_by = Some(order_by);
909        }
910        builder
911    }
912    fn filter(self, filter: Expr) -> ExprFuncBuilder {
913        match self {
914            Expr::AggregateFunction(udaf) => {
915                let mut builder =
916                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
917                builder.filter = Some(filter);
918                builder
919            }
920            _ => ExprFuncBuilder::new(None),
921        }
922    }
923    fn distinct(self) -> ExprFuncBuilder {
924        match self {
925            Expr::AggregateFunction(udaf) => {
926                let mut builder =
927                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
928                builder.distinct = true;
929                builder
930            }
931            _ => ExprFuncBuilder::new(None),
932        }
933    }
934    fn null_treatment(
935        self,
936        null_treatment: impl Into<Option<NullTreatment>>,
937    ) -> ExprFuncBuilder {
938        let mut builder = match self {
939            Expr::AggregateFunction(udaf) => {
940                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
941            }
942            Expr::WindowFunction(udwf) => {
943                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
944            }
945            _ => ExprFuncBuilder::new(None),
946        };
947        if builder.fun.is_some() {
948            builder.null_treatment = null_treatment.into();
949        }
950        builder
951    }
952
953    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
954        match self {
955            Expr::WindowFunction(udwf) => {
956                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
957                builder.partition_by = Some(partition_by);
958                builder
959            }
960            _ => ExprFuncBuilder::new(None),
961        }
962    }
963
964    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
965        match self {
966            Expr::WindowFunction(udwf) => {
967                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
968                builder.window_frame = Some(window_frame);
969                builder
970            }
971            _ => ExprFuncBuilder::new(None),
972        }
973    }
974}
975
976#[cfg(test)]
977mod test {
978    use super::*;
979
980    #[test]
981    fn filter_is_null_and_is_not_null() {
982        let col_null = col("col1");
983        let col_not_null = ident("col2");
984        assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
985        assert_eq!(
986            format!("{}", col_not_null.is_not_null()),
987            "col2 IS NOT NULL"
988        );
989    }
990}