datafusion_substrait/logical_plan/consumer/expr/
window_function.rs1use crate::logical_plan::consumer::{
19 from_substrait_func_args, from_substrait_rex_vec, from_substrait_sorts,
20 substrait_fun_name, SubstraitConsumer,
21};
22use datafusion::common::{
23 not_impl_err, plan_datafusion_err, plan_err, substrait_err, DFSchema, ScalarValue,
24};
25use datafusion::execution::FunctionRegistry;
26use datafusion::logical_expr::expr::WindowFunctionParams;
27use datafusion::logical_expr::{
28 expr, Expr, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
29};
30use substrait::proto::expression::window_function::{Bound, BoundsType};
31use substrait::proto::expression::WindowFunction;
32use substrait::proto::expression::{
33 window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind,
34};
35
36pub async fn from_window_function(
37 consumer: &impl SubstraitConsumer,
38 window: &WindowFunction,
39 input_schema: &DFSchema,
40) -> datafusion::common::Result<Expr> {
41 let Some(fn_signature) = consumer
42 .get_extensions()
43 .functions
44 .get(&window.function_reference)
45 else {
46 return plan_err!(
47 "Window function not found: function reference = {:?}",
48 window.function_reference
49 );
50 };
51 let fn_name = substrait_fun_name(fn_signature);
52
53 let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) {
55 Ok(WindowFunctionDefinition::WindowUDF(udwf))
56 } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) {
57 Ok(WindowFunctionDefinition::AggregateUDF(udaf))
58 } else {
59 not_impl_err!(
60 "Window function {} is not supported: function anchor = {:?}",
61 fn_name,
62 window.function_reference
63 )
64 }?;
65
66 let mut order_by =
67 from_substrait_sorts(consumer, &window.sorts, input_schema).await?;
68
69 let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| {
70 plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type)
71 })? {
72 BoundsType::Rows => WindowFrameUnits::Rows,
73 BoundsType::Range => WindowFrameUnits::Range,
74 BoundsType::Unspecified => {
75 if order_by.is_empty() {
79 WindowFrameUnits::Rows
80 } else {
81 WindowFrameUnits::Range
82 }
83 }
84 };
85 let window_frame = datafusion::logical_expr::WindowFrame::new_bounds(
86 bound_units,
87 from_substrait_bound(&window.lower_bound, true)?,
88 from_substrait_bound(&window.upper_bound, false)?,
89 );
90
91 window_frame.regularize_order_bys(&mut order_by)?;
92
93 let args = if fun.name() == "count" && window.arguments.is_empty() {
97 vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)]
98 } else {
99 from_substrait_func_args(consumer, &window.arguments, input_schema).await?
100 };
101
102 Ok(Expr::from(expr::WindowFunction {
103 fun,
104 params: WindowFunctionParams {
105 args,
106 partition_by: from_substrait_rex_vec(
107 consumer,
108 &window.partitions,
109 input_schema,
110 )
111 .await?,
112 order_by,
113 window_frame,
114 filter: None,
115 null_treatment: None,
116 distinct: false,
117 },
118 }))
119}
120
121fn from_substrait_bound(
122 bound: &Option<Bound>,
123 is_lower: bool,
124) -> datafusion::common::Result<WindowFrameBound> {
125 match bound {
126 Some(b) => match &b.kind {
127 Some(k) => match k {
128 BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => {
129 Ok(WindowFrameBound::CurrentRow)
130 }
131 BoundKind::Preceding(SubstraitBound::Preceding { offset }) => {
132 if *offset <= 0 {
133 return plan_err!("Preceding bound must be positive");
134 }
135 Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
136 *offset as u64,
137 ))))
138 }
139 BoundKind::Following(SubstraitBound::Following { offset }) => {
140 if *offset <= 0 {
141 return plan_err!("Following bound must be positive");
142 }
143 Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some(
144 *offset as u64,
145 ))))
146 }
147 BoundKind::Unbounded(SubstraitBound::Unbounded {}) => {
148 if is_lower {
149 Ok(WindowFrameBound::Preceding(ScalarValue::Null))
150 } else {
151 Ok(WindowFrameBound::Following(ScalarValue::Null))
152 }
153 }
154 },
155 None => substrait_err!("WindowFunction missing Substrait Bound kind"),
156 },
157 None => {
158 if is_lower {
159 Ok(WindowFrameBound::Preceding(ScalarValue::Null))
160 } else {
161 Ok(WindowFrameBound::Following(ScalarValue::Null))
162 }
163 }
164 }
165}