1use datafusion::common::{DataFusionError, ScalarValue};
19use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams};
20use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
21use pyo3::{prelude::*, IntoPyObjectExt};
22use std::fmt::{self, Display, Formatter};
23
24use crate::common::data_type::PyScalarValue;
25use crate::common::df_schema::PyDFSchema;
26use crate::errors::{py_type_err, PyDataFusionResult};
27use crate::expr::logical_node::LogicalNode;
28use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
29use crate::expr::PyExpr;
30use crate::sql::logical::PyLogicalPlan;
31
32use super::py_expr_list;
33
34use crate::errors::py_datafusion_err;
35
36#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
37#[derive(Clone)]
38pub struct PyWindowExpr {
39 window: Window,
40}
41
42#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)]
43#[derive(Clone)]
44pub struct PyWindowFrame {
45 window_frame: WindowFrame,
46}
47
48impl From<PyWindowFrame> for WindowFrame {
49 fn from(window_frame: PyWindowFrame) -> Self {
50 window_frame.window_frame
51 }
52}
53
54impl From<WindowFrame> for PyWindowFrame {
55 fn from(window_frame: WindowFrame) -> PyWindowFrame {
56 PyWindowFrame { window_frame }
57 }
58}
59
60#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)]
61#[derive(Clone)]
62pub struct PyWindowFrameBound {
63 frame_bound: WindowFrameBound,
64}
65
66impl From<PyWindowExpr> for Window {
67 fn from(window: PyWindowExpr) -> Window {
68 window.window
69 }
70}
71
72impl From<Window> for PyWindowExpr {
73 fn from(window: Window) -> PyWindowExpr {
74 PyWindowExpr { window }
75 }
76}
77
78impl From<WindowFrameBound> for PyWindowFrameBound {
79 fn from(frame_bound: WindowFrameBound) -> Self {
80 PyWindowFrameBound { frame_bound }
81 }
82}
83
84impl Display for PyWindowExpr {
85 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
86 write!(
87 f,
88 "Over\n
89 Window Expr: {:?}
90 Schema: {:?}",
91 &self.window.window_expr, &self.window.schema
92 )
93 }
94}
95
96impl Display for PyWindowFrame {
97 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
98 write!(
99 f,
100 "OVER ({} BETWEEN {} AND {})",
101 self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
102 )
103 }
104}
105
106#[pymethods]
107impl PyWindowExpr {
108 pub fn schema(&self) -> PyResult<PyDFSchema> {
110 Ok(self.window.schema.as_ref().clone().into())
111 }
112
113 pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
115 py_expr_list(&self.window.window_expr)
116 }
117
118 pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
120 match expr.expr.unalias() {
121 Expr::WindowFunction(WindowFunction {
122 params: WindowFunctionParams { order_by, .. },
123 ..
124 }) => py_sort_expr_list(&order_by),
125 other => Err(not_window_function_err(other)),
126 }
127 }
128
129 pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
131 match expr.expr.unalias() {
132 Expr::WindowFunction(WindowFunction {
133 params: WindowFunctionParams { partition_by, .. },
134 ..
135 }) => py_expr_list(&partition_by),
136 other => Err(not_window_function_err(other)),
137 }
138 }
139
140 pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
142 match expr.expr.unalias() {
143 Expr::WindowFunction(WindowFunction {
144 params: WindowFunctionParams { args, .. },
145 ..
146 }) => py_expr_list(&args),
147 other => Err(not_window_function_err(other)),
148 }
149 }
150
151 pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
153 match expr.expr.unalias() {
154 Expr::WindowFunction(WindowFunction { fun, .. }) => Ok(fun.to_string()),
155 other => Err(not_window_function_err(other)),
156 }
157 }
158
159 pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
161 match expr.expr.unalias() {
162 Expr::WindowFunction(WindowFunction {
163 params: WindowFunctionParams { window_frame, .. },
164 ..
165 }) => Some(window_frame.into()),
166 _ => None,
167 }
168 }
169}
170
171fn not_window_function_err(expr: Expr) -> PyErr {
172 py_type_err(format!(
173 "Provided {} Expr {:?} is not a WindowFunction type",
174 expr.variant_name(),
175 expr
176 ))
177}
178
179#[pymethods]
180impl PyWindowFrame {
181 #[new]
182 #[pyo3(signature=(unit, start_bound, end_bound))]
183 pub fn new(
184 unit: &str,
185 start_bound: Option<PyScalarValue>,
186 end_bound: Option<PyScalarValue>,
187 ) -> PyResult<Self> {
188 let units = unit.to_ascii_lowercase();
189 let units = match units.as_str() {
190 "rows" => WindowFrameUnits::Rows,
191 "range" => WindowFrameUnits::Range,
192 "groups" => WindowFrameUnits::Groups,
193 _ => {
194 return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
195 "{:?}",
196 units,
197 ))));
198 }
199 };
200 let start_bound = match start_bound {
201 Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
202 None => match units {
203 WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
204 WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
205 WindowFrameUnits::Groups => {
206 return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
207 "{:?}",
208 units,
209 ))));
210 }
211 },
212 };
213 let end_bound = match end_bound {
214 Some(end_bound) => WindowFrameBound::Following(end_bound.0),
215 None => match units {
216 WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
217 WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
218 WindowFrameUnits::Groups => {
219 return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
220 "{:?}",
221 units,
222 ))));
223 }
224 },
225 };
226 Ok(PyWindowFrame {
227 window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
228 })
229 }
230
231 pub fn get_frame_units(&self) -> PyResult<String> {
233 Ok(self.window_frame.units.to_string())
234 }
235 pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
237 Ok(self.window_frame.start_bound.clone().into())
238 }
239 pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
241 Ok(self.window_frame.end_bound.clone().into())
242 }
243
244 fn __repr__(&self) -> String {
246 format!("{}", self)
247 }
248}
249
250#[pymethods]
251impl PyWindowFrameBound {
252 pub fn is_current_row(&self) -> bool {
254 matches!(self.frame_bound, WindowFrameBound::CurrentRow)
255 }
256
257 pub fn is_preceding(&self) -> bool {
259 matches!(self.frame_bound, WindowFrameBound::Preceding(_))
260 }
261
262 pub fn is_following(&self) -> bool {
264 matches!(self.frame_bound, WindowFrameBound::Following(_))
265 }
266 pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
268 match &self.frame_bound {
269 WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
270 x if x.is_null() => Ok(None),
271 ScalarValue::UInt64(v) => Ok(*v),
272 ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
274 ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
275 Ok(s) => Ok(Some(s)),
276 Err(_e) => Err(DataFusionError::Plan(format!(
277 "Unable to parse u64 from Utf8 value '{s}'"
278 ))
279 .into()),
280 },
281 ref x => {
282 Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
283 }
284 },
285 WindowFrameBound::CurrentRow => Ok(None),
286 }
287 }
288 pub fn is_unbounded(&self) -> PyResult<bool> {
290 match &self.frame_bound {
291 WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
292 WindowFrameBound::CurrentRow => Ok(false),
293 }
294 }
295}
296
297impl LogicalNode for PyWindowExpr {
298 fn inputs(&self) -> Vec<PyLogicalPlan> {
299 vec![self.window.input.as_ref().clone().into()]
300 }
301
302 fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
303 self.clone().into_bound_py_any(py)
304 }
305}