datafusion_python/expr/
window.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::common::data_type::PyScalarValue;
19use crate::common::df_schema::PyDFSchema;
20use crate::errors::{py_type_err, PyDataFusionResult};
21use crate::expr::logical_node::LogicalNode;
22use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
23use crate::expr::PyExpr;
24use crate::sql::logical::PyLogicalPlan;
25use datafusion::common::{DataFusionError, ScalarValue};
26use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
27use pyo3::exceptions::PyNotImplementedError;
28use pyo3::{prelude::*, IntoPyObjectExt};
29use std::fmt::{self, Display, Formatter};
30
31use super::py_expr_list;
32
33#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
34#[derive(Clone)]
35pub struct PyWindowExpr {
36    window: Window,
37}
38
39#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)]
40#[derive(Clone)]
41pub struct PyWindowFrame {
42    window_frame: WindowFrame,
43}
44
45impl From<PyWindowFrame> for WindowFrame {
46    fn from(window_frame: PyWindowFrame) -> Self {
47        window_frame.window_frame
48    }
49}
50
51impl From<WindowFrame> for PyWindowFrame {
52    fn from(window_frame: WindowFrame) -> PyWindowFrame {
53        PyWindowFrame { window_frame }
54    }
55}
56
57#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)]
58#[derive(Clone)]
59pub struct PyWindowFrameBound {
60    frame_bound: WindowFrameBound,
61}
62
63impl From<PyWindowExpr> for Window {
64    fn from(window: PyWindowExpr) -> Window {
65        window.window
66    }
67}
68
69impl From<Window> for PyWindowExpr {
70    fn from(window: Window) -> PyWindowExpr {
71        PyWindowExpr { window }
72    }
73}
74
75impl From<WindowFrameBound> for PyWindowFrameBound {
76    fn from(frame_bound: WindowFrameBound) -> Self {
77        PyWindowFrameBound { frame_bound }
78    }
79}
80
81impl Display for PyWindowExpr {
82    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
83        write!(
84            f,
85            "Over\n
86            Window Expr: {:?}
87            Schema: {:?}",
88            &self.window.window_expr, &self.window.schema
89        )
90    }
91}
92
93impl Display for PyWindowFrame {
94    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
95        write!(
96            f,
97            "OVER ({} BETWEEN {} AND {})",
98            self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
99        )
100    }
101}
102
103#[pymethods]
104impl PyWindowExpr {
105    /// Returns the schema of the Window
106    pub fn schema(&self) -> PyResult<PyDFSchema> {
107        Ok(self.window.schema.as_ref().clone().into())
108    }
109
110    /// Returns window expressions
111    pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
112        py_expr_list(&self.window.window_expr)
113    }
114
115    /// Returns order by columns in a window function expression
116    pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
117        match expr.expr.unalias() {
118            Expr::WindowFunction(boxed_window_fn) => {
119                py_sort_expr_list(&boxed_window_fn.params.order_by)
120            }
121            other => Err(not_window_function_err(other)),
122        }
123    }
124
125    /// Return partition by columns in a window function expression
126    pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
127        match expr.expr.unalias() {
128            Expr::WindowFunction(boxed_window_fn) => {
129                py_expr_list(&boxed_window_fn.params.partition_by)
130            }
131            other => Err(not_window_function_err(other)),
132        }
133    }
134
135    /// Return input args for window function
136    pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
137        match expr.expr.unalias() {
138            Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args),
139            other => Err(not_window_function_err(other)),
140        }
141    }
142
143    /// Return window function name
144    pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
145        match expr.expr.unalias() {
146            Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()),
147            other => Err(not_window_function_err(other)),
148        }
149    }
150
151    /// Returns a Pywindow frame for a given window function expression
152    pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
153        match expr.expr.unalias() {
154            Expr::WindowFunction(boxed_window_fn) => {
155                Some(boxed_window_fn.params.window_frame.into())
156            }
157            _ => None,
158        }
159    }
160}
161
162fn not_window_function_err(expr: Expr) -> PyErr {
163    py_type_err(format!(
164        "Provided {} Expr {:?} is not a WindowFunction type",
165        expr.variant_name(),
166        expr
167    ))
168}
169
170#[pymethods]
171impl PyWindowFrame {
172    #[new]
173    #[pyo3(signature=(unit, start_bound, end_bound))]
174    pub fn new(
175        unit: &str,
176        start_bound: Option<PyScalarValue>,
177        end_bound: Option<PyScalarValue>,
178    ) -> PyResult<Self> {
179        let units = unit.to_ascii_lowercase();
180        let units = match units.as_str() {
181            "rows" => WindowFrameUnits::Rows,
182            "range" => WindowFrameUnits::Range,
183            "groups" => WindowFrameUnits::Groups,
184            _ => {
185                return Err(PyNotImplementedError::new_err(format!("{units:?}")));
186            }
187        };
188        let start_bound = match start_bound {
189            Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
190            None => match units {
191                WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
192                WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
193                WindowFrameUnits::Groups => {
194                    return Err(PyNotImplementedError::new_err(format!("{units:?}")));
195                }
196            },
197        };
198        let end_bound = match end_bound {
199            Some(end_bound) => WindowFrameBound::Following(end_bound.0),
200            None => match units {
201                WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
202                WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
203                WindowFrameUnits::Groups => {
204                    return Err(PyNotImplementedError::new_err(format!("{units:?}")));
205                }
206            },
207        };
208        Ok(PyWindowFrame {
209            window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
210        })
211    }
212
213    /// Returns the window frame units for the bounds
214    pub fn get_frame_units(&self) -> PyResult<String> {
215        Ok(self.window_frame.units.to_string())
216    }
217    /// Returns starting bound
218    pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
219        Ok(self.window_frame.start_bound.clone().into())
220    }
221    /// Returns end bound
222    pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
223        Ok(self.window_frame.end_bound.clone().into())
224    }
225
226    /// Get a String representation of this window frame
227    fn __repr__(&self) -> String {
228        format!("{self}")
229    }
230}
231
232#[pymethods]
233impl PyWindowFrameBound {
234    /// Returns if the frame bound is current row
235    pub fn is_current_row(&self) -> bool {
236        matches!(self.frame_bound, WindowFrameBound::CurrentRow)
237    }
238
239    /// Returns if the frame bound is preceding
240    pub fn is_preceding(&self) -> bool {
241        matches!(self.frame_bound, WindowFrameBound::Preceding(_))
242    }
243
244    /// Returns if the frame bound is following
245    pub fn is_following(&self) -> bool {
246        matches!(self.frame_bound, WindowFrameBound::Following(_))
247    }
248    /// Returns the offset of the window frame
249    pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
250        match &self.frame_bound {
251            WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
252                x if x.is_null() => Ok(None),
253                ScalarValue::UInt64(v) => Ok(*v),
254                // The cast below is only safe because window bounds cannot be negative
255                ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
256                ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
257                    Ok(s) => Ok(Some(s)),
258                    Err(_e) => Err(DataFusionError::Plan(format!(
259                        "Unable to parse u64 from Utf8 value '{s}'"
260                    ))
261                    .into()),
262                },
263                ref x => {
264                    Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
265                }
266            },
267            WindowFrameBound::CurrentRow => Ok(None),
268        }
269    }
270    /// Returns if the frame bound is unbounded
271    pub fn is_unbounded(&self) -> PyResult<bool> {
272        match &self.frame_bound {
273            WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
274            WindowFrameBound::CurrentRow => Ok(false),
275        }
276    }
277}
278
279impl LogicalNode for PyWindowExpr {
280    fn inputs(&self) -> Vec<PyLogicalPlan> {
281        vec![self.window.input.as_ref().clone().into()]
282    }
283
284    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
285        self.clone().into_bound_py_any(py)
286    }
287}