datafusion_python/
udwf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::ops::Range;
20use std::sync::Arc;
21
22use arrow::array::{make_array, Array, ArrayData, ArrayRef};
23use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs};
24use datafusion::logical_expr::window_state::WindowAggState;
25use datafusion::scalar::ScalarValue;
26use pyo3::exceptions::PyValueError;
27use pyo3::prelude::*;
28
29use crate::common::data_type::PyScalarValue;
30use crate::errors::to_datafusion_err;
31use crate::expr::PyExpr;
32use crate::utils::parse_volatility;
33use datafusion::arrow::datatypes::DataType;
34use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
35use datafusion::error::{DataFusionError, Result};
36use datafusion::logical_expr::{
37    PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl,
38};
39use pyo3::types::{PyList, PyTuple};
40
41#[derive(Debug)]
42struct RustPartitionEvaluator {
43    evaluator: PyObject,
44}
45
46impl RustPartitionEvaluator {
47    fn new(evaluator: PyObject) -> Self {
48        Self { evaluator }
49    }
50}
51
52impl PartitionEvaluator for RustPartitionEvaluator {
53    fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> {
54        Python::with_gil(|py| self.evaluator.bind(py).call_method0("memoize").map(|_| ()))
55            .map_err(|e| DataFusionError::Execution(format!("{e}")))
56    }
57
58    fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
59        Python::with_gil(|py| {
60            let py_args = vec![idx.into_pyobject(py)?, n_rows.into_pyobject(py)?];
61            let py_args = PyTuple::new(py, py_args)?;
62
63            self.evaluator
64                .bind(py)
65                .call_method1("get_range", py_args)
66                .and_then(|v| {
67                    let tuple: Bound<'_, PyTuple> = v.extract()?;
68                    if tuple.len() != 2 {
69                        return Err(PyValueError::new_err(format!(
70                            "Expected get_range to return tuple of length 2. Received length {}",
71                            tuple.len()
72                        )));
73                    }
74
75                    let start: usize = tuple.get_item(0).unwrap().extract()?;
76                    let end: usize = tuple.get_item(1).unwrap().extract()?;
77
78                    Ok(Range { start, end })
79                })
80        })
81        .map_err(|e| DataFusionError::Execution(format!("{e}")))
82    }
83
84    fn is_causal(&self) -> bool {
85        Python::with_gil(|py| {
86            self.evaluator
87                .bind(py)
88                .call_method0("is_causal")
89                .and_then(|v| v.extract())
90                .unwrap_or(false)
91        })
92    }
93
94    fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
95        println!("evaluate all called with number of values {}", values.len());
96        Python::with_gil(|py| {
97            let py_values = PyList::new(
98                py,
99                values
100                    .iter()
101                    .map(|arg| arg.into_data().to_pyarrow(py).unwrap()),
102            )?;
103            let py_num_rows = num_rows.into_pyobject(py)?;
104            let py_args = PyTuple::new(py, vec![py_values.as_any(), &py_num_rows])?;
105
106            self.evaluator
107                .bind(py)
108                .call_method1("evaluate_all", py_args)
109                .map(|v| {
110                    let array_data = ArrayData::from_pyarrow_bound(&v).unwrap();
111                    make_array(array_data)
112                })
113        })
114        .map_err(to_datafusion_err)
115    }
116
117    fn evaluate(&mut self, values: &[ArrayRef], range: &Range<usize>) -> Result<ScalarValue> {
118        Python::with_gil(|py| {
119            let py_values = PyList::new(
120                py,
121                values
122                    .iter()
123                    .map(|arg| arg.into_data().to_pyarrow(py).unwrap()),
124            )?;
125            let range_tuple = PyTuple::new(py, vec![range.start, range.end])?;
126            let py_args = PyTuple::new(py, vec![py_values.as_any(), range_tuple.as_any()])?;
127
128            self.evaluator
129                .bind(py)
130                .call_method1("evaluate", py_args)
131                .and_then(|v| v.extract::<PyScalarValue>())
132                .map(|v| v.0)
133        })
134        .map_err(to_datafusion_err)
135    }
136
137    fn evaluate_all_with_rank(
138        &self,
139        num_rows: usize,
140        ranks_in_partition: &[Range<usize>],
141    ) -> Result<ArrayRef> {
142        Python::with_gil(|py| {
143            let ranks = ranks_in_partition
144                .iter()
145                .map(|r| PyTuple::new(py, vec![r.start, r.end]))
146                .collect::<PyResult<Vec<_>>>()?;
147
148            // 1. cast args to Pyarrow array
149            let py_args = vec![
150                num_rows.into_pyobject(py)?.into_any(),
151                PyList::new(py, ranks)?.into_any(),
152            ];
153
154            let py_args = PyTuple::new(py, py_args)?;
155
156            // 2. call function
157            self.evaluator
158                .bind(py)
159                .call_method1("evaluate_all_with_rank", py_args)
160                .map(|v| {
161                    let array_data = ArrayData::from_pyarrow_bound(&v).unwrap();
162                    make_array(array_data)
163                })
164        })
165        .map_err(to_datafusion_err)
166    }
167
168    fn supports_bounded_execution(&self) -> bool {
169        Python::with_gil(|py| {
170            self.evaluator
171                .bind(py)
172                .call_method0("supports_bounded_execution")
173                .and_then(|v| v.extract())
174                .unwrap_or(false)
175        })
176    }
177
178    fn uses_window_frame(&self) -> bool {
179        Python::with_gil(|py| {
180            self.evaluator
181                .bind(py)
182                .call_method0("uses_window_frame")
183                .and_then(|v| v.extract())
184                .unwrap_or(false)
185        })
186    }
187
188    fn include_rank(&self) -> bool {
189        Python::with_gil(|py| {
190            self.evaluator
191                .bind(py)
192                .call_method0("include_rank")
193                .and_then(|v| v.extract())
194                .unwrap_or(false)
195        })
196    }
197}
198
199pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFactory {
200    Arc::new(move || -> Result<Box<dyn PartitionEvaluator>> {
201        let evaluator = Python::with_gil(|py| {
202            evaluator
203                .call0(py)
204                .map_err(|e| DataFusionError::Execution(e.to_string()))
205        })?;
206        Ok(Box::new(RustPartitionEvaluator::new(evaluator)))
207    })
208}
209
210/// Represents an WindowUDF
211#[pyclass(name = "WindowUDF", module = "datafusion", subclass)]
212#[derive(Debug, Clone)]
213pub struct PyWindowUDF {
214    pub(crate) function: WindowUDF,
215}
216
217#[pymethods]
218impl PyWindowUDF {
219    #[new]
220    #[pyo3(signature=(name, evaluator, input_types, return_type, volatility))]
221    fn new(
222        name: &str,
223        evaluator: PyObject,
224        input_types: Vec<PyArrowType<DataType>>,
225        return_type: PyArrowType<DataType>,
226        volatility: &str,
227    ) -> PyResult<Self> {
228        let return_type = return_type.0;
229        let input_types = input_types.into_iter().map(|t| t.0).collect();
230
231        let function = WindowUDF::from(MultiColumnWindowUDF::new(
232            name,
233            input_types,
234            return_type,
235            parse_volatility(volatility)?,
236            to_rust_partition_evaluator(evaluator),
237        ));
238        Ok(Self { function })
239    }
240
241    /// creates a new PyExpr with the call of the udf
242    #[pyo3(signature = (*args))]
243    fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
244        let args = args.iter().map(|e| e.expr.clone()).collect();
245        Ok(self.function.call(args).into())
246    }
247
248    fn __repr__(&self) -> PyResult<String> {
249        Ok(format!("WindowUDF({})", self.function.name()))
250    }
251}
252
253pub struct MultiColumnWindowUDF {
254    name: String,
255    signature: Signature,
256    return_type: DataType,
257    partition_evaluator_factory: PartitionEvaluatorFactory,
258}
259
260impl std::fmt::Debug for MultiColumnWindowUDF {
261    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
262        f.debug_struct("WindowUDF")
263            .field("name", &self.name)
264            .field("signature", &self.signature)
265            .field("return_type", &"<func>")
266            .field("partition_evaluator_factory", &"<FUNC>")
267            .finish()
268    }
269}
270
271impl MultiColumnWindowUDF {
272    pub fn new(
273        name: impl Into<String>,
274        input_types: Vec<DataType>,
275        return_type: DataType,
276        volatility: Volatility,
277        partition_evaluator_factory: PartitionEvaluatorFactory,
278    ) -> Self {
279        let name = name.into();
280        let signature = Signature::exact(input_types, volatility);
281        Self {
282            name,
283            signature,
284            return_type,
285            partition_evaluator_factory,
286        }
287    }
288}
289
290impl WindowUDFImpl for MultiColumnWindowUDF {
291    fn as_any(&self) -> &dyn Any {
292        self
293    }
294
295    fn name(&self) -> &str {
296        &self.name
297    }
298
299    fn signature(&self) -> &Signature {
300        &self.signature
301    }
302
303    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<arrow::datatypes::Field> {
304        // TODO: Should nullable always be `true`?
305        Ok(arrow::datatypes::Field::new(
306            field_args.name(),
307            self.return_type.clone(),
308            true,
309        ))
310    }
311
312    // TODO: Enable passing partition_evaluator_args to python?
313    fn partition_evaluator(
314        &self,
315        _partition_evaluator_args: PartitionEvaluatorArgs,
316    ) -> Result<Box<dyn PartitionEvaluator>> {
317        let _ = _partition_evaluator_args;
318        (self.partition_evaluator_factory)()
319    }
320}