arrow_udf_runtime/python/
mod.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("README.md")]
16
17// Notice for developers:
18// This library uses the sub-interpreter and per-interpreter GIL introduced in Python 3.12
19// to support concurrent execution of different functions in multiple threads.
20// However, pyo3 has not yet safely supported sub-interpreter. We use the raw FFI API of pyo3 to implement sub-interpreter.
21// Therefore, special attention is needed:
22// **All PyObject created in a sub-interpreter must be destroyed in the same sub-interpreter.**
23// Otherwise, it will cause a crash the next time Python is called.
24// Special attention is needed for PyErr in PyResult.
25// Remember to convert `PyErr` using the `pyerr_to_anyhow` function before passing it out of the sub-interpreter.
26
27use crate::into_field::IntoField;
28use crate::CallMode;
29
30use self::interpreter::SubInterpreter;
31use anyhow::{bail, Context, Result};
32use arrow_array::builder::{ArrayBuilder, Int32Builder, StringBuilder};
33use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
34use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
35use pyo3::types::{PyAnyMethods, PyIterator, PyModule, PyTuple};
36use pyo3::{Py, PyObject};
37use std::collections::HashMap;
38use std::fmt::Debug;
39use std::sync::Arc;
40
41// #[cfg(Py_3_12)]
42mod interpreter;
43mod pyarrow;
44
45/// A runtime to execute user defined functions in Python.
46///
47/// # Usages
48///
49/// - Create a new runtime with [`Runtime::new`] or [`Runtime::builder`].
50/// - For scalar functions, use [`add_function`] and [`call`].
51/// - For table functions, use [`add_function`] and [`call_table_function`].
52/// - For aggregate functions, create the function with [`add_aggregate`], and then
53///     - create a new state with [`create_state`],
54///     - update the state with [`accumulate`] or [`accumulate_or_retract`],
55///     - merge states with [`merge`],
56///     - finally get the result with [`finish`].
57///
58/// Click on each function to see the example.
59///
60/// # Parallelism
61///
62/// As we know, Python has a Global Interpreter Lock (GIL) that prevents multiple threads from executing Python code simultaneously.
63/// To work around this limitation, each runtime creates a sub-interpreter with its own GIL. This feature requires Python 3.12 or later.
64///
65/// [`add_function`]: Runtime::add_function
66/// [`add_aggregate`]: Runtime::add_aggregate
67/// [`call`]: Runtime::call
68/// [`call_table_function`]: Runtime::call_table_function
69/// [`create_state`]: Runtime::create_state
70/// [`accumulate`]: Runtime::accumulate
71/// [`accumulate_or_retract`]: Runtime::accumulate_or_retract
72/// [`merge`]: Runtime::merge
73/// [`finish`]: Runtime::finish
74pub struct Runtime {
75    interpreter: SubInterpreter,
76    functions: HashMap<String, Function>,
77    aggregates: HashMap<String, Aggregate>,
78    converter: pyarrow::Converter,
79}
80
81impl Debug for Runtime {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("Runtime")
84            .field("functions", &self.functions.keys())
85            .field("aggregates", &self.aggregates.keys())
86            .finish()
87    }
88}
89
90/// A user defined function.
91struct Function {
92    function: PyObject,
93    return_field: FieldRef,
94    mode: CallMode,
95}
96
97/// A user defined aggregate function.
98struct Aggregate {
99    state_field: FieldRef,
100    output_field: FieldRef,
101    mode: CallMode,
102    create_state: PyObject,
103    accumulate: PyObject,
104    retract: Option<PyObject>,
105    finish: Option<PyObject>,
106    merge: Option<PyObject>,
107}
108
109/// A builder for `Runtime`.
110#[derive(Default, Debug)]
111pub struct Builder {
112    sandboxed: bool,
113    removed_symbols: Vec<String>,
114}
115
116impl Builder {
117    /// Set whether the runtime is sandboxed.
118    ///
119    /// When sandboxed, only a limited set of modules can be imported, and some built-in functions are disabled.
120    /// This is useful for running untrusted code.
121    ///
122    /// Allowed modules: `json`, `decimal`, `re`, `math`, `datetime`, `time`.
123    ///
124    /// Disallowed builtins: `breakpoint`, `exit`, `eval`, `help`, `input`, `open`, `print`.
125    ///
126    /// The default is `false`.
127    pub fn sandboxed(mut self, sandboxed: bool) -> Self {
128        self.sandboxed = sandboxed;
129        self.remove_symbol("__builtins__.breakpoint")
130            .remove_symbol("__builtins__.exit")
131            .remove_symbol("__builtins__.eval")
132            .remove_symbol("__builtins__.help")
133            .remove_symbol("__builtins__.input")
134            .remove_symbol("__builtins__.open")
135            .remove_symbol("__builtins__.print")
136    }
137
138    /// Remove a symbol from builtins.
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// # use arrow_udf_runtime::python::Runtime;
144    /// let builder = Runtime::builder().remove_symbol("__builtins__.eval");
145    /// ```
146    pub fn remove_symbol(mut self, symbol: &str) -> Self {
147        self.removed_symbols.push(symbol.to_string());
148        self
149    }
150
151    /// Build the `Runtime`.
152    pub fn build(self) -> Result<Runtime> {
153        let interpreter = SubInterpreter::new()?;
154        interpreter.run(
155            r#"
156# internal use for json types
157import json
158import pickle
159import decimal
160
161# an internal class used for struct input arguments
162class Struct:
163    pass
164"#,
165        )?;
166        if self.sandboxed {
167            let mut script = r#"
168# limit the modules that can be imported
169original_import = __builtins__.__import__
170
171def limited_import(name, globals=None, locals=None, fromlist=(), level=0):
172    # FIXME: 'sys' should not be allowed, but it is required by 'decimal'
173    # FIXME: 'time.sleep' should not be allowed, but 'time' is required by 'datetime'
174    allowlist = (
175        'json',
176        'decimal',
177        're',
178        'math',
179        'datetime',
180        'time',
181        'operator',
182        'numbers',
183        'abc',
184        'sys',
185        'contextvars',
186        '_io',
187        '_contextvars',
188        '_pydecimal',
189        '_pydatetime',
190    )
191    if level == 0 and name in allowlist:
192        return original_import(name, globals, locals, fromlist, level)
193    raise ImportError(f'import {name} is not allowed')
194
195__builtins__.__import__ = limited_import
196del limited_import
197"#
198            .to_string();
199            for symbol in self.removed_symbols {
200                script.push_str(&format!("del {}\n", symbol));
201            }
202            interpreter.run(&script)?;
203        }
204        Ok(Runtime {
205            interpreter,
206            functions: HashMap::new(),
207            aggregates: HashMap::new(),
208            converter: pyarrow::Converter::new(),
209        })
210    }
211}
212
213impl Runtime {
214    /// Create a new `Runtime`.
215    pub fn new() -> Result<Self> {
216        Builder::default().build()
217    }
218
219    /// Return a new builder for `Runtime`.
220    pub fn builder() -> Builder {
221        Builder::default()
222    }
223
224    /// Add a new scalar function or table function.
225    ///
226    /// # Arguments
227    ///
228    /// - `name`: The name of the function.
229    /// - `return_type`: The data type of the return value.
230    /// - `mode`: Whether the function will be called when some of its arguments are null.
231    /// - `code`: The Python code of the function.
232    ///
233    /// The code should define a function with the same name as the function.
234    /// The function should return a value for scalar functions, or yield values for table functions.
235    ///
236    /// # Example
237    ///
238    /// ```
239    /// # use arrow_udf_runtime::python::Runtime;
240    /// # use arrow_udf_runtime::CallMode;
241    /// # use arrow_schema::DataType;
242    /// let mut runtime = Runtime::new().unwrap();
243    /// // add a scalar function
244    /// runtime
245    ///     .add_function(
246    ///         "gcd",
247    ///         DataType::Int32,
248    ///         CallMode::ReturnNullOnNullInput,
249    ///         r#"
250    /// def gcd(a: int, b: int) -> int:
251    ///     while b:
252    ///         a, b = b, a % b
253    ///     return a
254    /// "#,
255    ///     )
256    ///     .unwrap();
257    /// // add a table function
258    /// runtime
259    ///     .add_function(
260    ///         "series",
261    ///         DataType::Int32,
262    ///         CallMode::ReturnNullOnNullInput,
263    ///         r#"
264    /// def series(n: int):
265    ///     for i in range(n):
266    ///         yield i
267    /// "#,
268    ///     )
269    ///     .unwrap();
270    /// ```
271    pub fn add_function(
272        &mut self,
273        name: &str,
274        return_type: impl IntoField,
275        mode: CallMode,
276        code: &str,
277    ) -> Result<()> {
278        self.add_function_with_handler(name, return_type, mode, code, name)
279    }
280
281    /// Add a new scalar function or table function with custom handler name.
282    ///
283    /// # Arguments
284    ///
285    /// - `handler`: The name of function in Python code to be called.
286    /// - others: Same as [`add_function`].
287    ///
288    /// [`add_function`]: Runtime::add_function
289    pub fn add_function_with_handler(
290        &mut self,
291        name: &str,
292        return_type: impl IntoField,
293        mode: CallMode,
294        code: &str,
295        handler: &str,
296    ) -> Result<()> {
297        let function = self.interpreter.with_gil(|py| {
298            Ok(PyModule::from_code_bound(py, code, name, name)?
299                .getattr(handler)?
300                .into())
301        })?;
302        let function = Function {
303            function,
304            return_field: return_type.into_field(name).into(),
305            mode,
306        };
307        self.functions.insert(name.to_string(), function);
308        Ok(())
309    }
310
311    /// Add a new aggregate function from Python code.
312    ///
313    /// # Arguments
314    ///
315    /// - `name`: The name of the function.
316    /// - `state_type`: The data type of the internal state.
317    /// - `output_type`: The data type of the aggregate value.
318    /// - `mode`: Whether the function will be called when some of its arguments are null.
319    /// - `code`: The Python code of the aggregate function.
320    ///
321    /// The code should define at least two functions:
322    ///
323    /// - `create_state() -> state`: Create a new state object.
324    /// - `accumulate(state, *args) -> state`: Accumulate a new value into the state, returning the updated state.
325    ///
326    /// optionally, the code can define:
327    ///
328    /// - `finish(state) -> value`: Get the result of the aggregate function.
329    ///     If not defined, the state is returned as the result.
330    ///     In this case, `output_type` must be the same as `state_type`.
331    /// - `retract(state, *args) -> state`: Retract a value from the state, returning the updated state.
332    /// - `merge(state, state) -> state`: Merge two states, returning the merged state.
333    ///
334    /// # Example
335    ///
336    /// ```
337    /// # use arrow_udf_runtime::python::Runtime;
338    /// # use arrow_udf_runtime::CallMode;
339    /// # use arrow_schema::DataType;
340    /// let mut runtime = Runtime::new().unwrap();
341    /// runtime
342    ///     .add_aggregate(
343    ///         "sum",
344    ///         DataType::Int32, // state_type
345    ///         DataType::Int32, // output_type
346    ///         CallMode::ReturnNullOnNullInput,
347    ///         r#"
348    /// def create_state():
349    ///     return 0
350    ///
351    /// def accumulate(state, value):
352    ///     return state + value
353    ///
354    /// def retract(state, value):
355    ///     return state - value
356    ///
357    /// def merge(state1, state2):
358    ///     return state1 + state2
359    ///         "#,
360    ///     )
361    ///     .unwrap();
362    /// ```
363    pub fn add_aggregate(
364        &mut self,
365        name: &str,
366        state_type: impl IntoField,
367        output_type: impl IntoField,
368        mode: CallMode,
369        code: &str,
370    ) -> Result<()> {
371        let aggregate = self.interpreter.with_gil(|py| {
372            let module = PyModule::from_code_bound(py, code, name, name)?;
373            Ok(Aggregate {
374                state_field: state_type.into_field(name).into(),
375                output_field: output_type.into_field(name).into(),
376                mode,
377                create_state: module.getattr("create_state")?.into(),
378                accumulate: module.getattr("accumulate")?.into(),
379                retract: module.getattr("retract").ok().map(|f| f.into()),
380                finish: module.getattr("finish").ok().map(|f| f.into()),
381                merge: module.getattr("merge").ok().map(|f| f.into()),
382            })
383        })?;
384        if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
385            bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
386        }
387        self.aggregates.insert(name.to_string(), aggregate);
388        Ok(())
389    }
390
391    /// Remove a scalar or table function.
392    pub fn del_function(&mut self, name: &str) -> Result<()> {
393        let function = self.functions.remove(name).context("function not found")?;
394        _ = self.interpreter.with_gil(|_| {
395            drop(function);
396            Ok(())
397        });
398        Ok(())
399    }
400
401    /// Remove an aggregate function.
402    pub fn del_aggregate(&mut self, name: &str) -> Result<()> {
403        let aggregate = self.functions.remove(name).context("function not found")?;
404        _ = self.interpreter.with_gil(|_| {
405            drop(aggregate);
406            Ok(())
407        });
408        Ok(())
409    }
410
411    /// Call a scalar function.
412    ///
413    /// # Example
414    ///
415    /// ```
416    #[doc = include_str!("doc_create_function.txt")]
417    /// // suppose we have created a scalar function `gcd`
418    /// // see the example in `add_function`
419    ///
420    /// let schema = Schema::new(vec![
421    ///     Field::new("x", DataType::Int32, true),
422    ///     Field::new("y", DataType::Int32, true),
423    /// ]);
424    /// let arg0 = Int32Array::from(vec![Some(25), None]);
425    /// let arg1 = Int32Array::from(vec![Some(15), None]);
426    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0), Arc::new(arg1)]).unwrap();
427    ///
428    /// let output = runtime.call("gcd", &input).unwrap();
429    /// assert_eq!(&**output.column(0), &Int32Array::from(vec![Some(5), None]));
430    /// ```
431    pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
432        let function = self.functions.get(name).context("function not found")?;
433        // convert each row to python objects and call the function
434        let (output, error) = self.interpreter.with_gil(|py| {
435            let mut results = Vec::with_capacity(input.num_rows());
436            let mut errors = vec![];
437            let mut row = Vec::with_capacity(input.num_columns());
438            for i in 0..input.num_rows() {
439                if function.mode == CallMode::ReturnNullOnNullInput
440                    && input.columns().iter().any(|column| column.is_null(i))
441                {
442                    results.push(py.None());
443                    continue;
444                }
445                row.clear();
446                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
447                    let pyobj = self.converter.get_pyobject(py, field, column, i)?;
448                    row.push(pyobj);
449                }
450                let args = PyTuple::new_bound(py, row.drain(..));
451                match function.function.call1(py, args) {
452                    Ok(result) => results.push(result),
453                    Err(e) => {
454                        results.push(py.None());
455                        errors.push((i, e.to_string()));
456                    }
457                }
458            }
459            let output = self
460                .converter
461                .build_array(&function.return_field, py, &results)?;
462            let error = build_error_array(input.num_rows(), errors);
463            Ok((output, error))
464        })?;
465        if let Some(error) = error {
466            let schema = Schema::new(vec![
467                function.return_field.clone(),
468                Field::new("error", DataType::Utf8, true).into(),
469            ]);
470            Ok(RecordBatch::try_new(Arc::new(schema), vec![output, error])?)
471        } else {
472            let schema = Schema::new(vec![function.return_field.clone()]);
473            Ok(RecordBatch::try_new(Arc::new(schema), vec![output])?)
474        }
475    }
476
477    /// Call a table function.
478    ///
479    /// # Example
480    ///
481    /// ```
482    #[doc = include_str!("doc_create_function.txt")]
483    /// // suppose we have created a table function `series`
484    /// // see the example in `add_function`
485    ///
486    /// let schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]);
487    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3)]);
488    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
489    ///
490    /// let mut outputs = runtime.call_table_function("series", &input, 10).unwrap();
491    /// let output = outputs.next().unwrap().unwrap();
492    /// let pretty = arrow_cast::pretty::pretty_format_batches(&[output]).unwrap().to_string();
493    /// assert_eq!(pretty, r#"
494    /// +-----+--------+
495    /// | row | series |
496    /// +-----+--------+
497    /// | 0   | 0      |
498    /// | 2   | 0      |
499    /// | 2   | 1      |
500    /// | 2   | 2      |
501    /// +-----+--------+"#.trim());
502    /// ```
503    pub fn call_table_function<'a>(
504        &'a self,
505        name: &'a str,
506        input: &'a RecordBatch,
507        chunk_size: usize,
508    ) -> Result<RecordBatchIter<'a>> {
509        assert!(chunk_size > 0);
510        let function = self.functions.get(name).context("function not found")?;
511
512        // initial state
513        Ok(RecordBatchIter {
514            interpreter: &self.interpreter,
515            input,
516            function,
517            schema: Arc::new(Schema::new(vec![
518                Field::new("row", DataType::Int32, true).into(),
519                function.return_field.clone(),
520            ])),
521            chunk_size,
522            row: 0,
523            generator: None,
524            converter: &self.converter,
525        })
526    }
527
528    /// Create a new state for an aggregate function.
529    ///
530    /// # Example
531    /// ```
532    #[doc = include_str!("doc_create_aggregate.txt")]
533    /// let state = runtime.create_state("sum").unwrap();
534    /// assert_eq!(&*state, &Int32Array::from(vec![0]));
535    /// ```
536    pub fn create_state(&self, name: &str) -> Result<ArrayRef> {
537        let aggregate = self.aggregates.get(name).context("function not found")?;
538        let state = self.interpreter.with_gil(|py| {
539            let state = aggregate.create_state.call0(py)?;
540            let state = self
541                .converter
542                .build_array(&aggregate.state_field, py, &[state])?;
543            Ok(state)
544        })?;
545        Ok(state)
546    }
547
548    /// Call accumulate of an aggregate function.
549    ///
550    /// # Example
551    /// ```
552    #[doc = include_str!("doc_create_aggregate.txt")]
553    /// let state = runtime.create_state("sum").unwrap();
554    ///
555    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
556    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
557    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
558    ///
559    /// let state = runtime.accumulate("sum", &state, &input).unwrap();
560    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
561    /// ```
562    pub fn accumulate(
563        &self,
564        name: &str,
565        state: &dyn Array,
566        input: &RecordBatch,
567    ) -> Result<ArrayRef> {
568        let aggregate = self.aggregates.get(name).context("function not found")?;
569        // convert each row to python objects and call the accumulate function
570        let new_state = self.interpreter.with_gil(|py| {
571            let mut state = self
572                .converter
573                .get_pyobject(py, &aggregate.state_field, state, 0)?;
574
575            let mut row = Vec::with_capacity(1 + input.num_columns());
576            for i in 0..input.num_rows() {
577                if aggregate.mode == CallMode::ReturnNullOnNullInput
578                    && input.columns().iter().any(|column| column.is_null(i))
579                {
580                    continue;
581                }
582                row.clear();
583                row.push(state.clone_ref(py));
584                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
585                    let pyobj = self.converter.get_pyobject(py, field, column, i)?;
586                    row.push(pyobj);
587                }
588                let args = PyTuple::new_bound(py, row.drain(..));
589                state = aggregate.accumulate.call1(py, args)?;
590            }
591            let output = self
592                .converter
593                .build_array(&aggregate.state_field, py, &[state])?;
594            Ok(output)
595        })?;
596        Ok(new_state)
597    }
598
599    /// Call accumulate or retract of an aggregate function.
600    ///
601    /// The `ops` is a boolean array that indicates whether to accumulate or retract each row.
602    /// `false` for accumulate and `true` for retract.
603    ///
604    /// # Example
605    /// ```
606    #[doc = include_str!("doc_create_aggregate.txt")]
607    /// let state = runtime.create_state("sum").unwrap();
608    ///
609    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
610    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
611    /// let ops = BooleanArray::from(vec![false, false, true, false]);
612    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
613    ///
614    /// let state = runtime.accumulate_or_retract("sum", &state, &ops, &input).unwrap();
615    /// assert_eq!(&*state, &Int32Array::from(vec![3]));
616    /// ```
617    pub fn accumulate_or_retract(
618        &self,
619        name: &str,
620        state: &dyn Array,
621        ops: &BooleanArray,
622        input: &RecordBatch,
623    ) -> Result<ArrayRef> {
624        let aggregate = self.aggregates.get(name).context("function not found")?;
625        let retract = aggregate
626            .retract
627            .as_ref()
628            .context("function does not support retraction")?;
629        // convert each row to python objects and call the accumulate function
630        let new_state = self.interpreter.with_gil(|py| {
631            let mut state = self
632                .converter
633                .get_pyobject(py, &aggregate.state_field, state, 0)?;
634
635            let mut row = Vec::with_capacity(1 + input.num_columns());
636            for i in 0..input.num_rows() {
637                if aggregate.mode == CallMode::ReturnNullOnNullInput
638                    && input.columns().iter().any(|column| column.is_null(i))
639                {
640                    continue;
641                }
642                row.clear();
643                row.push(state.clone_ref(py));
644                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
645                    let pyobj = self.converter.get_pyobject(py, field, column, i)?;
646                    row.push(pyobj);
647                }
648                let args = PyTuple::new_bound(py, row.drain(..));
649                let func = if ops.is_valid(i) && ops.value(i) {
650                    retract
651                } else {
652                    &aggregate.accumulate
653                };
654                state = func.call1(py, args)?;
655            }
656            let output = self
657                .converter
658                .build_array(&aggregate.state_field, py, &[state])?;
659            Ok(output)
660        })?;
661        Ok(new_state)
662    }
663
664    /// Merge states of an aggregate function.
665    ///
666    /// # Example
667    /// ```
668    #[doc = include_str!("doc_create_aggregate.txt")]
669    /// let states = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
670    ///
671    /// let state = runtime.merge("sum", &states).unwrap();
672    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
673    /// ```
674    pub fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
675        let aggregate = self.aggregates.get(name).context("function not found")?;
676        let merge = aggregate.merge.as_ref().context("merge not found")?;
677        let output = self.interpreter.with_gil(|py| {
678            let mut state = self
679                .converter
680                .get_pyobject(py, &aggregate.state_field, states, 0)?;
681            for i in 1..states.len() {
682                if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
683                    continue;
684                }
685                let state2 = self
686                    .converter
687                    .get_pyobject(py, &aggregate.state_field, states, i)?;
688                let args = PyTuple::new_bound(py, [state, state2]);
689                state = merge.call1(py, args)?;
690            }
691            let output = self
692                .converter
693                .build_array(&aggregate.state_field, py, &[state])?;
694            Ok(output)
695        })?;
696        Ok(output)
697    }
698
699    /// Get the result of an aggregate function.
700    ///
701    /// If the `finish` function is not defined, the state is returned as the result.
702    ///
703    /// # Example
704    /// ```
705    #[doc = include_str!("doc_create_aggregate.txt")]
706    /// let states: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)]));
707    ///
708    /// let outputs = runtime.finish("sum", &states).unwrap();
709    /// assert_eq!(&outputs, &states);
710    /// ```
711    pub fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
712        let aggregate = self.aggregates.get(name).context("function not found")?;
713        let Some(finish) = &aggregate.finish else {
714            return Ok(states.clone());
715        };
716        let output = self.interpreter.with_gil(|py| {
717            let mut results = Vec::with_capacity(states.len());
718            for i in 0..states.len() {
719                if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
720                    results.push(py.None());
721                    continue;
722                }
723                let state = self
724                    .converter
725                    .get_pyobject(py, &aggregate.state_field, states, i)?;
726                let args = PyTuple::new_bound(py, [state]);
727                let result = finish.call1(py, args)?;
728                results.push(result);
729            }
730            let output = self
731                .converter
732                .build_array(&aggregate.output_field, py, &results)?;
733            Ok(output)
734        })?;
735        Ok(output)
736    }
737}
738
739/// An iterator over the result of a table function.
740pub struct RecordBatchIter<'a> {
741    interpreter: &'a SubInterpreter,
742    input: &'a RecordBatch,
743    function: &'a Function,
744    schema: SchemaRef,
745    chunk_size: usize,
746    // mutable states
747    /// Current row index.
748    row: usize,
749    /// Generator of the current row.
750    generator: Option<Py<PyIterator>>,
751    converter: &'a pyarrow::Converter,
752}
753
754impl RecordBatchIter<'_> {
755    /// Get the schema of the output.
756    pub fn schema(&self) -> &Schema {
757        &self.schema
758    }
759
760    fn next(&mut self) -> Result<Option<RecordBatch>> {
761        if self.row == self.input.num_rows() {
762            return Ok(None);
763        }
764        let batch = self.interpreter.with_gil(|py| {
765            let mut indexes = Int32Builder::with_capacity(self.chunk_size);
766            let mut results = Vec::with_capacity(self.input.num_rows());
767            let mut errors = vec![];
768            let mut row = Vec::with_capacity(self.input.num_columns());
769            while self.row < self.input.num_rows() && results.len() < self.chunk_size {
770                let generator = if let Some(g) = self.generator.as_ref() {
771                    g
772                } else {
773                    // call the table function to get a generator
774                    if self.function.mode == CallMode::ReturnNullOnNullInput
775                        && (self.input.columns().iter()).any(|column| column.is_null(self.row))
776                    {
777                        self.row += 1;
778                        continue;
779                    }
780                    row.clear();
781                    for (column, field) in
782                        (self.input.columns().iter()).zip(self.input.schema().fields())
783                    {
784                        let val = self.converter.get_pyobject(py, field, column, self.row)?;
785                        row.push(val);
786                    }
787                    let args = PyTuple::new_bound(py, row.drain(..));
788                    match self.function.function.bind(py).call1(args) {
789                        Ok(result) => {
790                            let iter = result.iter()?.into();
791                            self.generator.insert(iter)
792                        }
793                        Err(e) => {
794                            // append a row with null value and error message
795                            indexes.append_value(self.row as i32);
796                            results.push(py.None());
797                            errors.push((indexes.len(), e.to_string()));
798                            self.row += 1;
799                            continue;
800                        }
801                    }
802                };
803                match generator.bind(py).clone().next() {
804                    Some(Ok(value)) => {
805                        indexes.append_value(self.row as i32);
806                        results.push(value.into());
807                    }
808                    Some(Err(e)) => {
809                        indexes.append_value(self.row as i32);
810                        results.push(py.None());
811                        errors.push((indexes.len(), e.to_string()));
812                        self.row += 1;
813                        self.generator = None;
814                    }
815                    None => {
816                        self.row += 1;
817                        self.generator = None;
818                    }
819                }
820            }
821
822            if results.is_empty() {
823                return Ok(None);
824            }
825            let indexes = Arc::new(indexes.finish());
826            let output = self
827                .converter
828                .build_array(&self.function.return_field, py, &results)
829                .context("failed to build arrow array from return values")?;
830            let error = build_error_array(indexes.len(), errors);
831            if let Some(error) = error {
832                Ok(Some(
833                    RecordBatch::try_new(
834                        Arc::new(append_error_to_schema(&self.schema)),
835                        vec![indexes, output, error],
836                    )
837                    .unwrap(),
838                ))
839            } else {
840                Ok(Some(
841                    RecordBatch::try_new(self.schema.clone(), vec![indexes, output]).unwrap(),
842                ))
843            }
844        })?;
845        Ok(batch)
846    }
847}
848
849impl Iterator for RecordBatchIter<'_> {
850    type Item = Result<RecordBatch>;
851    fn next(&mut self) -> Option<Self::Item> {
852        self.next().transpose()
853    }
854}
855
856impl Drop for RecordBatchIter<'_> {
857    fn drop(&mut self) {
858        if let Some(generator) = self.generator.take() {
859            _ = self.interpreter.with_gil(|_| {
860                drop(generator);
861                Ok(())
862            });
863        }
864    }
865}
866
867impl Drop for Runtime {
868    fn drop(&mut self) {
869        // `PyObject` must be dropped inside the interpreter
870        _ = self.interpreter.with_gil(|_| {
871            self.functions.clear();
872            self.aggregates.clear();
873            Ok(())
874        });
875    }
876}
877
878fn build_error_array(num_rows: usize, errors: Vec<(usize, String)>) -> Option<ArrayRef> {
879    if errors.is_empty() {
880        return None;
881    }
882    let data_capacity = errors.iter().map(|(i, _)| i).sum();
883    let mut builder = StringBuilder::with_capacity(num_rows, data_capacity);
884    for (i, msg) in errors {
885        while builder.len() + 1 < i {
886            builder.append_null();
887        }
888        builder.append_value(&msg);
889    }
890    while builder.len() < num_rows {
891        builder.append_null();
892    }
893    Some(Arc::new(builder.finish()))
894}
895
896/// Append an error field to the schema.
897fn append_error_to_schema(schema: &Schema) -> Schema {
898    let mut fields = schema.fields().to_vec();
899    fields.push(Arc::new(Field::new("error", DataType::Utf8, true)));
900    Schema::new(fields)
901}