arrow_udf_runtime/python/mod.rs
1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("README.md")]
16
17// Notice for developers:
18// This library uses the sub-interpreter and per-interpreter GIL introduced in Python 3.12
19// to support concurrent execution of different functions in multiple threads.
20// However, pyo3 has not yet safely supported sub-interpreter. We use the raw FFI API of pyo3 to implement sub-interpreter.
21// Therefore, special attention is needed:
22// **All PyObject created in a sub-interpreter must be destroyed in the same sub-interpreter.**
23// Otherwise, it will cause a crash the next time Python is called.
24// Special attention is needed for PyErr in PyResult.
25// Remember to convert `PyErr` using the `pyerr_to_anyhow` function before passing it out of the sub-interpreter.
26
27use crate::into_field::IntoField;
28use crate::CallMode;
29
30use self::interpreter::SubInterpreter;
31use anyhow::{bail, Context, Result};
32use arrow_array::builder::{ArrayBuilder, Int32Builder, StringBuilder};
33use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
34use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
35use pyo3::types::{PyAnyMethods, PyIterator, PyModule, PyTuple};
36use pyo3::{Py, PyObject};
37use std::collections::HashMap;
38use std::fmt::Debug;
39use std::sync::Arc;
40
41// #[cfg(Py_3_12)]
42mod interpreter;
43mod pyarrow;
44
45/// A runtime to execute user defined functions in Python.
46///
47/// # Usages
48///
49/// - Create a new runtime with [`Runtime::new`] or [`Runtime::builder`].
50/// - For scalar functions, use [`add_function`] and [`call`].
51/// - For table functions, use [`add_function`] and [`call_table_function`].
52/// - For aggregate functions, create the function with [`add_aggregate`], and then
53/// - create a new state with [`create_state`],
54/// - update the state with [`accumulate`] or [`accumulate_or_retract`],
55/// - merge states with [`merge`],
56/// - finally get the result with [`finish`].
57///
58/// Click on each function to see the example.
59///
60/// # Parallelism
61///
62/// As we know, Python has a Global Interpreter Lock (GIL) that prevents multiple threads from executing Python code simultaneously.
63/// To work around this limitation, each runtime creates a sub-interpreter with its own GIL. This feature requires Python 3.12 or later.
64///
65/// [`add_function`]: Runtime::add_function
66/// [`add_aggregate`]: Runtime::add_aggregate
67/// [`call`]: Runtime::call
68/// [`call_table_function`]: Runtime::call_table_function
69/// [`create_state`]: Runtime::create_state
70/// [`accumulate`]: Runtime::accumulate
71/// [`accumulate_or_retract`]: Runtime::accumulate_or_retract
72/// [`merge`]: Runtime::merge
73/// [`finish`]: Runtime::finish
74pub struct Runtime {
75 interpreter: SubInterpreter,
76 functions: HashMap<String, Function>,
77 aggregates: HashMap<String, Aggregate>,
78 converter: pyarrow::Converter,
79}
80
81impl Debug for Runtime {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("Runtime")
84 .field("functions", &self.functions.keys())
85 .field("aggregates", &self.aggregates.keys())
86 .finish()
87 }
88}
89
90/// A user defined function.
91struct Function {
92 function: PyObject,
93 return_field: FieldRef,
94 mode: CallMode,
95}
96
97/// A user defined aggregate function.
98struct Aggregate {
99 state_field: FieldRef,
100 output_field: FieldRef,
101 mode: CallMode,
102 create_state: PyObject,
103 accumulate: PyObject,
104 retract: Option<PyObject>,
105 finish: Option<PyObject>,
106 merge: Option<PyObject>,
107}
108
109/// A builder for `Runtime`.
110#[derive(Default, Debug)]
111pub struct Builder {
112 sandboxed: bool,
113 removed_symbols: Vec<String>,
114}
115
116impl Builder {
117 /// Set whether the runtime is sandboxed.
118 ///
119 /// When sandboxed, only a limited set of modules can be imported, and some built-in functions are disabled.
120 /// This is useful for running untrusted code.
121 ///
122 /// Allowed modules: `json`, `decimal`, `re`, `math`, `datetime`, `time`.
123 ///
124 /// Disallowed builtins: `breakpoint`, `exit`, `eval`, `help`, `input`, `open`, `print`.
125 ///
126 /// The default is `false`.
127 pub fn sandboxed(mut self, sandboxed: bool) -> Self {
128 self.sandboxed = sandboxed;
129 self.remove_symbol("__builtins__.breakpoint")
130 .remove_symbol("__builtins__.exit")
131 .remove_symbol("__builtins__.eval")
132 .remove_symbol("__builtins__.help")
133 .remove_symbol("__builtins__.input")
134 .remove_symbol("__builtins__.open")
135 .remove_symbol("__builtins__.print")
136 }
137
138 /// Remove a symbol from builtins.
139 ///
140 /// # Examples
141 ///
142 /// ```
143 /// # use arrow_udf_runtime::python::Runtime;
144 /// let builder = Runtime::builder().remove_symbol("__builtins__.eval");
145 /// ```
146 pub fn remove_symbol(mut self, symbol: &str) -> Self {
147 self.removed_symbols.push(symbol.to_string());
148 self
149 }
150
151 /// Build the `Runtime`.
152 pub fn build(self) -> Result<Runtime> {
153 let interpreter = SubInterpreter::new()?;
154 interpreter.run(
155 r#"
156# internal use for json types
157import json
158import pickle
159import decimal
160
161# an internal class used for struct input arguments
162class Struct:
163 pass
164"#,
165 )?;
166 if self.sandboxed {
167 let mut script = r#"
168# limit the modules that can be imported
169original_import = __builtins__.__import__
170
171def limited_import(name, globals=None, locals=None, fromlist=(), level=0):
172 # FIXME: 'sys' should not be allowed, but it is required by 'decimal'
173 # FIXME: 'time.sleep' should not be allowed, but 'time' is required by 'datetime'
174 allowlist = (
175 'json',
176 'decimal',
177 're',
178 'math',
179 'datetime',
180 'time',
181 'operator',
182 'numbers',
183 'abc',
184 'sys',
185 'contextvars',
186 '_io',
187 '_contextvars',
188 '_pydecimal',
189 '_pydatetime',
190 )
191 if level == 0 and name in allowlist:
192 return original_import(name, globals, locals, fromlist, level)
193 raise ImportError(f'import {name} is not allowed')
194
195__builtins__.__import__ = limited_import
196del limited_import
197"#
198 .to_string();
199 for symbol in self.removed_symbols {
200 script.push_str(&format!("del {}\n", symbol));
201 }
202 interpreter.run(&script)?;
203 }
204 Ok(Runtime {
205 interpreter,
206 functions: HashMap::new(),
207 aggregates: HashMap::new(),
208 converter: pyarrow::Converter::new(),
209 })
210 }
211}
212
213impl Runtime {
214 /// Create a new `Runtime`.
215 pub fn new() -> Result<Self> {
216 Builder::default().build()
217 }
218
219 /// Return a new builder for `Runtime`.
220 pub fn builder() -> Builder {
221 Builder::default()
222 }
223
224 /// Add a new scalar function or table function.
225 ///
226 /// # Arguments
227 ///
228 /// - `name`: The name of the function.
229 /// - `return_type`: The data type of the return value.
230 /// - `mode`: Whether the function will be called when some of its arguments are null.
231 /// - `code`: The Python code of the function.
232 ///
233 /// The code should define a function with the same name as the function.
234 /// The function should return a value for scalar functions, or yield values for table functions.
235 ///
236 /// # Example
237 ///
238 /// ```
239 /// # use arrow_udf_runtime::python::Runtime;
240 /// # use arrow_udf_runtime::CallMode;
241 /// # use arrow_schema::DataType;
242 /// let mut runtime = Runtime::new().unwrap();
243 /// // add a scalar function
244 /// runtime
245 /// .add_function(
246 /// "gcd",
247 /// DataType::Int32,
248 /// CallMode::ReturnNullOnNullInput,
249 /// r#"
250 /// def gcd(a: int, b: int) -> int:
251 /// while b:
252 /// a, b = b, a % b
253 /// return a
254 /// "#,
255 /// )
256 /// .unwrap();
257 /// // add a table function
258 /// runtime
259 /// .add_function(
260 /// "series",
261 /// DataType::Int32,
262 /// CallMode::ReturnNullOnNullInput,
263 /// r#"
264 /// def series(n: int):
265 /// for i in range(n):
266 /// yield i
267 /// "#,
268 /// )
269 /// .unwrap();
270 /// ```
271 pub fn add_function(
272 &mut self,
273 name: &str,
274 return_type: impl IntoField,
275 mode: CallMode,
276 code: &str,
277 ) -> Result<()> {
278 self.add_function_with_handler(name, return_type, mode, code, name)
279 }
280
281 /// Add a new scalar function or table function with custom handler name.
282 ///
283 /// # Arguments
284 ///
285 /// - `handler`: The name of function in Python code to be called.
286 /// - others: Same as [`add_function`].
287 ///
288 /// [`add_function`]: Runtime::add_function
289 pub fn add_function_with_handler(
290 &mut self,
291 name: &str,
292 return_type: impl IntoField,
293 mode: CallMode,
294 code: &str,
295 handler: &str,
296 ) -> Result<()> {
297 let function = self.interpreter.with_gil(|py| {
298 Ok(PyModule::from_code_bound(py, code, name, name)?
299 .getattr(handler)?
300 .into())
301 })?;
302 let function = Function {
303 function,
304 return_field: return_type.into_field(name).into(),
305 mode,
306 };
307 self.functions.insert(name.to_string(), function);
308 Ok(())
309 }
310
311 /// Add a new aggregate function from Python code.
312 ///
313 /// # Arguments
314 ///
315 /// - `name`: The name of the function.
316 /// - `state_type`: The data type of the internal state.
317 /// - `output_type`: The data type of the aggregate value.
318 /// - `mode`: Whether the function will be called when some of its arguments are null.
319 /// - `code`: The Python code of the aggregate function.
320 ///
321 /// The code should define at least two functions:
322 ///
323 /// - `create_state() -> state`: Create a new state object.
324 /// - `accumulate(state, *args) -> state`: Accumulate a new value into the state, returning the updated state.
325 ///
326 /// optionally, the code can define:
327 ///
328 /// - `finish(state) -> value`: Get the result of the aggregate function.
329 /// If not defined, the state is returned as the result.
330 /// In this case, `output_type` must be the same as `state_type`.
331 /// - `retract(state, *args) -> state`: Retract a value from the state, returning the updated state.
332 /// - `merge(state, state) -> state`: Merge two states, returning the merged state.
333 ///
334 /// # Example
335 ///
336 /// ```
337 /// # use arrow_udf_runtime::python::Runtime;
338 /// # use arrow_udf_runtime::CallMode;
339 /// # use arrow_schema::DataType;
340 /// let mut runtime = Runtime::new().unwrap();
341 /// runtime
342 /// .add_aggregate(
343 /// "sum",
344 /// DataType::Int32, // state_type
345 /// DataType::Int32, // output_type
346 /// CallMode::ReturnNullOnNullInput,
347 /// r#"
348 /// def create_state():
349 /// return 0
350 ///
351 /// def accumulate(state, value):
352 /// return state + value
353 ///
354 /// def retract(state, value):
355 /// return state - value
356 ///
357 /// def merge(state1, state2):
358 /// return state1 + state2
359 /// "#,
360 /// )
361 /// .unwrap();
362 /// ```
363 pub fn add_aggregate(
364 &mut self,
365 name: &str,
366 state_type: impl IntoField,
367 output_type: impl IntoField,
368 mode: CallMode,
369 code: &str,
370 ) -> Result<()> {
371 let aggregate = self.interpreter.with_gil(|py| {
372 let module = PyModule::from_code_bound(py, code, name, name)?;
373 Ok(Aggregate {
374 state_field: state_type.into_field(name).into(),
375 output_field: output_type.into_field(name).into(),
376 mode,
377 create_state: module.getattr("create_state")?.into(),
378 accumulate: module.getattr("accumulate")?.into(),
379 retract: module.getattr("retract").ok().map(|f| f.into()),
380 finish: module.getattr("finish").ok().map(|f| f.into()),
381 merge: module.getattr("merge").ok().map(|f| f.into()),
382 })
383 })?;
384 if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
385 bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
386 }
387 self.aggregates.insert(name.to_string(), aggregate);
388 Ok(())
389 }
390
391 /// Remove a scalar or table function.
392 pub fn del_function(&mut self, name: &str) -> Result<()> {
393 let function = self.functions.remove(name).context("function not found")?;
394 _ = self.interpreter.with_gil(|_| {
395 drop(function);
396 Ok(())
397 });
398 Ok(())
399 }
400
401 /// Remove an aggregate function.
402 pub fn del_aggregate(&mut self, name: &str) -> Result<()> {
403 let aggregate = self.functions.remove(name).context("function not found")?;
404 _ = self.interpreter.with_gil(|_| {
405 drop(aggregate);
406 Ok(())
407 });
408 Ok(())
409 }
410
411 /// Call a scalar function.
412 ///
413 /// # Example
414 ///
415 /// ```
416 #[doc = include_str!("doc_create_function.txt")]
417 /// // suppose we have created a scalar function `gcd`
418 /// // see the example in `add_function`
419 ///
420 /// let schema = Schema::new(vec![
421 /// Field::new("x", DataType::Int32, true),
422 /// Field::new("y", DataType::Int32, true),
423 /// ]);
424 /// let arg0 = Int32Array::from(vec![Some(25), None]);
425 /// let arg1 = Int32Array::from(vec![Some(15), None]);
426 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0), Arc::new(arg1)]).unwrap();
427 ///
428 /// let output = runtime.call("gcd", &input).unwrap();
429 /// assert_eq!(&**output.column(0), &Int32Array::from(vec![Some(5), None]));
430 /// ```
431 pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
432 let function = self.functions.get(name).context("function not found")?;
433 // convert each row to python objects and call the function
434 let (output, error) = self.interpreter.with_gil(|py| {
435 let mut results = Vec::with_capacity(input.num_rows());
436 let mut errors = vec![];
437 let mut row = Vec::with_capacity(input.num_columns());
438 for i in 0..input.num_rows() {
439 if function.mode == CallMode::ReturnNullOnNullInput
440 && input.columns().iter().any(|column| column.is_null(i))
441 {
442 results.push(py.None());
443 continue;
444 }
445 row.clear();
446 for (column, field) in input.columns().iter().zip(input.schema().fields()) {
447 let pyobj = self.converter.get_pyobject(py, field, column, i)?;
448 row.push(pyobj);
449 }
450 let args = PyTuple::new_bound(py, row.drain(..));
451 match function.function.call1(py, args) {
452 Ok(result) => results.push(result),
453 Err(e) => {
454 results.push(py.None());
455 errors.push((i, e.to_string()));
456 }
457 }
458 }
459 let output = self
460 .converter
461 .build_array(&function.return_field, py, &results)?;
462 let error = build_error_array(input.num_rows(), errors);
463 Ok((output, error))
464 })?;
465 if let Some(error) = error {
466 let schema = Schema::new(vec![
467 function.return_field.clone(),
468 Field::new("error", DataType::Utf8, true).into(),
469 ]);
470 Ok(RecordBatch::try_new(Arc::new(schema), vec![output, error])?)
471 } else {
472 let schema = Schema::new(vec![function.return_field.clone()]);
473 Ok(RecordBatch::try_new(Arc::new(schema), vec![output])?)
474 }
475 }
476
477 /// Call a table function.
478 ///
479 /// # Example
480 ///
481 /// ```
482 #[doc = include_str!("doc_create_function.txt")]
483 /// // suppose we have created a table function `series`
484 /// // see the example in `add_function`
485 ///
486 /// let schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]);
487 /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3)]);
488 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
489 ///
490 /// let mut outputs = runtime.call_table_function("series", &input, 10).unwrap();
491 /// let output = outputs.next().unwrap().unwrap();
492 /// let pretty = arrow_cast::pretty::pretty_format_batches(&[output]).unwrap().to_string();
493 /// assert_eq!(pretty, r#"
494 /// +-----+--------+
495 /// | row | series |
496 /// +-----+--------+
497 /// | 0 | 0 |
498 /// | 2 | 0 |
499 /// | 2 | 1 |
500 /// | 2 | 2 |
501 /// +-----+--------+"#.trim());
502 /// ```
503 pub fn call_table_function<'a>(
504 &'a self,
505 name: &'a str,
506 input: &'a RecordBatch,
507 chunk_size: usize,
508 ) -> Result<RecordBatchIter<'a>> {
509 assert!(chunk_size > 0);
510 let function = self.functions.get(name).context("function not found")?;
511
512 // initial state
513 Ok(RecordBatchIter {
514 interpreter: &self.interpreter,
515 input,
516 function,
517 schema: Arc::new(Schema::new(vec![
518 Field::new("row", DataType::Int32, true).into(),
519 function.return_field.clone(),
520 ])),
521 chunk_size,
522 row: 0,
523 generator: None,
524 converter: &self.converter,
525 })
526 }
527
528 /// Create a new state for an aggregate function.
529 ///
530 /// # Example
531 /// ```
532 #[doc = include_str!("doc_create_aggregate.txt")]
533 /// let state = runtime.create_state("sum").unwrap();
534 /// assert_eq!(&*state, &Int32Array::from(vec![0]));
535 /// ```
536 pub fn create_state(&self, name: &str) -> Result<ArrayRef> {
537 let aggregate = self.aggregates.get(name).context("function not found")?;
538 let state = self.interpreter.with_gil(|py| {
539 let state = aggregate.create_state.call0(py)?;
540 let state = self
541 .converter
542 .build_array(&aggregate.state_field, py, &[state])?;
543 Ok(state)
544 })?;
545 Ok(state)
546 }
547
548 /// Call accumulate of an aggregate function.
549 ///
550 /// # Example
551 /// ```
552 #[doc = include_str!("doc_create_aggregate.txt")]
553 /// let state = runtime.create_state("sum").unwrap();
554 ///
555 /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
556 /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
557 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
558 ///
559 /// let state = runtime.accumulate("sum", &state, &input).unwrap();
560 /// assert_eq!(&*state, &Int32Array::from(vec![9]));
561 /// ```
562 pub fn accumulate(
563 &self,
564 name: &str,
565 state: &dyn Array,
566 input: &RecordBatch,
567 ) -> Result<ArrayRef> {
568 let aggregate = self.aggregates.get(name).context("function not found")?;
569 // convert each row to python objects and call the accumulate function
570 let new_state = self.interpreter.with_gil(|py| {
571 let mut state = self
572 .converter
573 .get_pyobject(py, &aggregate.state_field, state, 0)?;
574
575 let mut row = Vec::with_capacity(1 + input.num_columns());
576 for i in 0..input.num_rows() {
577 if aggregate.mode == CallMode::ReturnNullOnNullInput
578 && input.columns().iter().any(|column| column.is_null(i))
579 {
580 continue;
581 }
582 row.clear();
583 row.push(state.clone_ref(py));
584 for (column, field) in input.columns().iter().zip(input.schema().fields()) {
585 let pyobj = self.converter.get_pyobject(py, field, column, i)?;
586 row.push(pyobj);
587 }
588 let args = PyTuple::new_bound(py, row.drain(..));
589 state = aggregate.accumulate.call1(py, args)?;
590 }
591 let output = self
592 .converter
593 .build_array(&aggregate.state_field, py, &[state])?;
594 Ok(output)
595 })?;
596 Ok(new_state)
597 }
598
599 /// Call accumulate or retract of an aggregate function.
600 ///
601 /// The `ops` is a boolean array that indicates whether to accumulate or retract each row.
602 /// `false` for accumulate and `true` for retract.
603 ///
604 /// # Example
605 /// ```
606 #[doc = include_str!("doc_create_aggregate.txt")]
607 /// let state = runtime.create_state("sum").unwrap();
608 ///
609 /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
610 /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
611 /// let ops = BooleanArray::from(vec![false, false, true, false]);
612 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
613 ///
614 /// let state = runtime.accumulate_or_retract("sum", &state, &ops, &input).unwrap();
615 /// assert_eq!(&*state, &Int32Array::from(vec![3]));
616 /// ```
617 pub fn accumulate_or_retract(
618 &self,
619 name: &str,
620 state: &dyn Array,
621 ops: &BooleanArray,
622 input: &RecordBatch,
623 ) -> Result<ArrayRef> {
624 let aggregate = self.aggregates.get(name).context("function not found")?;
625 let retract = aggregate
626 .retract
627 .as_ref()
628 .context("function does not support retraction")?;
629 // convert each row to python objects and call the accumulate function
630 let new_state = self.interpreter.with_gil(|py| {
631 let mut state = self
632 .converter
633 .get_pyobject(py, &aggregate.state_field, state, 0)?;
634
635 let mut row = Vec::with_capacity(1 + input.num_columns());
636 for i in 0..input.num_rows() {
637 if aggregate.mode == CallMode::ReturnNullOnNullInput
638 && input.columns().iter().any(|column| column.is_null(i))
639 {
640 continue;
641 }
642 row.clear();
643 row.push(state.clone_ref(py));
644 for (column, field) in input.columns().iter().zip(input.schema().fields()) {
645 let pyobj = self.converter.get_pyobject(py, field, column, i)?;
646 row.push(pyobj);
647 }
648 let args = PyTuple::new_bound(py, row.drain(..));
649 let func = if ops.is_valid(i) && ops.value(i) {
650 retract
651 } else {
652 &aggregate.accumulate
653 };
654 state = func.call1(py, args)?;
655 }
656 let output = self
657 .converter
658 .build_array(&aggregate.state_field, py, &[state])?;
659 Ok(output)
660 })?;
661 Ok(new_state)
662 }
663
664 /// Merge states of an aggregate function.
665 ///
666 /// # Example
667 /// ```
668 #[doc = include_str!("doc_create_aggregate.txt")]
669 /// let states = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
670 ///
671 /// let state = runtime.merge("sum", &states).unwrap();
672 /// assert_eq!(&*state, &Int32Array::from(vec![9]));
673 /// ```
674 pub fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
675 let aggregate = self.aggregates.get(name).context("function not found")?;
676 let merge = aggregate.merge.as_ref().context("merge not found")?;
677 let output = self.interpreter.with_gil(|py| {
678 let mut state = self
679 .converter
680 .get_pyobject(py, &aggregate.state_field, states, 0)?;
681 for i in 1..states.len() {
682 if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
683 continue;
684 }
685 let state2 = self
686 .converter
687 .get_pyobject(py, &aggregate.state_field, states, i)?;
688 let args = PyTuple::new_bound(py, [state, state2]);
689 state = merge.call1(py, args)?;
690 }
691 let output = self
692 .converter
693 .build_array(&aggregate.state_field, py, &[state])?;
694 Ok(output)
695 })?;
696 Ok(output)
697 }
698
699 /// Get the result of an aggregate function.
700 ///
701 /// If the `finish` function is not defined, the state is returned as the result.
702 ///
703 /// # Example
704 /// ```
705 #[doc = include_str!("doc_create_aggregate.txt")]
706 /// let states: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)]));
707 ///
708 /// let outputs = runtime.finish("sum", &states).unwrap();
709 /// assert_eq!(&outputs, &states);
710 /// ```
711 pub fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
712 let aggregate = self.aggregates.get(name).context("function not found")?;
713 let Some(finish) = &aggregate.finish else {
714 return Ok(states.clone());
715 };
716 let output = self.interpreter.with_gil(|py| {
717 let mut results = Vec::with_capacity(states.len());
718 for i in 0..states.len() {
719 if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
720 results.push(py.None());
721 continue;
722 }
723 let state = self
724 .converter
725 .get_pyobject(py, &aggregate.state_field, states, i)?;
726 let args = PyTuple::new_bound(py, [state]);
727 let result = finish.call1(py, args)?;
728 results.push(result);
729 }
730 let output = self
731 .converter
732 .build_array(&aggregate.output_field, py, &results)?;
733 Ok(output)
734 })?;
735 Ok(output)
736 }
737}
738
739/// An iterator over the result of a table function.
740pub struct RecordBatchIter<'a> {
741 interpreter: &'a SubInterpreter,
742 input: &'a RecordBatch,
743 function: &'a Function,
744 schema: SchemaRef,
745 chunk_size: usize,
746 // mutable states
747 /// Current row index.
748 row: usize,
749 /// Generator of the current row.
750 generator: Option<Py<PyIterator>>,
751 converter: &'a pyarrow::Converter,
752}
753
754impl RecordBatchIter<'_> {
755 /// Get the schema of the output.
756 pub fn schema(&self) -> &Schema {
757 &self.schema
758 }
759
760 fn next(&mut self) -> Result<Option<RecordBatch>> {
761 if self.row == self.input.num_rows() {
762 return Ok(None);
763 }
764 let batch = self.interpreter.with_gil(|py| {
765 let mut indexes = Int32Builder::with_capacity(self.chunk_size);
766 let mut results = Vec::with_capacity(self.input.num_rows());
767 let mut errors = vec![];
768 let mut row = Vec::with_capacity(self.input.num_columns());
769 while self.row < self.input.num_rows() && results.len() < self.chunk_size {
770 let generator = if let Some(g) = self.generator.as_ref() {
771 g
772 } else {
773 // call the table function to get a generator
774 if self.function.mode == CallMode::ReturnNullOnNullInput
775 && (self.input.columns().iter()).any(|column| column.is_null(self.row))
776 {
777 self.row += 1;
778 continue;
779 }
780 row.clear();
781 for (column, field) in
782 (self.input.columns().iter()).zip(self.input.schema().fields())
783 {
784 let val = self.converter.get_pyobject(py, field, column, self.row)?;
785 row.push(val);
786 }
787 let args = PyTuple::new_bound(py, row.drain(..));
788 match self.function.function.bind(py).call1(args) {
789 Ok(result) => {
790 let iter = result.iter()?.into();
791 self.generator.insert(iter)
792 }
793 Err(e) => {
794 // append a row with null value and error message
795 indexes.append_value(self.row as i32);
796 results.push(py.None());
797 errors.push((indexes.len(), e.to_string()));
798 self.row += 1;
799 continue;
800 }
801 }
802 };
803 match generator.bind(py).clone().next() {
804 Some(Ok(value)) => {
805 indexes.append_value(self.row as i32);
806 results.push(value.into());
807 }
808 Some(Err(e)) => {
809 indexes.append_value(self.row as i32);
810 results.push(py.None());
811 errors.push((indexes.len(), e.to_string()));
812 self.row += 1;
813 self.generator = None;
814 }
815 None => {
816 self.row += 1;
817 self.generator = None;
818 }
819 }
820 }
821
822 if results.is_empty() {
823 return Ok(None);
824 }
825 let indexes = Arc::new(indexes.finish());
826 let output = self
827 .converter
828 .build_array(&self.function.return_field, py, &results)
829 .context("failed to build arrow array from return values")?;
830 let error = build_error_array(indexes.len(), errors);
831 if let Some(error) = error {
832 Ok(Some(
833 RecordBatch::try_new(
834 Arc::new(append_error_to_schema(&self.schema)),
835 vec![indexes, output, error],
836 )
837 .unwrap(),
838 ))
839 } else {
840 Ok(Some(
841 RecordBatch::try_new(self.schema.clone(), vec![indexes, output]).unwrap(),
842 ))
843 }
844 })?;
845 Ok(batch)
846 }
847}
848
849impl Iterator for RecordBatchIter<'_> {
850 type Item = Result<RecordBatch>;
851 fn next(&mut self) -> Option<Self::Item> {
852 self.next().transpose()
853 }
854}
855
856impl Drop for RecordBatchIter<'_> {
857 fn drop(&mut self) {
858 if let Some(generator) = self.generator.take() {
859 _ = self.interpreter.with_gil(|_| {
860 drop(generator);
861 Ok(())
862 });
863 }
864 }
865}
866
867impl Drop for Runtime {
868 fn drop(&mut self) {
869 // `PyObject` must be dropped inside the interpreter
870 _ = self.interpreter.with_gil(|_| {
871 self.functions.clear();
872 self.aggregates.clear();
873 Ok(())
874 });
875 }
876}
877
878fn build_error_array(num_rows: usize, errors: Vec<(usize, String)>) -> Option<ArrayRef> {
879 if errors.is_empty() {
880 return None;
881 }
882 let data_capacity = errors.iter().map(|(i, _)| i).sum();
883 let mut builder = StringBuilder::with_capacity(num_rows, data_capacity);
884 for (i, msg) in errors {
885 while builder.len() + 1 < i {
886 builder.append_null();
887 }
888 builder.append_value(&msg);
889 }
890 while builder.len() < num_rows {
891 builder.append_null();
892 }
893 Some(Arc::new(builder.finish()))
894}
895
896/// Append an error field to the schema.
897fn append_error_to_schema(schema: &Schema) -> Schema {
898 let mut fields = schema.fields().to_vec();
899 fields.push(Arc::new(Field::new("error", DataType::Utf8, true)));
900 Schema::new(fields)
901}