use crate::logical_plan::consumer::{
SubstraitConsumer, from_substrait_func_args, from_substrait_rex_vec,
from_substrait_sorts, substrait_fun_name,
};
use datafusion::common::{
DFSchema, ScalarValue, not_impl_err, plan_datafusion_err, plan_err, substrait_err,
};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::expr::WindowFunctionParams;
use datafusion::logical_expr::{
Expr, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, expr,
};
use substrait::proto::expression::WindowFunction;
use substrait::proto::expression::window_function::{Bound, BoundsType};
use substrait::proto::expression::{
window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind,
};
pub async fn from_window_function(
consumer: &impl SubstraitConsumer,
window: &WindowFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
let Some(fn_signature) = consumer
.get_extensions()
.functions
.get(&window.function_reference)
else {
return plan_err!(
"Window function not found: function reference = {:?}",
window.function_reference
);
};
let fn_name = substrait_fun_name(fn_signature);
let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) {
Ok(WindowFunctionDefinition::WindowUDF(udwf))
} else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) {
Ok(WindowFunctionDefinition::AggregateUDF(udaf))
} else {
not_impl_err!(
"Window function {} is not supported: function anchor = {:?}",
fn_name,
window.function_reference
)
}?;
let mut order_by =
from_substrait_sorts(consumer, &window.sorts, input_schema).await?;
let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| {
plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type)
})? {
BoundsType::Rows => WindowFrameUnits::Rows,
BoundsType::Range => WindowFrameUnits::Range,
BoundsType::Unspecified => {
if order_by.is_empty() {
WindowFrameUnits::Rows
} else {
WindowFrameUnits::Range
}
}
};
let window_frame = datafusion::logical_expr::WindowFrame::new_bounds(
bound_units,
from_substrait_bound(&window.lower_bound, true)?,
from_substrait_bound(&window.upper_bound, false)?,
);
window_frame.regularize_order_bys(&mut order_by)?;
let args = if fun.name() == "count" && window.arguments.is_empty() {
vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)]
} else {
from_substrait_func_args(consumer, &window.arguments, input_schema).await?
};
Ok(Expr::from(expr::WindowFunction {
fun,
params: WindowFunctionParams {
args,
partition_by: from_substrait_rex_vec(
consumer,
&window.partitions,
input_schema,
)
.await?,
order_by,
window_frame,
filter: None,
null_treatment: None,
distinct: false,
},
}))
}
fn from_substrait_bound(
bound: &Option<Bound>,
is_lower: bool,
) -> datafusion::common::Result<WindowFrameBound> {
match bound {
Some(b) => match &b.kind {
Some(k) => match k {
BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => {
Ok(WindowFrameBound::CurrentRow)
}
BoundKind::Preceding(SubstraitBound::Preceding { offset }) => {
if *offset <= 0 {
return plan_err!("Preceding bound must be positive");
}
Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
*offset as u64,
))))
}
BoundKind::Following(SubstraitBound::Following { offset }) => {
if *offset <= 0 {
return plan_err!("Following bound must be positive");
}
Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some(
*offset as u64,
))))
}
BoundKind::Unbounded(SubstraitBound::Unbounded {}) => {
if is_lower {
Ok(WindowFrameBound::Preceding(ScalarValue::Null))
} else {
Ok(WindowFrameBound::Following(ScalarValue::Null))
}
}
},
None => substrait_err!("WindowFunction missing Substrait Bound kind"),
},
None => {
if is_lower {
Ok(WindowFrameBound::Preceding(ScalarValue::Null))
} else {
Ok(WindowFrameBound::Following(ScalarValue::Null))
}
}
}
}