use datafusion::common::{DataFusionError, ScalarValue};
use datafusion::logical_expr::expr::WindowFunction;
use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};
use crate::common::df_schema::PyDFSchema;
use crate::errors::py_type_err;
use crate::expr::logical_node::LogicalNode;
use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
use crate::expr::PyExpr;
use crate::sql::logical::PyLogicalPlan;
use super::py_expr_list;
use crate::errors::py_datafusion_err;
#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyWindowExpr {
window: Window,
}
#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyWindowFrame {
window_frame: WindowFrame,
}
impl From<PyWindowFrame> for WindowFrame {
fn from(window_frame: PyWindowFrame) -> Self {
window_frame.window_frame
}
}
impl From<WindowFrame> for PyWindowFrame {
fn from(window_frame: WindowFrame) -> PyWindowFrame {
PyWindowFrame { window_frame }
}
}
#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyWindowFrameBound {
frame_bound: WindowFrameBound,
}
impl From<PyWindowExpr> for Window {
fn from(window: PyWindowExpr) -> Window {
window.window
}
}
impl From<Window> for PyWindowExpr {
fn from(window: Window) -> PyWindowExpr {
PyWindowExpr { window }
}
}
impl From<WindowFrameBound> for PyWindowFrameBound {
fn from(frame_bound: WindowFrameBound) -> Self {
PyWindowFrameBound { frame_bound }
}
}
impl Display for PyWindowExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"Over\n
Window Expr: {:?}
Schema: {:?}",
&self.window.window_expr, &self.window.schema
)
}
}
impl Display for PyWindowFrame {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(
f,
"OVER ({} BETWEEN {} AND {})",
self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
)
}
}
#[pymethods]
impl PyWindowExpr {
pub fn schema(&self) -> PyResult<PyDFSchema> {
Ok(self.window.schema.as_ref().clone().into())
}
pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
py_expr_list(&self.window.window_expr)
}
pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { order_by, .. }) => py_sort_expr_list(&order_by),
other => Err(not_window_function_err(other)),
}
}
pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { partition_by, .. }) => {
py_expr_list(&partition_by)
}
other => Err(not_window_function_err(other)),
}
}
pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { args, .. }) => py_expr_list(&args),
other => Err(not_window_function_err(other)),
}
}
pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { fun, .. }) => Ok(fun.to_string()),
other => Err(not_window_function_err(other)),
}
}
pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { window_frame, .. }) => Some(window_frame.into()),
_ => None,
}
}
}
fn not_window_function_err(expr: Expr) -> PyErr {
py_type_err(format!(
"Provided {} Expr {:?} is not a WindowFunction type",
expr.variant_name(),
expr
))
}
#[pymethods]
impl PyWindowFrame {
#[new]
#[pyo3(signature=(unit, start_bound, end_bound))]
pub fn new(
unit: &str,
start_bound: Option<ScalarValue>,
end_bound: Option<ScalarValue>,
) -> PyResult<Self> {
let units = unit.to_ascii_lowercase();
let units = match units.as_str() {
"rows" => WindowFrameUnits::Rows,
"range" => WindowFrameUnits::Range,
"groups" => WindowFrameUnits::Groups,
_ => {
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
"{:?}",
units,
))));
}
};
let start_bound = match start_bound {
Some(start_bound) => WindowFrameBound::Preceding(start_bound),
None => match units {
WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Groups => {
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
"{:?}",
units,
))));
}
},
};
let end_bound = match end_bound {
Some(end_bound) => WindowFrameBound::Following(end_bound),
None => match units {
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Groups => {
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
"{:?}",
units,
))));
}
},
};
Ok(PyWindowFrame {
window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
})
}
pub fn get_frame_units(&self) -> PyResult<String> {
Ok(self.window_frame.units.to_string())
}
pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
Ok(self.window_frame.start_bound.clone().into())
}
pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
Ok(self.window_frame.end_bound.clone().into())
}
fn __repr__(&self) -> String {
format!("{}", self)
}
}
#[pymethods]
impl PyWindowFrameBound {
pub fn is_current_row(&self) -> bool {
matches!(self.frame_bound, WindowFrameBound::CurrentRow)
}
pub fn is_preceding(&self) -> bool {
matches!(self.frame_bound, WindowFrameBound::Preceding(_))
}
pub fn is_following(&self) -> bool {
matches!(self.frame_bound, WindowFrameBound::Following(_))
}
pub fn get_offset(&self) -> PyResult<Option<u64>> {
match &self.frame_bound {
WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
x if x.is_null() => Ok(None),
ScalarValue::UInt64(v) => Ok(*v),
ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
Ok(s) => Ok(Some(s)),
Err(_e) => Err(DataFusionError::Plan(format!(
"Unable to parse u64 from Utf8 value '{s}'"
))
.into()),
},
ref x => {
Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
}
},
WindowFrameBound::CurrentRow => Ok(None),
}
}
pub fn is_unbounded(&self) -> PyResult<bool> {
match &self.frame_bound {
WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
WindowFrameBound::CurrentRow => Ok(false),
}
}
}
impl LogicalNode for PyWindowExpr {
fn inputs(&self) -> Vec<PyLogicalPlan> {
vec![self.window.input.as_ref().clone().into()]
}
fn to_variant(&self, py: Python) -> PyResult<PyObject> {
Ok(self.clone().into_py(py))
}
}