datafusion_expr/
expr_fn.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Functions for creating logical expressions
19
20use crate::expr::{
21    AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22    Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction,
23};
24use crate::function::{
25    AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26    StateFieldsArgs,
27};
28use crate::ptr_eq::PtrEq;
29use crate::select_expr::SelectExpr;
30use crate::{
31    conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
32    AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs,
33    ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
34};
35use crate::{
36    AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
37};
38use arrow::compute::kernels::cast_utils::{
39    parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
40};
41use arrow::datatypes::{DataType, Field, FieldRef};
42use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference};
43use datafusion_functions_window_common::field::WindowUDFFieldArgs;
44use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
45use sqlparser::ast::NullTreatment;
46use std::any::Any;
47use std::fmt::Debug;
48use std::hash::Hash;
49use std::ops::Not;
50use std::sync::Arc;
51
52/// Create a column expression based on a qualified or unqualified column name. Will
53/// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase).
54///
55/// For example:
56///
57/// ```rust
58/// # use datafusion_expr::col;
59/// let c1 = col("a");
60/// let c2 = col("A");
61/// assert_eq!(c1, c2);
62///
63/// // note how quoting with double quotes preserves the case
64/// let c3 = col(r#""A""#);
65/// assert_ne!(c1, c3);
66/// ```
67pub fn col(ident: impl Into<Column>) -> Expr {
68    Expr::Column(ident.into())
69}
70
71/// Create an out reference column which hold a reference that has been resolved to a field
72/// outside of the current plan.
73pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
74    Expr::OuterReferenceColumn(dt, ident.into())
75}
76
77/// Create an unqualified column expression from the provided name, without normalizing
78/// the column.
79///
80/// For example:
81///
82/// ```rust
83/// # use datafusion_expr::{col, ident};
84/// let c1 = ident("A"); // not normalized staying as column 'A'
85/// let c2 = col("A"); // normalized via SQL rules becoming column 'a'
86/// assert_ne!(c1, c2);
87///
88/// let c3 = col(r#""A""#);
89/// assert_eq!(c1, c3);
90///
91/// let c4 = col("t1.a"); // parses as relation 't1' column 'a'
92/// let c5 = ident("t1.a"); // parses as column 't1.a'
93/// assert_ne!(c4, c5);
94/// ```
95pub fn ident(name: impl Into<String>) -> Expr {
96    Expr::Column(Column::from_name(name))
97}
98
99/// Create placeholder value that will be filled in (such as `$1`)
100///
101/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
102///
103/// # Example
104///
105/// ```rust
106/// # use datafusion_expr::{placeholder};
107/// let p = placeholder("$0"); // $0, refers to parameter 1
108/// assert_eq!(p.to_string(), "$0")
109/// ```
110pub fn placeholder(id: impl Into<String>) -> Expr {
111    Expr::Placeholder(Placeholder {
112        id: id.into(),
113        data_type: None,
114    })
115}
116
117/// Create an '*' [`Expr::Wildcard`] expression that matches all columns
118///
119/// # Example
120///
121/// ```rust
122/// # use datafusion_expr::{wildcard};
123/// let p = wildcard();
124/// assert_eq!(p.to_string(), "*")
125/// ```
126pub fn wildcard() -> SelectExpr {
127    SelectExpr::Wildcard(WildcardOptions::default())
128}
129
130/// Create an '*' [`Expr::Wildcard`] expression with the wildcard options
131pub fn wildcard_with_options(options: WildcardOptions) -> SelectExpr {
132    SelectExpr::Wildcard(options)
133}
134
135/// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table
136///
137/// # Example
138///
139/// ```rust
140/// # use datafusion_common::TableReference;
141/// # use datafusion_expr::{qualified_wildcard};
142/// let p = qualified_wildcard(TableReference::bare("t"));
143/// assert_eq!(p.to_string(), "t.*")
144/// ```
145pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> SelectExpr {
146    SelectExpr::QualifiedWildcard(qualifier.into(), WildcardOptions::default())
147}
148
149/// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options
150pub fn qualified_wildcard_with_options(
151    qualifier: impl Into<TableReference>,
152    options: WildcardOptions,
153) -> SelectExpr {
154    SelectExpr::QualifiedWildcard(qualifier.into(), options)
155}
156
157/// Return a new expression `left <op> right`
158pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
159    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
160}
161
162/// Return a new expression with a logical AND
163pub fn and(left: Expr, right: Expr) -> Expr {
164    Expr::BinaryExpr(BinaryExpr::new(
165        Box::new(left),
166        Operator::And,
167        Box::new(right),
168    ))
169}
170
171/// Return a new expression with a logical OR
172pub fn or(left: Expr, right: Expr) -> Expr {
173    Expr::BinaryExpr(BinaryExpr::new(
174        Box::new(left),
175        Operator::Or,
176        Box::new(right),
177    ))
178}
179
180/// Return a new expression with a logical NOT
181pub fn not(expr: Expr) -> Expr {
182    expr.not()
183}
184
185/// Return a new expression with bitwise AND
186pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
187    Expr::BinaryExpr(BinaryExpr::new(
188        Box::new(left),
189        Operator::BitwiseAnd,
190        Box::new(right),
191    ))
192}
193
194/// Return a new expression with bitwise OR
195pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
196    Expr::BinaryExpr(BinaryExpr::new(
197        Box::new(left),
198        Operator::BitwiseOr,
199        Box::new(right),
200    ))
201}
202
203/// Return a new expression with bitwise XOR
204pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
205    Expr::BinaryExpr(BinaryExpr::new(
206        Box::new(left),
207        Operator::BitwiseXor,
208        Box::new(right),
209    ))
210}
211
212/// Return a new expression with bitwise SHIFT RIGHT
213pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
214    Expr::BinaryExpr(BinaryExpr::new(
215        Box::new(left),
216        Operator::BitwiseShiftRight,
217        Box::new(right),
218    ))
219}
220
221/// Return a new expression with bitwise SHIFT LEFT
222pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
223    Expr::BinaryExpr(BinaryExpr::new(
224        Box::new(left),
225        Operator::BitwiseShiftLeft,
226        Box::new(right),
227    ))
228}
229
230/// Create an in_list expression
231pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
232    Expr::InList(InList::new(Box::new(expr), list, negated))
233}
234
235/// Create an EXISTS subquery expression
236pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
237    let outer_ref_columns = subquery.all_out_ref_exprs();
238    Expr::Exists(Exists {
239        subquery: Subquery {
240            subquery,
241            outer_ref_columns,
242            spans: Spans::new(),
243        },
244        negated: false,
245    })
246}
247
248/// Create a NOT EXISTS subquery expression
249pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
250    let outer_ref_columns = subquery.all_out_ref_exprs();
251    Expr::Exists(Exists {
252        subquery: Subquery {
253            subquery,
254            outer_ref_columns,
255            spans: Spans::new(),
256        },
257        negated: true,
258    })
259}
260
261/// Create an IN subquery expression
262pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
263    let outer_ref_columns = subquery.all_out_ref_exprs();
264    Expr::InSubquery(InSubquery::new(
265        Box::new(expr),
266        Subquery {
267            subquery,
268            outer_ref_columns,
269            spans: Spans::new(),
270        },
271        false,
272    ))
273}
274
275/// Create a NOT IN subquery expression
276pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
277    let outer_ref_columns = subquery.all_out_ref_exprs();
278    Expr::InSubquery(InSubquery::new(
279        Box::new(expr),
280        Subquery {
281            subquery,
282            outer_ref_columns,
283            spans: Spans::new(),
284        },
285        true,
286    ))
287}
288
289/// Create a scalar subquery expression
290pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
291    let outer_ref_columns = subquery.all_out_ref_exprs();
292    Expr::ScalarSubquery(Subquery {
293        subquery,
294        outer_ref_columns,
295        spans: Spans::new(),
296    })
297}
298
299/// Create a grouping set
300pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
301    Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
302}
303
304/// Create a grouping set for all combination of `exprs`
305pub fn cube(exprs: Vec<Expr>) -> Expr {
306    Expr::GroupingSet(GroupingSet::Cube(exprs))
307}
308
309/// Create a grouping set for rollup
310pub fn rollup(exprs: Vec<Expr>) -> Expr {
311    Expr::GroupingSet(GroupingSet::Rollup(exprs))
312}
313
314/// Create a cast expression
315pub fn cast(expr: Expr, data_type: DataType) -> Expr {
316    Expr::Cast(Cast::new(Box::new(expr), data_type))
317}
318
319/// Create a try cast expression
320pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
321    Expr::TryCast(TryCast::new(Box::new(expr), data_type))
322}
323
324/// Create is null expression
325pub fn is_null(expr: Expr) -> Expr {
326    Expr::IsNull(Box::new(expr))
327}
328
329/// Create is true expression
330pub fn is_true(expr: Expr) -> Expr {
331    Expr::IsTrue(Box::new(expr))
332}
333
334/// Create is not true expression
335pub fn is_not_true(expr: Expr) -> Expr {
336    Expr::IsNotTrue(Box::new(expr))
337}
338
339/// Create is false expression
340pub fn is_false(expr: Expr) -> Expr {
341    Expr::IsFalse(Box::new(expr))
342}
343
344/// Create is not false expression
345pub fn is_not_false(expr: Expr) -> Expr {
346    Expr::IsNotFalse(Box::new(expr))
347}
348
349/// Create is unknown expression
350pub fn is_unknown(expr: Expr) -> Expr {
351    Expr::IsUnknown(Box::new(expr))
352}
353
354/// Create is not unknown expression
355pub fn is_not_unknown(expr: Expr) -> Expr {
356    Expr::IsNotUnknown(Box::new(expr))
357}
358
359/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
360pub fn case(expr: Expr) -> CaseBuilder {
361    CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
362}
363
364/// Create a CASE WHEN statement with boolean WHEN expressions and no base expression.
365pub fn when(when: Expr, then: Expr) -> CaseBuilder {
366    CaseBuilder::new(None, vec![when], vec![then], None)
367}
368
369/// Create a Unnest expression
370pub fn unnest(expr: Expr) -> Expr {
371    Expr::Unnest(Unnest {
372        expr: Box::new(expr),
373    })
374}
375
376/// Convenience method to create a new user defined scalar function (UDF) with a
377/// specific signature and specific return type.
378///
379/// Note this function does not expose all available features of [`ScalarUDF`],
380/// such as
381///
382/// * computing return types based on input types
383/// * multiple [`Signature`]s
384/// * aliases
385///
386/// See [`ScalarUDF`] for details and examples on how to use the full
387/// functionality.
388pub fn create_udf(
389    name: &str,
390    input_types: Vec<DataType>,
391    return_type: DataType,
392    volatility: Volatility,
393    fun: ScalarFunctionImplementation,
394) -> ScalarUDF {
395    ScalarUDF::from(SimpleScalarUDF::new(
396        name,
397        input_types,
398        return_type,
399        volatility,
400        fun,
401    ))
402}
403
404/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
405/// return type.
406#[derive(PartialEq, Eq, Hash)]
407pub struct SimpleScalarUDF {
408    name: String,
409    signature: Signature,
410    return_type: DataType,
411    fun: PtrEq<ScalarFunctionImplementation>,
412}
413
414impl Debug for SimpleScalarUDF {
415    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
416        f.debug_struct("SimpleScalarUDF")
417            .field("name", &self.name)
418            .field("signature", &self.signature)
419            .field("return_type", &self.return_type)
420            .field("fun", &"<FUNC>")
421            .finish()
422    }
423}
424
425impl SimpleScalarUDF {
426    /// Create a new `SimpleScalarUDF` from a name, input types, return type and
427    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
428    pub fn new(
429        name: impl Into<String>,
430        input_types: Vec<DataType>,
431        return_type: DataType,
432        volatility: Volatility,
433        fun: ScalarFunctionImplementation,
434    ) -> Self {
435        Self::new_with_signature(
436            name,
437            Signature::exact(input_types, volatility),
438            return_type,
439            fun,
440        )
441    }
442
443    /// Create a new `SimpleScalarUDF` from a name, signature, return type and
444    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
445    pub fn new_with_signature(
446        name: impl Into<String>,
447        signature: Signature,
448        return_type: DataType,
449        fun: ScalarFunctionImplementation,
450    ) -> Self {
451        Self {
452            name: name.into(),
453            signature,
454            return_type,
455            fun: fun.into(),
456        }
457    }
458}
459
460impl ScalarUDFImpl for SimpleScalarUDF {
461    fn as_any(&self) -> &dyn Any {
462        self
463    }
464
465    fn name(&self) -> &str {
466        &self.name
467    }
468
469    fn signature(&self) -> &Signature {
470        &self.signature
471    }
472
473    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
474        Ok(self.return_type.clone())
475    }
476
477    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
478        (self.fun)(&args.args)
479    }
480}
481
482/// Creates a new UDAF with a specific signature, state type and return type.
483/// The signature and state type must match the `Accumulator's implementation`.
484pub fn create_udaf(
485    name: &str,
486    input_type: Vec<DataType>,
487    return_type: Arc<DataType>,
488    volatility: Volatility,
489    accumulator: AccumulatorFactoryFunction,
490    state_type: Arc<Vec<DataType>>,
491) -> AggregateUDF {
492    let return_type = Arc::unwrap_or_clone(return_type);
493    let state_type = Arc::unwrap_or_clone(state_type);
494    let state_fields = state_type
495        .into_iter()
496        .enumerate()
497        .map(|(i, t)| Field::new(format!("{i}"), t, true))
498        .map(Arc::new)
499        .collect::<Vec<_>>();
500    AggregateUDF::from(SimpleAggregateUDF::new(
501        name,
502        input_type,
503        return_type,
504        volatility,
505        accumulator,
506        state_fields,
507    ))
508}
509
510/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
511/// return type.
512#[derive(PartialEq, Eq, Hash)]
513pub struct SimpleAggregateUDF {
514    name: String,
515    signature: Signature,
516    return_type: DataType,
517    accumulator: PtrEq<AccumulatorFactoryFunction>,
518    state_fields: Vec<FieldRef>,
519}
520
521impl Debug for SimpleAggregateUDF {
522    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
523        f.debug_struct("SimpleAggregateUDF")
524            .field("name", &self.name)
525            .field("signature", &self.signature)
526            .field("return_type", &self.return_type)
527            .field("fun", &"<FUNC>")
528            .finish()
529    }
530}
531
532impl SimpleAggregateUDF {
533    /// Create a new `SimpleAggregateUDF` from a name, input types, return type, state type and
534    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
535    pub fn new(
536        name: impl Into<String>,
537        input_type: Vec<DataType>,
538        return_type: DataType,
539        volatility: Volatility,
540        accumulator: AccumulatorFactoryFunction,
541        state_fields: Vec<FieldRef>,
542    ) -> Self {
543        let name = name.into();
544        let signature = Signature::exact(input_type, volatility);
545        Self {
546            name,
547            signature,
548            return_type,
549            accumulator: accumulator.into(),
550            state_fields,
551        }
552    }
553
554    /// Create a new `SimpleAggregateUDF` from a name, signature, return type, state type and
555    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
556    pub fn new_with_signature(
557        name: impl Into<String>,
558        signature: Signature,
559        return_type: DataType,
560        accumulator: AccumulatorFactoryFunction,
561        state_fields: Vec<FieldRef>,
562    ) -> Self {
563        let name = name.into();
564        Self {
565            name,
566            signature,
567            return_type,
568            accumulator: accumulator.into(),
569            state_fields,
570        }
571    }
572}
573
574impl AggregateUDFImpl for SimpleAggregateUDF {
575    fn as_any(&self) -> &dyn Any {
576        self
577    }
578
579    fn name(&self) -> &str {
580        &self.name
581    }
582
583    fn signature(&self) -> &Signature {
584        &self.signature
585    }
586
587    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
588        Ok(self.return_type.clone())
589    }
590
591    fn accumulator(
592        &self,
593        acc_args: AccumulatorArgs,
594    ) -> Result<Box<dyn crate::Accumulator>> {
595        (self.accumulator)(acc_args)
596    }
597
598    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
599        Ok(self.state_fields.clone())
600    }
601}
602
603/// Creates a new UDWF with a specific signature, state type and return type.
604///
605/// The signature and state type must match the [`PartitionEvaluator`]'s implementation`.
606///
607/// [`PartitionEvaluator`]: crate::PartitionEvaluator
608pub fn create_udwf(
609    name: &str,
610    input_type: DataType,
611    return_type: Arc<DataType>,
612    volatility: Volatility,
613    partition_evaluator_factory: PartitionEvaluatorFactory,
614) -> WindowUDF {
615    let return_type = Arc::unwrap_or_clone(return_type);
616    WindowUDF::from(SimpleWindowUDF::new(
617        name,
618        input_type,
619        return_type,
620        volatility,
621        partition_evaluator_factory,
622    ))
623}
624
625/// Implements [`WindowUDFImpl`] for functions that have a single signature and
626/// return type.
627#[derive(PartialEq, Eq, Hash)]
628pub struct SimpleWindowUDF {
629    name: String,
630    signature: Signature,
631    return_type: DataType,
632    partition_evaluator_factory: PtrEq<PartitionEvaluatorFactory>,
633}
634
635impl Debug for SimpleWindowUDF {
636    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
637        f.debug_struct("WindowUDF")
638            .field("name", &self.name)
639            .field("signature", &self.signature)
640            .field("return_type", &"<func>")
641            .field("partition_evaluator_factory", &"<FUNC>")
642            .finish()
643    }
644}
645
646impl SimpleWindowUDF {
647    /// Create a new `SimpleWindowUDF` from a name, input types, return type and
648    /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility
649    pub fn new(
650        name: impl Into<String>,
651        input_type: DataType,
652        return_type: DataType,
653        volatility: Volatility,
654        partition_evaluator_factory: PartitionEvaluatorFactory,
655    ) -> Self {
656        let name = name.into();
657        let signature = Signature::exact([input_type].to_vec(), volatility);
658        Self {
659            name,
660            signature,
661            return_type,
662            partition_evaluator_factory: partition_evaluator_factory.into(),
663        }
664    }
665}
666
667impl WindowUDFImpl for SimpleWindowUDF {
668    fn as_any(&self) -> &dyn Any {
669        self
670    }
671
672    fn name(&self) -> &str {
673        &self.name
674    }
675
676    fn signature(&self) -> &Signature {
677        &self.signature
678    }
679
680    fn partition_evaluator(
681        &self,
682        _partition_evaluator_args: PartitionEvaluatorArgs,
683    ) -> Result<Box<dyn PartitionEvaluator>> {
684        (self.partition_evaluator_factory)()
685    }
686
687    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
688        Ok(Arc::new(Field::new(
689            field_args.name(),
690            self.return_type.clone(),
691            true,
692        )))
693    }
694}
695
696pub fn interval_year_month_lit(value: &str) -> Expr {
697    let interval = parse_interval_year_month(value).ok();
698    Expr::Literal(ScalarValue::IntervalYearMonth(interval), None)
699}
700
701pub fn interval_datetime_lit(value: &str) -> Expr {
702    let interval = parse_interval_day_time(value).ok();
703    Expr::Literal(ScalarValue::IntervalDayTime(interval), None)
704}
705
706pub fn interval_month_day_nano_lit(value: &str) -> Expr {
707    let interval = parse_interval_month_day_nano(value).ok();
708    Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None)
709}
710
711/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
712///
713/// Adds methods to [`Expr`] that make it easy to set optional options
714/// such as `ORDER BY`, `FILTER` and `DISTINCT`
715///
716/// # Example
717/// ```no_run
718/// # use datafusion_common::Result;
719/// # use datafusion_expr::test::function_stub::count;
720/// # use sqlparser::ast::NullTreatment;
721/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col};
722/// # // first_value is an aggregate function in another crate
723/// # fn first_value(_arg: Expr) -> Expr {
724/// unimplemented!() }
725/// # fn main() -> Result<()> {
726/// // Create an aggregate count, filtering on column y > 5
727/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?;
728///
729/// // Find the first value in an aggregate sorted by column y
730/// // equivalent to:
731/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)`
732/// let sort_expr = col("y").sort(true, true);
733/// let agg = first_value(col("x"))
734///     .order_by(vec![sort_expr])
735///     .null_treatment(NullTreatment::IgnoreNulls)
736///     .build()?;
737///
738/// // Create a window expression for percent rank partitioned on column a
739/// // equivalent to:
740/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)`
741/// // percent_rank is an udwf function in another crate
742/// # fn percent_rank() -> Expr {
743/// unimplemented!() }
744/// let window = percent_rank()
745///     .partition_by(vec![col("a")])
746///     .order_by(vec![col("b").sort(true, true)])
747///     .null_treatment(NullTreatment::IgnoreNulls)
748///     .build()?;
749/// #     Ok(())
750/// # }
751/// ```
752pub trait ExprFunctionExt {
753    /// Add `ORDER BY <order_by>`
754    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
755    /// Add `FILTER <filter>`
756    fn filter(self, filter: Expr) -> ExprFuncBuilder;
757    /// Add `DISTINCT`
758    fn distinct(self) -> ExprFuncBuilder;
759    /// Add `RESPECT NULLS` or `IGNORE NULLS`
760    fn null_treatment(
761        self,
762        null_treatment: impl Into<Option<NullTreatment>>,
763    ) -> ExprFuncBuilder;
764    /// Add `PARTITION BY`
765    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
766    /// Add appropriate window frame conditions
767    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
768}
769
770#[derive(Debug, Clone)]
771pub enum ExprFuncKind {
772    Aggregate(AggregateFunction),
773    Window(Box<WindowFunction>),
774}
775
776/// Implementation of [`ExprFunctionExt`].
777///
778/// See [`ExprFunctionExt`] for usage and examples
779#[derive(Debug, Clone)]
780pub struct ExprFuncBuilder {
781    fun: Option<ExprFuncKind>,
782    order_by: Option<Vec<Sort>>,
783    filter: Option<Expr>,
784    distinct: bool,
785    null_treatment: Option<NullTreatment>,
786    partition_by: Option<Vec<Expr>>,
787    window_frame: Option<WindowFrame>,
788}
789
790impl ExprFuncBuilder {
791    /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`]
792    fn new(fun: Option<ExprFuncKind>) -> Self {
793        Self {
794            fun,
795            order_by: None,
796            filter: None,
797            distinct: false,
798            null_treatment: None,
799            partition_by: None,
800            window_frame: None,
801        }
802    }
803
804    /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
805    ///
806    /// # Errors:
807    ///
808    /// Returns an error if this builder  [`ExprFunctionExt`] was used with an
809    /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
810    pub fn build(self) -> Result<Expr> {
811        let Self {
812            fun,
813            order_by,
814            filter,
815            distinct,
816            null_treatment,
817            partition_by,
818            window_frame,
819        } = self;
820
821        let Some(fun) = fun else {
822            return plan_err!(
823                "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
824            );
825        };
826
827        let fun_expr = match fun {
828            ExprFuncKind::Aggregate(mut udaf) => {
829                udaf.params.order_by = order_by.unwrap_or_default();
830                udaf.params.filter = filter.map(Box::new);
831                udaf.params.distinct = distinct;
832                udaf.params.null_treatment = null_treatment;
833                Expr::AggregateFunction(udaf)
834            }
835            ExprFuncKind::Window(mut udwf) => {
836                let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
837                udwf.params.partition_by = partition_by.unwrap_or_default();
838                udwf.params.order_by = order_by.unwrap_or_default();
839                udwf.params.window_frame =
840                    window_frame.unwrap_or_else(|| WindowFrame::new(has_order_by));
841                udwf.params.filter = filter.map(Box::new);
842                udwf.params.null_treatment = null_treatment;
843                udwf.params.distinct = distinct;
844                Expr::WindowFunction(udwf)
845            }
846        };
847
848        Ok(fun_expr)
849    }
850}
851
852impl ExprFunctionExt for ExprFuncBuilder {
853    /// Add `ORDER BY <order_by>`
854    fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
855        self.order_by = Some(order_by);
856        self
857    }
858
859    /// Add `FILTER <filter>`
860    fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
861        self.filter = Some(filter);
862        self
863    }
864
865    /// Add `DISTINCT`
866    fn distinct(mut self) -> ExprFuncBuilder {
867        self.distinct = true;
868        self
869    }
870
871    /// Add `RESPECT NULLS` or `IGNORE NULLS`
872    fn null_treatment(
873        mut self,
874        null_treatment: impl Into<Option<NullTreatment>>,
875    ) -> ExprFuncBuilder {
876        self.null_treatment = null_treatment.into();
877        self
878    }
879
880    fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
881        self.partition_by = Some(partition_by);
882        self
883    }
884
885    fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
886        self.window_frame = Some(window_frame);
887        self
888    }
889}
890
891impl ExprFunctionExt for Expr {
892    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
893        let mut builder = match self {
894            Expr::AggregateFunction(udaf) => {
895                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
896            }
897            Expr::WindowFunction(udwf) => {
898                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
899            }
900            _ => ExprFuncBuilder::new(None),
901        };
902        if builder.fun.is_some() {
903            builder.order_by = Some(order_by);
904        }
905        builder
906    }
907    fn filter(self, filter: Expr) -> ExprFuncBuilder {
908        match self {
909            Expr::AggregateFunction(udaf) => {
910                let mut builder =
911                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
912                builder.filter = Some(filter);
913                builder
914            }
915            _ => ExprFuncBuilder::new(None),
916        }
917    }
918    fn distinct(self) -> ExprFuncBuilder {
919        match self {
920            Expr::AggregateFunction(udaf) => {
921                let mut builder =
922                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
923                builder.distinct = true;
924                builder
925            }
926            _ => ExprFuncBuilder::new(None),
927        }
928    }
929    fn null_treatment(
930        self,
931        null_treatment: impl Into<Option<NullTreatment>>,
932    ) -> ExprFuncBuilder {
933        let mut builder = match self {
934            Expr::AggregateFunction(udaf) => {
935                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
936            }
937            Expr::WindowFunction(udwf) => {
938                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
939            }
940            _ => ExprFuncBuilder::new(None),
941        };
942        if builder.fun.is_some() {
943            builder.null_treatment = null_treatment.into();
944        }
945        builder
946    }
947
948    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
949        match self {
950            Expr::WindowFunction(udwf) => {
951                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
952                builder.partition_by = Some(partition_by);
953                builder
954            }
955            _ => ExprFuncBuilder::new(None),
956        }
957    }
958
959    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
960        match self {
961            Expr::WindowFunction(udwf) => {
962                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
963                builder.window_frame = Some(window_frame);
964                builder
965            }
966            _ => ExprFuncBuilder::new(None),
967        }
968    }
969}
970
971#[cfg(test)]
972mod test {
973    use super::*;
974
975    #[test]
976    fn filter_is_null_and_is_not_null() {
977        let col_null = col("col1");
978        let col_not_null = ident("col2");
979        assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
980        assert_eq!(
981            format!("{}", col_not_null.is_not_null()),
982            "col2 IS NOT NULL"
983        );
984    }
985}