Skip to main content

datafusion_python/expr/
window.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::{self, Display, Formatter};
19
20use datafusion::common::{DataFusionError, ScalarValue};
21use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
22use pyo3::IntoPyObjectExt;
23use pyo3::exceptions::PyNotImplementedError;
24use pyo3::prelude::*;
25
26use super::py_expr_list;
27use crate::common::data_type::PyScalarValue;
28use crate::common::df_schema::PyDFSchema;
29use crate::errors::{PyDataFusionResult, py_type_err};
30use crate::expr::PyExpr;
31use crate::expr::logical_node::LogicalNode;
32use crate::expr::sort_expr::{PySortExpr, py_sort_expr_list};
33use crate::sql::logical::PyLogicalPlan;
34
35#[pyclass(
36    from_py_object,
37    frozen,
38    name = "WindowExpr",
39    module = "datafusion.expr",
40    subclass
41)]
42#[derive(Clone)]
43pub struct PyWindowExpr {
44    window: Window,
45}
46
47#[pyclass(
48    from_py_object,
49    frozen,
50    name = "WindowFrame",
51    module = "datafusion.expr",
52    subclass
53)]
54#[derive(Clone)]
55pub struct PyWindowFrame {
56    window_frame: WindowFrame,
57}
58
59impl From<PyWindowFrame> for WindowFrame {
60    fn from(window_frame: PyWindowFrame) -> Self {
61        window_frame.window_frame
62    }
63}
64
65impl From<WindowFrame> for PyWindowFrame {
66    fn from(window_frame: WindowFrame) -> PyWindowFrame {
67        PyWindowFrame { window_frame }
68    }
69}
70
71#[pyclass(
72    from_py_object,
73    frozen,
74    name = "WindowFrameBound",
75    module = "datafusion.expr",
76    subclass
77)]
78#[derive(Clone)]
79pub struct PyWindowFrameBound {
80    frame_bound: WindowFrameBound,
81}
82
83impl From<PyWindowExpr> for Window {
84    fn from(window: PyWindowExpr) -> Window {
85        window.window
86    }
87}
88
89impl From<Window> for PyWindowExpr {
90    fn from(window: Window) -> PyWindowExpr {
91        PyWindowExpr { window }
92    }
93}
94
95impl From<WindowFrameBound> for PyWindowFrameBound {
96    fn from(frame_bound: WindowFrameBound) -> Self {
97        PyWindowFrameBound { frame_bound }
98    }
99}
100
101impl Display for PyWindowExpr {
102    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
103        write!(
104            f,
105            "Over\n
106            Window Expr: {:?}
107            Schema: {:?}",
108            &self.window.window_expr, &self.window.schema
109        )
110    }
111}
112
113impl Display for PyWindowFrame {
114    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
115        write!(
116            f,
117            "OVER ({} BETWEEN {} AND {})",
118            self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
119        )
120    }
121}
122
123#[pymethods]
124impl PyWindowExpr {
125    /// Returns the schema of the Window
126    pub fn schema(&self) -> PyResult<PyDFSchema> {
127        Ok(self.window.schema.as_ref().clone().into())
128    }
129
130    /// Returns window expressions
131    pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
132        py_expr_list(&self.window.window_expr)
133    }
134
135    /// Returns order by columns in a window function expression
136    pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
137        match expr.expr.unalias() {
138            Expr::WindowFunction(boxed_window_fn) => {
139                py_sort_expr_list(&boxed_window_fn.params.order_by)
140            }
141            other => Err(not_window_function_err(other)),
142        }
143    }
144
145    /// Return partition by columns in a window function expression
146    pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
147        match expr.expr.unalias() {
148            Expr::WindowFunction(boxed_window_fn) => {
149                py_expr_list(&boxed_window_fn.params.partition_by)
150            }
151            other => Err(not_window_function_err(other)),
152        }
153    }
154
155    /// Return input args for window function
156    pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
157        match expr.expr.unalias() {
158            Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args),
159            other => Err(not_window_function_err(other)),
160        }
161    }
162
163    /// Return window function name
164    pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
165        match expr.expr.unalias() {
166            Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()),
167            other => Err(not_window_function_err(other)),
168        }
169    }
170
171    /// Returns a Pywindow frame for a given window function expression
172    pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
173        match expr.expr.unalias() {
174            Expr::WindowFunction(boxed_window_fn) => {
175                Some(boxed_window_fn.params.window_frame.into())
176            }
177            _ => None,
178        }
179    }
180}
181
182fn not_window_function_err(expr: Expr) -> PyErr {
183    py_type_err(format!(
184        "Provided {} Expr {:?} is not a WindowFunction type",
185        expr.variant_name(),
186        expr
187    ))
188}
189
190#[pymethods]
191impl PyWindowFrame {
192    #[new]
193    #[pyo3(signature=(unit, start_bound, end_bound))]
194    pub fn new(
195        unit: &str,
196        start_bound: Option<PyScalarValue>,
197        end_bound: Option<PyScalarValue>,
198    ) -> PyResult<Self> {
199        let units = unit.to_ascii_lowercase();
200        let units = match units.as_str() {
201            "rows" => WindowFrameUnits::Rows,
202            "range" => WindowFrameUnits::Range,
203            "groups" => WindowFrameUnits::Groups,
204            _ => {
205                return Err(PyNotImplementedError::new_err(format!("{units:?}")));
206            }
207        };
208        let start_bound = match start_bound {
209            Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
210            None => match units {
211                WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
212                WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
213                WindowFrameUnits::Groups => {
214                    return Err(PyNotImplementedError::new_err(format!("{units:?}")));
215                }
216            },
217        };
218        let end_bound = match end_bound {
219            Some(end_bound) => WindowFrameBound::Following(end_bound.0),
220            None => match units {
221                WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
222                WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
223                WindowFrameUnits::Groups => {
224                    return Err(PyNotImplementedError::new_err(format!("{units:?}")));
225                }
226            },
227        };
228        Ok(PyWindowFrame {
229            window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
230        })
231    }
232
233    /// Returns the window frame units for the bounds
234    pub fn get_frame_units(&self) -> PyResult<String> {
235        Ok(self.window_frame.units.to_string())
236    }
237    /// Returns starting bound
238    pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
239        Ok(self.window_frame.start_bound.clone().into())
240    }
241    /// Returns end bound
242    pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
243        Ok(self.window_frame.end_bound.clone().into())
244    }
245
246    /// Get a String representation of this window frame
247    fn __repr__(&self) -> String {
248        format!("{self}")
249    }
250}
251
252#[pymethods]
253impl PyWindowFrameBound {
254    /// Returns if the frame bound is current row
255    pub fn is_current_row(&self) -> bool {
256        matches!(self.frame_bound, WindowFrameBound::CurrentRow)
257    }
258
259    /// Returns if the frame bound is preceding
260    pub fn is_preceding(&self) -> bool {
261        matches!(self.frame_bound, WindowFrameBound::Preceding(_))
262    }
263
264    /// Returns if the frame bound is following
265    pub fn is_following(&self) -> bool {
266        matches!(self.frame_bound, WindowFrameBound::Following(_))
267    }
268    /// Returns the offset of the window frame
269    pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
270        match &self.frame_bound {
271            WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
272                x if x.is_null() => Ok(None),
273                ScalarValue::UInt64(v) => Ok(*v),
274                // The cast below is only safe because window bounds cannot be negative
275                ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
276                ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
277                    Ok(s) => Ok(Some(s)),
278                    Err(_e) => Err(DataFusionError::Plan(format!(
279                        "Unable to parse u64 from Utf8 value '{s}'"
280                    ))
281                    .into()),
282                },
283                ref x => {
284                    Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
285                }
286            },
287            WindowFrameBound::CurrentRow => Ok(None),
288        }
289    }
290    /// Returns if the frame bound is unbounded
291    pub fn is_unbounded(&self) -> PyResult<bool> {
292        match &self.frame_bound {
293            WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
294            WindowFrameBound::CurrentRow => Ok(false),
295        }
296    }
297}
298
299impl LogicalNode for PyWindowExpr {
300    fn inputs(&self) -> Vec<PyLogicalPlan> {
301        vec![self.window.input.as_ref().clone().into()]
302    }
303
304    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
305        self.clone().into_bound_py_any(py)
306    }
307}