datafusion_python/expr/
window.rs1use std::fmt::{self, Display, Formatter};
19
20use datafusion::common::{DataFusionError, ScalarValue};
21use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
22use pyo3::exceptions::PyNotImplementedError;
23use pyo3::prelude::*;
24use pyo3::IntoPyObjectExt;
25
26use super::py_expr_list;
27use crate::common::data_type::PyScalarValue;
28use crate::common::df_schema::PyDFSchema;
29use crate::errors::{py_type_err, PyDataFusionResult};
30use crate::expr::logical_node::LogicalNode;
31use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
32use crate::expr::PyExpr;
33use crate::sql::logical::PyLogicalPlan;
34
35#[pyclass(frozen, name = "WindowExpr", module = "datafusion.expr", subclass)]
36#[derive(Clone)]
37pub struct PyWindowExpr {
38 window: Window,
39}
40
41#[pyclass(frozen, name = "WindowFrame", module = "datafusion.expr", subclass)]
42#[derive(Clone)]
43pub struct PyWindowFrame {
44 window_frame: WindowFrame,
45}
46
47impl From<PyWindowFrame> for WindowFrame {
48 fn from(window_frame: PyWindowFrame) -> Self {
49 window_frame.window_frame
50 }
51}
52
53impl From<WindowFrame> for PyWindowFrame {
54 fn from(window_frame: WindowFrame) -> PyWindowFrame {
55 PyWindowFrame { window_frame }
56 }
57}
58
59#[pyclass(
60 frozen,
61 name = "WindowFrameBound",
62 module = "datafusion.expr",
63 subclass
64)]
65#[derive(Clone)]
66pub struct PyWindowFrameBound {
67 frame_bound: WindowFrameBound,
68}
69
70impl From<PyWindowExpr> for Window {
71 fn from(window: PyWindowExpr) -> Window {
72 window.window
73 }
74}
75
76impl From<Window> for PyWindowExpr {
77 fn from(window: Window) -> PyWindowExpr {
78 PyWindowExpr { window }
79 }
80}
81
82impl From<WindowFrameBound> for PyWindowFrameBound {
83 fn from(frame_bound: WindowFrameBound) -> Self {
84 PyWindowFrameBound { frame_bound }
85 }
86}
87
88impl Display for PyWindowExpr {
89 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
90 write!(
91 f,
92 "Over\n
93 Window Expr: {:?}
94 Schema: {:?}",
95 &self.window.window_expr, &self.window.schema
96 )
97 }
98}
99
100impl Display for PyWindowFrame {
101 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
102 write!(
103 f,
104 "OVER ({} BETWEEN {} AND {})",
105 self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
106 )
107 }
108}
109
110#[pymethods]
111impl PyWindowExpr {
112 pub fn schema(&self) -> PyResult<PyDFSchema> {
114 Ok(self.window.schema.as_ref().clone().into())
115 }
116
117 pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
119 py_expr_list(&self.window.window_expr)
120 }
121
122 pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
124 match expr.expr.unalias() {
125 Expr::WindowFunction(boxed_window_fn) => {
126 py_sort_expr_list(&boxed_window_fn.params.order_by)
127 }
128 other => Err(not_window_function_err(other)),
129 }
130 }
131
132 pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
134 match expr.expr.unalias() {
135 Expr::WindowFunction(boxed_window_fn) => {
136 py_expr_list(&boxed_window_fn.params.partition_by)
137 }
138 other => Err(not_window_function_err(other)),
139 }
140 }
141
142 pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
144 match expr.expr.unalias() {
145 Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args),
146 other => Err(not_window_function_err(other)),
147 }
148 }
149
150 pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
152 match expr.expr.unalias() {
153 Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()),
154 other => Err(not_window_function_err(other)),
155 }
156 }
157
158 pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
160 match expr.expr.unalias() {
161 Expr::WindowFunction(boxed_window_fn) => {
162 Some(boxed_window_fn.params.window_frame.into())
163 }
164 _ => None,
165 }
166 }
167}
168
169fn not_window_function_err(expr: Expr) -> PyErr {
170 py_type_err(format!(
171 "Provided {} Expr {:?} is not a WindowFunction type",
172 expr.variant_name(),
173 expr
174 ))
175}
176
177#[pymethods]
178impl PyWindowFrame {
179 #[new]
180 #[pyo3(signature=(unit, start_bound, end_bound))]
181 pub fn new(
182 unit: &str,
183 start_bound: Option<PyScalarValue>,
184 end_bound: Option<PyScalarValue>,
185 ) -> PyResult<Self> {
186 let units = unit.to_ascii_lowercase();
187 let units = match units.as_str() {
188 "rows" => WindowFrameUnits::Rows,
189 "range" => WindowFrameUnits::Range,
190 "groups" => WindowFrameUnits::Groups,
191 _ => {
192 return Err(PyNotImplementedError::new_err(format!("{units:?}")));
193 }
194 };
195 let start_bound = match start_bound {
196 Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
197 None => match units {
198 WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
199 WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
200 WindowFrameUnits::Groups => {
201 return Err(PyNotImplementedError::new_err(format!("{units:?}")));
202 }
203 },
204 };
205 let end_bound = match end_bound {
206 Some(end_bound) => WindowFrameBound::Following(end_bound.0),
207 None => match units {
208 WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
209 WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
210 WindowFrameUnits::Groups => {
211 return Err(PyNotImplementedError::new_err(format!("{units:?}")));
212 }
213 },
214 };
215 Ok(PyWindowFrame {
216 window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
217 })
218 }
219
220 pub fn get_frame_units(&self) -> PyResult<String> {
222 Ok(self.window_frame.units.to_string())
223 }
224 pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
226 Ok(self.window_frame.start_bound.clone().into())
227 }
228 pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
230 Ok(self.window_frame.end_bound.clone().into())
231 }
232
233 fn __repr__(&self) -> String {
235 format!("{self}")
236 }
237}
238
239#[pymethods]
240impl PyWindowFrameBound {
241 pub fn is_current_row(&self) -> bool {
243 matches!(self.frame_bound, WindowFrameBound::CurrentRow)
244 }
245
246 pub fn is_preceding(&self) -> bool {
248 matches!(self.frame_bound, WindowFrameBound::Preceding(_))
249 }
250
251 pub fn is_following(&self) -> bool {
253 matches!(self.frame_bound, WindowFrameBound::Following(_))
254 }
255 pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
257 match &self.frame_bound {
258 WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
259 x if x.is_null() => Ok(None),
260 ScalarValue::UInt64(v) => Ok(*v),
261 ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
263 ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
264 Ok(s) => Ok(Some(s)),
265 Err(_e) => Err(DataFusionError::Plan(format!(
266 "Unable to parse u64 from Utf8 value '{s}'"
267 ))
268 .into()),
269 },
270 ref x => {
271 Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
272 }
273 },
274 WindowFrameBound::CurrentRow => Ok(None),
275 }
276 }
277 pub fn is_unbounded(&self) -> PyResult<bool> {
279 match &self.frame_bound {
280 WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
281 WindowFrameBound::CurrentRow => Ok(false),
282 }
283 }
284}
285
286impl LogicalNode for PyWindowExpr {
287 fn inputs(&self) -> Vec<PyLogicalPlan> {
288 vec![self.window.input.as_ref().clone().into()]
289 }
290
291 fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
292 self.clone().into_bound_py_any(py)
293 }
294}