datafusion_python/expr/
window.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::{self, Display, Formatter};
19
20use datafusion::common::{DataFusionError, ScalarValue};
21use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
22use pyo3::exceptions::PyNotImplementedError;
23use pyo3::prelude::*;
24use pyo3::IntoPyObjectExt;
25
26use super::py_expr_list;
27use crate::common::data_type::PyScalarValue;
28use crate::common::df_schema::PyDFSchema;
29use crate::errors::{py_type_err, PyDataFusionResult};
30use crate::expr::logical_node::LogicalNode;
31use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
32use crate::expr::PyExpr;
33use crate::sql::logical::PyLogicalPlan;
34
35#[pyclass(frozen, name = "WindowExpr", module = "datafusion.expr", subclass)]
36#[derive(Clone)]
37pub struct PyWindowExpr {
38    window: Window,
39}
40
41#[pyclass(frozen, name = "WindowFrame", module = "datafusion.expr", subclass)]
42#[derive(Clone)]
43pub struct PyWindowFrame {
44    window_frame: WindowFrame,
45}
46
47impl From<PyWindowFrame> for WindowFrame {
48    fn from(window_frame: PyWindowFrame) -> Self {
49        window_frame.window_frame
50    }
51}
52
53impl From<WindowFrame> for PyWindowFrame {
54    fn from(window_frame: WindowFrame) -> PyWindowFrame {
55        PyWindowFrame { window_frame }
56    }
57}
58
59#[pyclass(
60    frozen,
61    name = "WindowFrameBound",
62    module = "datafusion.expr",
63    subclass
64)]
65#[derive(Clone)]
66pub struct PyWindowFrameBound {
67    frame_bound: WindowFrameBound,
68}
69
70impl From<PyWindowExpr> for Window {
71    fn from(window: PyWindowExpr) -> Window {
72        window.window
73    }
74}
75
76impl From<Window> for PyWindowExpr {
77    fn from(window: Window) -> PyWindowExpr {
78        PyWindowExpr { window }
79    }
80}
81
82impl From<WindowFrameBound> for PyWindowFrameBound {
83    fn from(frame_bound: WindowFrameBound) -> Self {
84        PyWindowFrameBound { frame_bound }
85    }
86}
87
88impl Display for PyWindowExpr {
89    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
90        write!(
91            f,
92            "Over\n
93            Window Expr: {:?}
94            Schema: {:?}",
95            &self.window.window_expr, &self.window.schema
96        )
97    }
98}
99
100impl Display for PyWindowFrame {
101    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
102        write!(
103            f,
104            "OVER ({} BETWEEN {} AND {})",
105            self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
106        )
107    }
108}
109
110#[pymethods]
111impl PyWindowExpr {
112    /// Returns the schema of the Window
113    pub fn schema(&self) -> PyResult<PyDFSchema> {
114        Ok(self.window.schema.as_ref().clone().into())
115    }
116
117    /// Returns window expressions
118    pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
119        py_expr_list(&self.window.window_expr)
120    }
121
122    /// Returns order by columns in a window function expression
123    pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
124        match expr.expr.unalias() {
125            Expr::WindowFunction(boxed_window_fn) => {
126                py_sort_expr_list(&boxed_window_fn.params.order_by)
127            }
128            other => Err(not_window_function_err(other)),
129        }
130    }
131
132    /// Return partition by columns in a window function expression
133    pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
134        match expr.expr.unalias() {
135            Expr::WindowFunction(boxed_window_fn) => {
136                py_expr_list(&boxed_window_fn.params.partition_by)
137            }
138            other => Err(not_window_function_err(other)),
139        }
140    }
141
142    /// Return input args for window function
143    pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
144        match expr.expr.unalias() {
145            Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args),
146            other => Err(not_window_function_err(other)),
147        }
148    }
149
150    /// Return window function name
151    pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
152        match expr.expr.unalias() {
153            Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()),
154            other => Err(not_window_function_err(other)),
155        }
156    }
157
158    /// Returns a Pywindow frame for a given window function expression
159    pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
160        match expr.expr.unalias() {
161            Expr::WindowFunction(boxed_window_fn) => {
162                Some(boxed_window_fn.params.window_frame.into())
163            }
164            _ => None,
165        }
166    }
167}
168
169fn not_window_function_err(expr: Expr) -> PyErr {
170    py_type_err(format!(
171        "Provided {} Expr {:?} is not a WindowFunction type",
172        expr.variant_name(),
173        expr
174    ))
175}
176
177#[pymethods]
178impl PyWindowFrame {
179    #[new]
180    #[pyo3(signature=(unit, start_bound, end_bound))]
181    pub fn new(
182        unit: &str,
183        start_bound: Option<PyScalarValue>,
184        end_bound: Option<PyScalarValue>,
185    ) -> PyResult<Self> {
186        let units = unit.to_ascii_lowercase();
187        let units = match units.as_str() {
188            "rows" => WindowFrameUnits::Rows,
189            "range" => WindowFrameUnits::Range,
190            "groups" => WindowFrameUnits::Groups,
191            _ => {
192                return Err(PyNotImplementedError::new_err(format!("{units:?}")));
193            }
194        };
195        let start_bound = match start_bound {
196            Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
197            None => match units {
198                WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
199                WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
200                WindowFrameUnits::Groups => {
201                    return Err(PyNotImplementedError::new_err(format!("{units:?}")));
202                }
203            },
204        };
205        let end_bound = match end_bound {
206            Some(end_bound) => WindowFrameBound::Following(end_bound.0),
207            None => match units {
208                WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
209                WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
210                WindowFrameUnits::Groups => {
211                    return Err(PyNotImplementedError::new_err(format!("{units:?}")));
212                }
213            },
214        };
215        Ok(PyWindowFrame {
216            window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
217        })
218    }
219
220    /// Returns the window frame units for the bounds
221    pub fn get_frame_units(&self) -> PyResult<String> {
222        Ok(self.window_frame.units.to_string())
223    }
224    /// Returns starting bound
225    pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
226        Ok(self.window_frame.start_bound.clone().into())
227    }
228    /// Returns end bound
229    pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
230        Ok(self.window_frame.end_bound.clone().into())
231    }
232
233    /// Get a String representation of this window frame
234    fn __repr__(&self) -> String {
235        format!("{self}")
236    }
237}
238
239#[pymethods]
240impl PyWindowFrameBound {
241    /// Returns if the frame bound is current row
242    pub fn is_current_row(&self) -> bool {
243        matches!(self.frame_bound, WindowFrameBound::CurrentRow)
244    }
245
246    /// Returns if the frame bound is preceding
247    pub fn is_preceding(&self) -> bool {
248        matches!(self.frame_bound, WindowFrameBound::Preceding(_))
249    }
250
251    /// Returns if the frame bound is following
252    pub fn is_following(&self) -> bool {
253        matches!(self.frame_bound, WindowFrameBound::Following(_))
254    }
255    /// Returns the offset of the window frame
256    pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
257        match &self.frame_bound {
258            WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
259                x if x.is_null() => Ok(None),
260                ScalarValue::UInt64(v) => Ok(*v),
261                // The cast below is only safe because window bounds cannot be negative
262                ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
263                ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
264                    Ok(s) => Ok(Some(s)),
265                    Err(_e) => Err(DataFusionError::Plan(format!(
266                        "Unable to parse u64 from Utf8 value '{s}'"
267                    ))
268                    .into()),
269                },
270                ref x => {
271                    Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
272                }
273            },
274            WindowFrameBound::CurrentRow => Ok(None),
275        }
276    }
277    /// Returns if the frame bound is unbounded
278    pub fn is_unbounded(&self) -> PyResult<bool> {
279        match &self.frame_bound {
280            WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
281            WindowFrameBound::CurrentRow => Ok(false),
282        }
283    }
284}
285
286impl LogicalNode for PyWindowExpr {
287    fn inputs(&self) -> Vec<PyLogicalPlan> {
288        vec![self.window.input.as_ref().clone().into()]
289    }
290
291    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
292        self.clone().into_bound_py_any(py)
293    }
294}