datafusion_substrait/logical_plan/consumer/expr/
window_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::logical_plan::consumer::{
19    from_substrait_func_args, from_substrait_rex_vec, from_substrait_sorts,
20    substrait_fun_name, SubstraitConsumer,
21};
22use datafusion::common::{
23    not_impl_err, plan_datafusion_err, plan_err, substrait_err, DFSchema, ScalarValue,
24};
25use datafusion::execution::FunctionRegistry;
26use datafusion::logical_expr::expr::WindowFunctionParams;
27use datafusion::logical_expr::{
28    expr, Expr, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
29};
30use substrait::proto::expression::window_function::{Bound, BoundsType};
31use substrait::proto::expression::WindowFunction;
32use substrait::proto::expression::{
33    window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind,
34};
35
36pub async fn from_window_function(
37    consumer: &impl SubstraitConsumer,
38    window: &WindowFunction,
39    input_schema: &DFSchema,
40) -> datafusion::common::Result<Expr> {
41    let Some(fn_signature) = consumer
42        .get_extensions()
43        .functions
44        .get(&window.function_reference)
45    else {
46        return plan_err!(
47            "Window function not found: function reference = {:?}",
48            window.function_reference
49        );
50    };
51    let fn_name = substrait_fun_name(fn_signature);
52
53    // check udwf first, then udaf, then built-in window and aggregate functions
54    let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) {
55        Ok(WindowFunctionDefinition::WindowUDF(udwf))
56    } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) {
57        Ok(WindowFunctionDefinition::AggregateUDF(udaf))
58    } else {
59        not_impl_err!(
60            "Window function {} is not supported: function anchor = {:?}",
61            fn_name,
62            window.function_reference
63        )
64    }?;
65
66    let mut order_by =
67        from_substrait_sorts(consumer, &window.sorts, input_schema).await?;
68
69    let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| {
70        plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type)
71    })? {
72        BoundsType::Rows => WindowFrameUnits::Rows,
73        BoundsType::Range => WindowFrameUnits::Range,
74        BoundsType::Unspecified => {
75            // If the plan does not specify the bounds type, then we use a simple logic to determine the units
76            // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary
77            // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row
78            if order_by.is_empty() {
79                WindowFrameUnits::Rows
80            } else {
81                WindowFrameUnits::Range
82            }
83        }
84    };
85    let window_frame = datafusion::logical_expr::WindowFrame::new_bounds(
86        bound_units,
87        from_substrait_bound(&window.lower_bound, true)?,
88        from_substrait_bound(&window.upper_bound, false)?,
89    );
90
91    window_frame.regularize_order_bys(&mut order_by)?;
92
93    // Datafusion does not support aggregate functions with no arguments, so
94    // we inject a dummy argument that does not affect the query, but allows
95    // us to bypass this limitation.
96    let args = if fun.name() == "count" && window.arguments.is_empty() {
97        vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)]
98    } else {
99        from_substrait_func_args(consumer, &window.arguments, input_schema).await?
100    };
101
102    Ok(Expr::from(expr::WindowFunction {
103        fun,
104        params: WindowFunctionParams {
105            args,
106            partition_by: from_substrait_rex_vec(
107                consumer,
108                &window.partitions,
109                input_schema,
110            )
111            .await?,
112            order_by,
113            window_frame,
114            filter: None,
115            null_treatment: None,
116            distinct: false,
117        },
118    }))
119}
120
121fn from_substrait_bound(
122    bound: &Option<Bound>,
123    is_lower: bool,
124) -> datafusion::common::Result<WindowFrameBound> {
125    match bound {
126        Some(b) => match &b.kind {
127            Some(k) => match k {
128                BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => {
129                    Ok(WindowFrameBound::CurrentRow)
130                }
131                BoundKind::Preceding(SubstraitBound::Preceding { offset }) => {
132                    if *offset <= 0 {
133                        return plan_err!("Preceding bound must be positive");
134                    }
135                    Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
136                        *offset as u64,
137                    ))))
138                }
139                BoundKind::Following(SubstraitBound::Following { offset }) => {
140                    if *offset <= 0 {
141                        return plan_err!("Following bound must be positive");
142                    }
143                    Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some(
144                        *offset as u64,
145                    ))))
146                }
147                BoundKind::Unbounded(SubstraitBound::Unbounded {}) => {
148                    if is_lower {
149                        Ok(WindowFrameBound::Preceding(ScalarValue::Null))
150                    } else {
151                        Ok(WindowFrameBound::Following(ScalarValue::Null))
152                    }
153                }
154            },
155            None => substrait_err!("WindowFunction missing Substrait Bound kind"),
156        },
157        None => {
158            if is_lower {
159                Ok(WindowFrameBound::Preceding(ScalarValue::Null))
160            } else {
161                Ok(WindowFrameBound::Following(ScalarValue::Null))
162            }
163        }
164    }
165}