datafusion_python/expr/
window.rs1use crate::common::data_type::PyScalarValue;
19use crate::common::df_schema::PyDFSchema;
20use crate::errors::{py_type_err, PyDataFusionResult};
21use crate::expr::logical_node::LogicalNode;
22use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
23use crate::expr::PyExpr;
24use crate::sql::logical::PyLogicalPlan;
25use datafusion::common::{DataFusionError, ScalarValue};
26use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
27use pyo3::exceptions::PyNotImplementedError;
28use pyo3::{prelude::*, IntoPyObjectExt};
29use std::fmt::{self, Display, Formatter};
30
31use super::py_expr_list;
32
33#[pyclass(frozen, name = "WindowExpr", module = "datafusion.expr", subclass)]
34#[derive(Clone)]
35pub struct PyWindowExpr {
36 window: Window,
37}
38
39#[pyclass(frozen, name = "WindowFrame", module = "datafusion.expr", subclass)]
40#[derive(Clone)]
41pub struct PyWindowFrame {
42 window_frame: WindowFrame,
43}
44
45impl From<PyWindowFrame> for WindowFrame {
46 fn from(window_frame: PyWindowFrame) -> Self {
47 window_frame.window_frame
48 }
49}
50
51impl From<WindowFrame> for PyWindowFrame {
52 fn from(window_frame: WindowFrame) -> PyWindowFrame {
53 PyWindowFrame { window_frame }
54 }
55}
56
57#[pyclass(
58 frozen,
59 name = "WindowFrameBound",
60 module = "datafusion.expr",
61 subclass
62)]
63#[derive(Clone)]
64pub struct PyWindowFrameBound {
65 frame_bound: WindowFrameBound,
66}
67
68impl From<PyWindowExpr> for Window {
69 fn from(window: PyWindowExpr) -> Window {
70 window.window
71 }
72}
73
74impl From<Window> for PyWindowExpr {
75 fn from(window: Window) -> PyWindowExpr {
76 PyWindowExpr { window }
77 }
78}
79
80impl From<WindowFrameBound> for PyWindowFrameBound {
81 fn from(frame_bound: WindowFrameBound) -> Self {
82 PyWindowFrameBound { frame_bound }
83 }
84}
85
86impl Display for PyWindowExpr {
87 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
88 write!(
89 f,
90 "Over\n
91 Window Expr: {:?}
92 Schema: {:?}",
93 &self.window.window_expr, &self.window.schema
94 )
95 }
96}
97
98impl Display for PyWindowFrame {
99 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
100 write!(
101 f,
102 "OVER ({} BETWEEN {} AND {})",
103 self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
104 )
105 }
106}
107
108#[pymethods]
109impl PyWindowExpr {
110 pub fn schema(&self) -> PyResult<PyDFSchema> {
112 Ok(self.window.schema.as_ref().clone().into())
113 }
114
115 pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
117 py_expr_list(&self.window.window_expr)
118 }
119
120 pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
122 match expr.expr.unalias() {
123 Expr::WindowFunction(boxed_window_fn) => {
124 py_sort_expr_list(&boxed_window_fn.params.order_by)
125 }
126 other => Err(not_window_function_err(other)),
127 }
128 }
129
130 pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
132 match expr.expr.unalias() {
133 Expr::WindowFunction(boxed_window_fn) => {
134 py_expr_list(&boxed_window_fn.params.partition_by)
135 }
136 other => Err(not_window_function_err(other)),
137 }
138 }
139
140 pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
142 match expr.expr.unalias() {
143 Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args),
144 other => Err(not_window_function_err(other)),
145 }
146 }
147
148 pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
150 match expr.expr.unalias() {
151 Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()),
152 other => Err(not_window_function_err(other)),
153 }
154 }
155
156 pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
158 match expr.expr.unalias() {
159 Expr::WindowFunction(boxed_window_fn) => {
160 Some(boxed_window_fn.params.window_frame.into())
161 }
162 _ => None,
163 }
164 }
165}
166
167fn not_window_function_err(expr: Expr) -> PyErr {
168 py_type_err(format!(
169 "Provided {} Expr {:?} is not a WindowFunction type",
170 expr.variant_name(),
171 expr
172 ))
173}
174
175#[pymethods]
176impl PyWindowFrame {
177 #[new]
178 #[pyo3(signature=(unit, start_bound, end_bound))]
179 pub fn new(
180 unit: &str,
181 start_bound: Option<PyScalarValue>,
182 end_bound: Option<PyScalarValue>,
183 ) -> PyResult<Self> {
184 let units = unit.to_ascii_lowercase();
185 let units = match units.as_str() {
186 "rows" => WindowFrameUnits::Rows,
187 "range" => WindowFrameUnits::Range,
188 "groups" => WindowFrameUnits::Groups,
189 _ => {
190 return Err(PyNotImplementedError::new_err(format!("{units:?}")));
191 }
192 };
193 let start_bound = match start_bound {
194 Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
195 None => match units {
196 WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
197 WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
198 WindowFrameUnits::Groups => {
199 return Err(PyNotImplementedError::new_err(format!("{units:?}")));
200 }
201 },
202 };
203 let end_bound = match end_bound {
204 Some(end_bound) => WindowFrameBound::Following(end_bound.0),
205 None => match units {
206 WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
207 WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
208 WindowFrameUnits::Groups => {
209 return Err(PyNotImplementedError::new_err(format!("{units:?}")));
210 }
211 },
212 };
213 Ok(PyWindowFrame {
214 window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
215 })
216 }
217
218 pub fn get_frame_units(&self) -> PyResult<String> {
220 Ok(self.window_frame.units.to_string())
221 }
222 pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
224 Ok(self.window_frame.start_bound.clone().into())
225 }
226 pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
228 Ok(self.window_frame.end_bound.clone().into())
229 }
230
231 fn __repr__(&self) -> String {
233 format!("{self}")
234 }
235}
236
237#[pymethods]
238impl PyWindowFrameBound {
239 pub fn is_current_row(&self) -> bool {
241 matches!(self.frame_bound, WindowFrameBound::CurrentRow)
242 }
243
244 pub fn is_preceding(&self) -> bool {
246 matches!(self.frame_bound, WindowFrameBound::Preceding(_))
247 }
248
249 pub fn is_following(&self) -> bool {
251 matches!(self.frame_bound, WindowFrameBound::Following(_))
252 }
253 pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
255 match &self.frame_bound {
256 WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
257 x if x.is_null() => Ok(None),
258 ScalarValue::UInt64(v) => Ok(*v),
259 ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
261 ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
262 Ok(s) => Ok(Some(s)),
263 Err(_e) => Err(DataFusionError::Plan(format!(
264 "Unable to parse u64 from Utf8 value '{s}'"
265 ))
266 .into()),
267 },
268 ref x => {
269 Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
270 }
271 },
272 WindowFrameBound::CurrentRow => Ok(None),
273 }
274 }
275 pub fn is_unbounded(&self) -> PyResult<bool> {
277 match &self.frame_bound {
278 WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
279 WindowFrameBound::CurrentRow => Ok(false),
280 }
281 }
282}
283
284impl LogicalNode for PyWindowExpr {
285 fn inputs(&self) -> Vec<PyLogicalPlan> {
286 vec![self.window.input.as_ref().clone().into()]
287 }
288
289 fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
290 self.clone().into_bound_py_any(py)
291 }
292}