datafusion_python/expr/
window.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::common::data_type::PyScalarValue;
19use crate::common::df_schema::PyDFSchema;
20use crate::errors::{py_type_err, PyDataFusionResult};
21use crate::expr::logical_node::LogicalNode;
22use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
23use crate::expr::PyExpr;
24use crate::sql::logical::PyLogicalPlan;
25use datafusion::common::{DataFusionError, ScalarValue};
26use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
27use pyo3::exceptions::PyNotImplementedError;
28use pyo3::{prelude::*, IntoPyObjectExt};
29use std::fmt::{self, Display, Formatter};
30
31use super::py_expr_list;
32
33#[pyclass(frozen, name = "WindowExpr", module = "datafusion.expr", subclass)]
34#[derive(Clone)]
35pub struct PyWindowExpr {
36    window: Window,
37}
38
39#[pyclass(frozen, name = "WindowFrame", module = "datafusion.expr", subclass)]
40#[derive(Clone)]
41pub struct PyWindowFrame {
42    window_frame: WindowFrame,
43}
44
45impl From<PyWindowFrame> for WindowFrame {
46    fn from(window_frame: PyWindowFrame) -> Self {
47        window_frame.window_frame
48    }
49}
50
51impl From<WindowFrame> for PyWindowFrame {
52    fn from(window_frame: WindowFrame) -> PyWindowFrame {
53        PyWindowFrame { window_frame }
54    }
55}
56
57#[pyclass(
58    frozen,
59    name = "WindowFrameBound",
60    module = "datafusion.expr",
61    subclass
62)]
63#[derive(Clone)]
64pub struct PyWindowFrameBound {
65    frame_bound: WindowFrameBound,
66}
67
68impl From<PyWindowExpr> for Window {
69    fn from(window: PyWindowExpr) -> Window {
70        window.window
71    }
72}
73
74impl From<Window> for PyWindowExpr {
75    fn from(window: Window) -> PyWindowExpr {
76        PyWindowExpr { window }
77    }
78}
79
80impl From<WindowFrameBound> for PyWindowFrameBound {
81    fn from(frame_bound: WindowFrameBound) -> Self {
82        PyWindowFrameBound { frame_bound }
83    }
84}
85
86impl Display for PyWindowExpr {
87    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
88        write!(
89            f,
90            "Over\n
91            Window Expr: {:?}
92            Schema: {:?}",
93            &self.window.window_expr, &self.window.schema
94        )
95    }
96}
97
98impl Display for PyWindowFrame {
99    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
100        write!(
101            f,
102            "OVER ({} BETWEEN {} AND {})",
103            self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
104        )
105    }
106}
107
108#[pymethods]
109impl PyWindowExpr {
110    /// Returns the schema of the Window
111    pub fn schema(&self) -> PyResult<PyDFSchema> {
112        Ok(self.window.schema.as_ref().clone().into())
113    }
114
115    /// Returns window expressions
116    pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
117        py_expr_list(&self.window.window_expr)
118    }
119
120    /// Returns order by columns in a window function expression
121    pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
122        match expr.expr.unalias() {
123            Expr::WindowFunction(boxed_window_fn) => {
124                py_sort_expr_list(&boxed_window_fn.params.order_by)
125            }
126            other => Err(not_window_function_err(other)),
127        }
128    }
129
130    /// Return partition by columns in a window function expression
131    pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
132        match expr.expr.unalias() {
133            Expr::WindowFunction(boxed_window_fn) => {
134                py_expr_list(&boxed_window_fn.params.partition_by)
135            }
136            other => Err(not_window_function_err(other)),
137        }
138    }
139
140    /// Return input args for window function
141    pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
142        match expr.expr.unalias() {
143            Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args),
144            other => Err(not_window_function_err(other)),
145        }
146    }
147
148    /// Return window function name
149    pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
150        match expr.expr.unalias() {
151            Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()),
152            other => Err(not_window_function_err(other)),
153        }
154    }
155
156    /// Returns a Pywindow frame for a given window function expression
157    pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
158        match expr.expr.unalias() {
159            Expr::WindowFunction(boxed_window_fn) => {
160                Some(boxed_window_fn.params.window_frame.into())
161            }
162            _ => None,
163        }
164    }
165}
166
167fn not_window_function_err(expr: Expr) -> PyErr {
168    py_type_err(format!(
169        "Provided {} Expr {:?} is not a WindowFunction type",
170        expr.variant_name(),
171        expr
172    ))
173}
174
175#[pymethods]
176impl PyWindowFrame {
177    #[new]
178    #[pyo3(signature=(unit, start_bound, end_bound))]
179    pub fn new(
180        unit: &str,
181        start_bound: Option<PyScalarValue>,
182        end_bound: Option<PyScalarValue>,
183    ) -> PyResult<Self> {
184        let units = unit.to_ascii_lowercase();
185        let units = match units.as_str() {
186            "rows" => WindowFrameUnits::Rows,
187            "range" => WindowFrameUnits::Range,
188            "groups" => WindowFrameUnits::Groups,
189            _ => {
190                return Err(PyNotImplementedError::new_err(format!("{units:?}")));
191            }
192        };
193        let start_bound = match start_bound {
194            Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
195            None => match units {
196                WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
197                WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
198                WindowFrameUnits::Groups => {
199                    return Err(PyNotImplementedError::new_err(format!("{units:?}")));
200                }
201            },
202        };
203        let end_bound = match end_bound {
204            Some(end_bound) => WindowFrameBound::Following(end_bound.0),
205            None => match units {
206                WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
207                WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
208                WindowFrameUnits::Groups => {
209                    return Err(PyNotImplementedError::new_err(format!("{units:?}")));
210                }
211            },
212        };
213        Ok(PyWindowFrame {
214            window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
215        })
216    }
217
218    /// Returns the window frame units for the bounds
219    pub fn get_frame_units(&self) -> PyResult<String> {
220        Ok(self.window_frame.units.to_string())
221    }
222    /// Returns starting bound
223    pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
224        Ok(self.window_frame.start_bound.clone().into())
225    }
226    /// Returns end bound
227    pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
228        Ok(self.window_frame.end_bound.clone().into())
229    }
230
231    /// Get a String representation of this window frame
232    fn __repr__(&self) -> String {
233        format!("{self}")
234    }
235}
236
237#[pymethods]
238impl PyWindowFrameBound {
239    /// Returns if the frame bound is current row
240    pub fn is_current_row(&self) -> bool {
241        matches!(self.frame_bound, WindowFrameBound::CurrentRow)
242    }
243
244    /// Returns if the frame bound is preceding
245    pub fn is_preceding(&self) -> bool {
246        matches!(self.frame_bound, WindowFrameBound::Preceding(_))
247    }
248
249    /// Returns if the frame bound is following
250    pub fn is_following(&self) -> bool {
251        matches!(self.frame_bound, WindowFrameBound::Following(_))
252    }
253    /// Returns the offset of the window frame
254    pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
255        match &self.frame_bound {
256            WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
257                x if x.is_null() => Ok(None),
258                ScalarValue::UInt64(v) => Ok(*v),
259                // The cast below is only safe because window bounds cannot be negative
260                ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
261                ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
262                    Ok(s) => Ok(Some(s)),
263                    Err(_e) => Err(DataFusionError::Plan(format!(
264                        "Unable to parse u64 from Utf8 value '{s}'"
265                    ))
266                    .into()),
267                },
268                ref x => {
269                    Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
270                }
271            },
272            WindowFrameBound::CurrentRow => Ok(None),
273        }
274    }
275    /// Returns if the frame bound is unbounded
276    pub fn is_unbounded(&self) -> PyResult<bool> {
277        match &self.frame_bound {
278            WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
279            WindowFrameBound::CurrentRow => Ok(false),
280        }
281    }
282}
283
284impl LogicalNode for PyWindowExpr {
285    fn inputs(&self) -> Vec<PyLogicalPlan> {
286        vec![self.window.input.as_ref().clone().into()]
287    }
288
289    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
290        self.clone().into_bound_py_any(py)
291    }
292}