datafusion_python/expr/
window.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::common::{DataFusionError, ScalarValue};
19use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams};
20use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
21use pyo3::{prelude::*, IntoPyObjectExt};
22use std::fmt::{self, Display, Formatter};
23
24use crate::common::data_type::PyScalarValue;
25use crate::common::df_schema::PyDFSchema;
26use crate::errors::{py_type_err, PyDataFusionResult};
27use crate::expr::logical_node::LogicalNode;
28use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
29use crate::expr::PyExpr;
30use crate::sql::logical::PyLogicalPlan;
31
32use super::py_expr_list;
33
34use crate::errors::py_datafusion_err;
35
36#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
37#[derive(Clone)]
38pub struct PyWindowExpr {
39    window: Window,
40}
41
42#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)]
43#[derive(Clone)]
44pub struct PyWindowFrame {
45    window_frame: WindowFrame,
46}
47
48impl From<PyWindowFrame> for WindowFrame {
49    fn from(window_frame: PyWindowFrame) -> Self {
50        window_frame.window_frame
51    }
52}
53
54impl From<WindowFrame> for PyWindowFrame {
55    fn from(window_frame: WindowFrame) -> PyWindowFrame {
56        PyWindowFrame { window_frame }
57    }
58}
59
60#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)]
61#[derive(Clone)]
62pub struct PyWindowFrameBound {
63    frame_bound: WindowFrameBound,
64}
65
66impl From<PyWindowExpr> for Window {
67    fn from(window: PyWindowExpr) -> Window {
68        window.window
69    }
70}
71
72impl From<Window> for PyWindowExpr {
73    fn from(window: Window) -> PyWindowExpr {
74        PyWindowExpr { window }
75    }
76}
77
78impl From<WindowFrameBound> for PyWindowFrameBound {
79    fn from(frame_bound: WindowFrameBound) -> Self {
80        PyWindowFrameBound { frame_bound }
81    }
82}
83
84impl Display for PyWindowExpr {
85    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
86        write!(
87            f,
88            "Over\n
89            Window Expr: {:?}
90            Schema: {:?}",
91            &self.window.window_expr, &self.window.schema
92        )
93    }
94}
95
96impl Display for PyWindowFrame {
97    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
98        write!(
99            f,
100            "OVER ({} BETWEEN {} AND {})",
101            self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
102        )
103    }
104}
105
106#[pymethods]
107impl PyWindowExpr {
108    /// Returns the schema of the Window
109    pub fn schema(&self) -> PyResult<PyDFSchema> {
110        Ok(self.window.schema.as_ref().clone().into())
111    }
112
113    /// Returns window expressions
114    pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
115        py_expr_list(&self.window.window_expr)
116    }
117
118    /// Returns order by columns in a window function expression
119    pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
120        match expr.expr.unalias() {
121            Expr::WindowFunction(WindowFunction {
122                params: WindowFunctionParams { order_by, .. },
123                ..
124            }) => py_sort_expr_list(&order_by),
125            other => Err(not_window_function_err(other)),
126        }
127    }
128
129    /// Return partition by columns in a window function expression
130    pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
131        match expr.expr.unalias() {
132            Expr::WindowFunction(WindowFunction {
133                params: WindowFunctionParams { partition_by, .. },
134                ..
135            }) => py_expr_list(&partition_by),
136            other => Err(not_window_function_err(other)),
137        }
138    }
139
140    /// Return input args for window function
141    pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
142        match expr.expr.unalias() {
143            Expr::WindowFunction(WindowFunction {
144                params: WindowFunctionParams { args, .. },
145                ..
146            }) => py_expr_list(&args),
147            other => Err(not_window_function_err(other)),
148        }
149    }
150
151    /// Return window function name
152    pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
153        match expr.expr.unalias() {
154            Expr::WindowFunction(WindowFunction { fun, .. }) => Ok(fun.to_string()),
155            other => Err(not_window_function_err(other)),
156        }
157    }
158
159    /// Returns a Pywindow frame for a given window function expression
160    pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
161        match expr.expr.unalias() {
162            Expr::WindowFunction(WindowFunction {
163                params: WindowFunctionParams { window_frame, .. },
164                ..
165            }) => Some(window_frame.into()),
166            _ => None,
167        }
168    }
169}
170
171fn not_window_function_err(expr: Expr) -> PyErr {
172    py_type_err(format!(
173        "Provided {} Expr {:?} is not a WindowFunction type",
174        expr.variant_name(),
175        expr
176    ))
177}
178
179#[pymethods]
180impl PyWindowFrame {
181    #[new]
182    #[pyo3(signature=(unit, start_bound, end_bound))]
183    pub fn new(
184        unit: &str,
185        start_bound: Option<PyScalarValue>,
186        end_bound: Option<PyScalarValue>,
187    ) -> PyResult<Self> {
188        let units = unit.to_ascii_lowercase();
189        let units = match units.as_str() {
190            "rows" => WindowFrameUnits::Rows,
191            "range" => WindowFrameUnits::Range,
192            "groups" => WindowFrameUnits::Groups,
193            _ => {
194                return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
195                    "{:?}",
196                    units,
197                ))));
198            }
199        };
200        let start_bound = match start_bound {
201            Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
202            None => match units {
203                WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
204                WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
205                WindowFrameUnits::Groups => {
206                    return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
207                        "{:?}",
208                        units,
209                    ))));
210                }
211            },
212        };
213        let end_bound = match end_bound {
214            Some(end_bound) => WindowFrameBound::Following(end_bound.0),
215            None => match units {
216                WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
217                WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
218                WindowFrameUnits::Groups => {
219                    return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
220                        "{:?}",
221                        units,
222                    ))));
223                }
224            },
225        };
226        Ok(PyWindowFrame {
227            window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
228        })
229    }
230
231    /// Returns the window frame units for the bounds
232    pub fn get_frame_units(&self) -> PyResult<String> {
233        Ok(self.window_frame.units.to_string())
234    }
235    /// Returns starting bound
236    pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
237        Ok(self.window_frame.start_bound.clone().into())
238    }
239    /// Returns end bound
240    pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
241        Ok(self.window_frame.end_bound.clone().into())
242    }
243
244    /// Get a String representation of this window frame
245    fn __repr__(&self) -> String {
246        format!("{}", self)
247    }
248}
249
250#[pymethods]
251impl PyWindowFrameBound {
252    /// Returns if the frame bound is current row
253    pub fn is_current_row(&self) -> bool {
254        matches!(self.frame_bound, WindowFrameBound::CurrentRow)
255    }
256
257    /// Returns if the frame bound is preceding
258    pub fn is_preceding(&self) -> bool {
259        matches!(self.frame_bound, WindowFrameBound::Preceding(_))
260    }
261
262    /// Returns if the frame bound is following
263    pub fn is_following(&self) -> bool {
264        matches!(self.frame_bound, WindowFrameBound::Following(_))
265    }
266    /// Returns the offset of the window frame
267    pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
268        match &self.frame_bound {
269            WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
270                x if x.is_null() => Ok(None),
271                ScalarValue::UInt64(v) => Ok(*v),
272                // The cast below is only safe because window bounds cannot be negative
273                ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
274                ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
275                    Ok(s) => Ok(Some(s)),
276                    Err(_e) => Err(DataFusionError::Plan(format!(
277                        "Unable to parse u64 from Utf8 value '{s}'"
278                    ))
279                    .into()),
280                },
281                ref x => {
282                    Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
283                }
284            },
285            WindowFrameBound::CurrentRow => Ok(None),
286        }
287    }
288    /// Returns if the frame bound is unbounded
289    pub fn is_unbounded(&self) -> PyResult<bool> {
290        match &self.frame_bound {
291            WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
292            WindowFrameBound::CurrentRow => Ok(false),
293        }
294    }
295}
296
297impl LogicalNode for PyWindowExpr {
298    fn inputs(&self) -> Vec<PyLogicalPlan> {
299        vec![self.window.input.as_ref().clone().into()]
300    }
301
302    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
303        self.clone().into_bound_py_any(py)
304    }
305}