1use std::any::Any;
19use std::ops::Range;
20use std::sync::Arc;
21
22use arrow::array::{make_array, Array, ArrayData, ArrayRef};
23use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs};
24use datafusion::logical_expr::window_state::WindowAggState;
25use datafusion::scalar::ScalarValue;
26use pyo3::exceptions::PyValueError;
27use pyo3::prelude::*;
28
29use crate::common::data_type::PyScalarValue;
30use crate::errors::to_datafusion_err;
31use crate::expr::PyExpr;
32use crate::utils::parse_volatility;
33use datafusion::arrow::datatypes::DataType;
34use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
35use datafusion::error::{DataFusionError, Result};
36use datafusion::logical_expr::{
37 PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl,
38};
39use pyo3::types::{PyList, PyTuple};
40
41#[derive(Debug)]
42struct RustPartitionEvaluator {
43 evaluator: PyObject,
44}
45
46impl RustPartitionEvaluator {
47 fn new(evaluator: PyObject) -> Self {
48 Self { evaluator }
49 }
50}
51
52impl PartitionEvaluator for RustPartitionEvaluator {
53 fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> {
54 Python::with_gil(|py| self.evaluator.bind(py).call_method0("memoize").map(|_| ()))
55 .map_err(|e| DataFusionError::Execution(format!("{e}")))
56 }
57
58 fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
59 Python::with_gil(|py| {
60 let py_args = vec![idx.into_pyobject(py)?, n_rows.into_pyobject(py)?];
61 let py_args = PyTuple::new(py, py_args)?;
62
63 self.evaluator
64 .bind(py)
65 .call_method1("get_range", py_args)
66 .and_then(|v| {
67 let tuple: Bound<'_, PyTuple> = v.extract()?;
68 if tuple.len() != 2 {
69 return Err(PyValueError::new_err(format!(
70 "Expected get_range to return tuple of length 2. Received length {}",
71 tuple.len()
72 )));
73 }
74
75 let start: usize = tuple.get_item(0).unwrap().extract()?;
76 let end: usize = tuple.get_item(1).unwrap().extract()?;
77
78 Ok(Range { start, end })
79 })
80 })
81 .map_err(|e| DataFusionError::Execution(format!("{e}")))
82 }
83
84 fn is_causal(&self) -> bool {
85 Python::with_gil(|py| {
86 self.evaluator
87 .bind(py)
88 .call_method0("is_causal")
89 .and_then(|v| v.extract())
90 .unwrap_or(false)
91 })
92 }
93
94 fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
95 println!("evaluate all called with number of values {}", values.len());
96 Python::with_gil(|py| {
97 let py_values = PyList::new(
98 py,
99 values
100 .iter()
101 .map(|arg| arg.into_data().to_pyarrow(py).unwrap()),
102 )?;
103 let py_num_rows = num_rows.into_pyobject(py)?;
104 let py_args = PyTuple::new(py, vec![py_values.as_any(), &py_num_rows])?;
105
106 self.evaluator
107 .bind(py)
108 .call_method1("evaluate_all", py_args)
109 .map(|v| {
110 let array_data = ArrayData::from_pyarrow_bound(&v).unwrap();
111 make_array(array_data)
112 })
113 })
114 .map_err(to_datafusion_err)
115 }
116
117 fn evaluate(&mut self, values: &[ArrayRef], range: &Range<usize>) -> Result<ScalarValue> {
118 Python::with_gil(|py| {
119 let py_values = PyList::new(
120 py,
121 values
122 .iter()
123 .map(|arg| arg.into_data().to_pyarrow(py).unwrap()),
124 )?;
125 let range_tuple = PyTuple::new(py, vec![range.start, range.end])?;
126 let py_args = PyTuple::new(py, vec![py_values.as_any(), range_tuple.as_any()])?;
127
128 self.evaluator
129 .bind(py)
130 .call_method1("evaluate", py_args)
131 .and_then(|v| v.extract::<PyScalarValue>())
132 .map(|v| v.0)
133 })
134 .map_err(to_datafusion_err)
135 }
136
137 fn evaluate_all_with_rank(
138 &self,
139 num_rows: usize,
140 ranks_in_partition: &[Range<usize>],
141 ) -> Result<ArrayRef> {
142 Python::with_gil(|py| {
143 let ranks = ranks_in_partition
144 .iter()
145 .map(|r| PyTuple::new(py, vec![r.start, r.end]))
146 .collect::<PyResult<Vec<_>>>()?;
147
148 let py_args = vec![
150 num_rows.into_pyobject(py)?.into_any(),
151 PyList::new(py, ranks)?.into_any(),
152 ];
153
154 let py_args = PyTuple::new(py, py_args)?;
155
156 self.evaluator
158 .bind(py)
159 .call_method1("evaluate_all_with_rank", py_args)
160 .map(|v| {
161 let array_data = ArrayData::from_pyarrow_bound(&v).unwrap();
162 make_array(array_data)
163 })
164 })
165 .map_err(to_datafusion_err)
166 }
167
168 fn supports_bounded_execution(&self) -> bool {
169 Python::with_gil(|py| {
170 self.evaluator
171 .bind(py)
172 .call_method0("supports_bounded_execution")
173 .and_then(|v| v.extract())
174 .unwrap_or(false)
175 })
176 }
177
178 fn uses_window_frame(&self) -> bool {
179 Python::with_gil(|py| {
180 self.evaluator
181 .bind(py)
182 .call_method0("uses_window_frame")
183 .and_then(|v| v.extract())
184 .unwrap_or(false)
185 })
186 }
187
188 fn include_rank(&self) -> bool {
189 Python::with_gil(|py| {
190 self.evaluator
191 .bind(py)
192 .call_method0("include_rank")
193 .and_then(|v| v.extract())
194 .unwrap_or(false)
195 })
196 }
197}
198
199pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFactory {
200 Arc::new(move || -> Result<Box<dyn PartitionEvaluator>> {
201 let evaluator = Python::with_gil(|py| {
202 evaluator
203 .call0(py)
204 .map_err(|e| DataFusionError::Execution(e.to_string()))
205 })?;
206 Ok(Box::new(RustPartitionEvaluator::new(evaluator)))
207 })
208}
209
210#[pyclass(name = "WindowUDF", module = "datafusion", subclass)]
212#[derive(Debug, Clone)]
213pub struct PyWindowUDF {
214 pub(crate) function: WindowUDF,
215}
216
217#[pymethods]
218impl PyWindowUDF {
219 #[new]
220 #[pyo3(signature=(name, evaluator, input_types, return_type, volatility))]
221 fn new(
222 name: &str,
223 evaluator: PyObject,
224 input_types: Vec<PyArrowType<DataType>>,
225 return_type: PyArrowType<DataType>,
226 volatility: &str,
227 ) -> PyResult<Self> {
228 let return_type = return_type.0;
229 let input_types = input_types.into_iter().map(|t| t.0).collect();
230
231 let function = WindowUDF::from(MultiColumnWindowUDF::new(
232 name,
233 input_types,
234 return_type,
235 parse_volatility(volatility)?,
236 to_rust_partition_evaluator(evaluator),
237 ));
238 Ok(Self { function })
239 }
240
241 #[pyo3(signature = (*args))]
243 fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
244 let args = args.iter().map(|e| e.expr.clone()).collect();
245 Ok(self.function.call(args).into())
246 }
247
248 fn __repr__(&self) -> PyResult<String> {
249 Ok(format!("WindowUDF({})", self.function.name()))
250 }
251}
252
253pub struct MultiColumnWindowUDF {
254 name: String,
255 signature: Signature,
256 return_type: DataType,
257 partition_evaluator_factory: PartitionEvaluatorFactory,
258}
259
260impl std::fmt::Debug for MultiColumnWindowUDF {
261 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
262 f.debug_struct("WindowUDF")
263 .field("name", &self.name)
264 .field("signature", &self.signature)
265 .field("return_type", &"<func>")
266 .field("partition_evaluator_factory", &"<FUNC>")
267 .finish()
268 }
269}
270
271impl MultiColumnWindowUDF {
272 pub fn new(
273 name: impl Into<String>,
274 input_types: Vec<DataType>,
275 return_type: DataType,
276 volatility: Volatility,
277 partition_evaluator_factory: PartitionEvaluatorFactory,
278 ) -> Self {
279 let name = name.into();
280 let signature = Signature::exact(input_types, volatility);
281 Self {
282 name,
283 signature,
284 return_type,
285 partition_evaluator_factory,
286 }
287 }
288}
289
290impl WindowUDFImpl for MultiColumnWindowUDF {
291 fn as_any(&self) -> &dyn Any {
292 self
293 }
294
295 fn name(&self) -> &str {
296 &self.name
297 }
298
299 fn signature(&self) -> &Signature {
300 &self.signature
301 }
302
303 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<arrow::datatypes::Field> {
304 Ok(arrow::datatypes::Field::new(
306 field_args.name(),
307 self.return_type.clone(),
308 true,
309 ))
310 }
311
312 fn partition_evaluator(
314 &self,
315 _partition_evaluator_args: PartitionEvaluatorArgs,
316 ) -> Result<Box<dyn PartitionEvaluator>> {
317 let _ = _partition_evaluator_args;
318 (self.partition_evaluator_factory)()
319 }
320}