datafusion_python/expr/
window.rs1use crate::common::data_type::PyScalarValue;
19use crate::common::df_schema::PyDFSchema;
20use crate::errors::{py_type_err, PyDataFusionResult};
21use crate::expr::logical_node::LogicalNode;
22use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
23use crate::expr::PyExpr;
24use crate::sql::logical::PyLogicalPlan;
25use datafusion::common::{DataFusionError, ScalarValue};
26use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
27use pyo3::exceptions::PyNotImplementedError;
28use pyo3::{prelude::*, IntoPyObjectExt};
29use std::fmt::{self, Display, Formatter};
30
31use super::py_expr_list;
32
33#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
34#[derive(Clone)]
35pub struct PyWindowExpr {
36 window: Window,
37}
38
39#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)]
40#[derive(Clone)]
41pub struct PyWindowFrame {
42 window_frame: WindowFrame,
43}
44
45impl From<PyWindowFrame> for WindowFrame {
46 fn from(window_frame: PyWindowFrame) -> Self {
47 window_frame.window_frame
48 }
49}
50
51impl From<WindowFrame> for PyWindowFrame {
52 fn from(window_frame: WindowFrame) -> PyWindowFrame {
53 PyWindowFrame { window_frame }
54 }
55}
56
57#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)]
58#[derive(Clone)]
59pub struct PyWindowFrameBound {
60 frame_bound: WindowFrameBound,
61}
62
63impl From<PyWindowExpr> for Window {
64 fn from(window: PyWindowExpr) -> Window {
65 window.window
66 }
67}
68
69impl From<Window> for PyWindowExpr {
70 fn from(window: Window) -> PyWindowExpr {
71 PyWindowExpr { window }
72 }
73}
74
75impl From<WindowFrameBound> for PyWindowFrameBound {
76 fn from(frame_bound: WindowFrameBound) -> Self {
77 PyWindowFrameBound { frame_bound }
78 }
79}
80
81impl Display for PyWindowExpr {
82 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
83 write!(
84 f,
85 "Over\n
86 Window Expr: {:?}
87 Schema: {:?}",
88 &self.window.window_expr, &self.window.schema
89 )
90 }
91}
92
93impl Display for PyWindowFrame {
94 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
95 write!(
96 f,
97 "OVER ({} BETWEEN {} AND {})",
98 self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
99 )
100 }
101}
102
103#[pymethods]
104impl PyWindowExpr {
105 pub fn schema(&self) -> PyResult<PyDFSchema> {
107 Ok(self.window.schema.as_ref().clone().into())
108 }
109
110 pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
112 py_expr_list(&self.window.window_expr)
113 }
114
115 pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
117 match expr.expr.unalias() {
118 Expr::WindowFunction(boxed_window_fn) => {
119 py_sort_expr_list(&boxed_window_fn.params.order_by)
120 }
121 other => Err(not_window_function_err(other)),
122 }
123 }
124
125 pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
127 match expr.expr.unalias() {
128 Expr::WindowFunction(boxed_window_fn) => {
129 py_expr_list(&boxed_window_fn.params.partition_by)
130 }
131 other => Err(not_window_function_err(other)),
132 }
133 }
134
135 pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
137 match expr.expr.unalias() {
138 Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args),
139 other => Err(not_window_function_err(other)),
140 }
141 }
142
143 pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
145 match expr.expr.unalias() {
146 Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()),
147 other => Err(not_window_function_err(other)),
148 }
149 }
150
151 pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
153 match expr.expr.unalias() {
154 Expr::WindowFunction(boxed_window_fn) => {
155 Some(boxed_window_fn.params.window_frame.into())
156 }
157 _ => None,
158 }
159 }
160}
161
162fn not_window_function_err(expr: Expr) -> PyErr {
163 py_type_err(format!(
164 "Provided {} Expr {:?} is not a WindowFunction type",
165 expr.variant_name(),
166 expr
167 ))
168}
169
170#[pymethods]
171impl PyWindowFrame {
172 #[new]
173 #[pyo3(signature=(unit, start_bound, end_bound))]
174 pub fn new(
175 unit: &str,
176 start_bound: Option<PyScalarValue>,
177 end_bound: Option<PyScalarValue>,
178 ) -> PyResult<Self> {
179 let units = unit.to_ascii_lowercase();
180 let units = match units.as_str() {
181 "rows" => WindowFrameUnits::Rows,
182 "range" => WindowFrameUnits::Range,
183 "groups" => WindowFrameUnits::Groups,
184 _ => {
185 return Err(PyNotImplementedError::new_err(format!("{units:?}")));
186 }
187 };
188 let start_bound = match start_bound {
189 Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
190 None => match units {
191 WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
192 WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
193 WindowFrameUnits::Groups => {
194 return Err(PyNotImplementedError::new_err(format!("{units:?}")));
195 }
196 },
197 };
198 let end_bound = match end_bound {
199 Some(end_bound) => WindowFrameBound::Following(end_bound.0),
200 None => match units {
201 WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
202 WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
203 WindowFrameUnits::Groups => {
204 return Err(PyNotImplementedError::new_err(format!("{units:?}")));
205 }
206 },
207 };
208 Ok(PyWindowFrame {
209 window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
210 })
211 }
212
213 pub fn get_frame_units(&self) -> PyResult<String> {
215 Ok(self.window_frame.units.to_string())
216 }
217 pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
219 Ok(self.window_frame.start_bound.clone().into())
220 }
221 pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
223 Ok(self.window_frame.end_bound.clone().into())
224 }
225
226 fn __repr__(&self) -> String {
228 format!("{self}")
229 }
230}
231
232#[pymethods]
233impl PyWindowFrameBound {
234 pub fn is_current_row(&self) -> bool {
236 matches!(self.frame_bound, WindowFrameBound::CurrentRow)
237 }
238
239 pub fn is_preceding(&self) -> bool {
241 matches!(self.frame_bound, WindowFrameBound::Preceding(_))
242 }
243
244 pub fn is_following(&self) -> bool {
246 matches!(self.frame_bound, WindowFrameBound::Following(_))
247 }
248 pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
250 match &self.frame_bound {
251 WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
252 x if x.is_null() => Ok(None),
253 ScalarValue::UInt64(v) => Ok(*v),
254 ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
256 ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
257 Ok(s) => Ok(Some(s)),
258 Err(_e) => Err(DataFusionError::Plan(format!(
259 "Unable to parse u64 from Utf8 value '{s}'"
260 ))
261 .into()),
262 },
263 ref x => {
264 Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
265 }
266 },
267 WindowFrameBound::CurrentRow => Ok(None),
268 }
269 }
270 pub fn is_unbounded(&self) -> PyResult<bool> {
272 match &self.frame_bound {
273 WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
274 WindowFrameBound::CurrentRow => Ok(false),
275 }
276 }
277}
278
279impl LogicalNode for PyWindowExpr {
280 fn inputs(&self) -> Vec<PyLogicalPlan> {
281 vec![self.window.input.as_ref().clone().into()]
282 }
283
284 fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
285 self.clone().into_bound_py_any(py)
286 }
287}