arrow_udf_python/
lib.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("../README.md")]
16
17// Notice for developers:
18// This library uses the sub-interpreter and per-interpreter GIL introduced in Python 3.12
19// to support concurrent execution of different functions in multiple threads.
20// However, pyo3 has not yet safely supported sub-interpreter. We use the raw FFI API of pyo3 to implement sub-interpreter.
21// Therefore, special attention is needed:
22// **All PyObject created in a sub-interpreter must be destroyed in the same sub-interpreter.**
23// Otherwise, it will cause a crash the next time Python is called.
24// Special attention is needed for PyErr in PyResult.
25// Remember to convert `PyErr` using the `pyerr_to_anyhow` function before passing it out of the sub-interpreter.
26
27use self::interpreter::SubInterpreter;
28pub use self::into_field::IntoField;
29use anyhow::{bail, Context, Result};
30use arrow_array::builder::{ArrayBuilder, Int32Builder, StringBuilder};
31use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
32use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
33use pyo3::types::{PyAnyMethods, PyIterator, PyModule, PyTuple};
34use pyo3::{Py, PyObject};
35use std::collections::HashMap;
36use std::fmt::Debug;
37use std::sync::Arc;
38
39// #[cfg(Py_3_12)]
40mod interpreter;
41mod into_field;
42mod pyarrow;
43
44/// A runtime to execute user defined functions in Python.
45///
46/// # Usages
47///
48/// - Create a new runtime with [`Runtime::new`] or [`Runtime::builder`].
49/// - For scalar functions, use [`add_function`] and [`call`].
50/// - For table functions, use [`add_function`] and [`call_table_function`].
51/// - For aggregate functions, create the function with [`add_aggregate`], and then
52///     - create a new state with [`create_state`],
53///     - update the state with [`accumulate`] or [`accumulate_or_retract`],
54///     - merge states with [`merge`],
55///     - finally get the result with [`finish`].
56///
57/// Click on each function to see the example.
58///
59/// # Parallelism
60///
61/// As we know, Python has a Global Interpreter Lock (GIL) that prevents multiple threads from executing Python code simultaneously.
62/// To work around this limitation, each runtime creates a sub-interpreter with its own GIL. This feature requires Python 3.12 or later.
63///
64/// [`add_function`]: Runtime::add_function
65/// [`add_aggregate`]: Runtime::add_aggregate
66/// [`call`]: Runtime::call
67/// [`call_table_function`]: Runtime::call_table_function
68/// [`create_state`]: Runtime::create_state
69/// [`accumulate`]: Runtime::accumulate
70/// [`accumulate_or_retract`]: Runtime::accumulate_or_retract
71/// [`merge`]: Runtime::merge
72/// [`finish`]: Runtime::finish
73pub struct Runtime {
74    interpreter: SubInterpreter,
75    functions: HashMap<String, Function>,
76    aggregates: HashMap<String, Aggregate>,
77    converter: pyarrow::Converter,
78}
79
80impl Debug for Runtime {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        f.debug_struct("Runtime")
83            .field("functions", &self.functions.keys())
84            .field("aggregates", &self.aggregates.keys())
85            .finish()
86    }
87}
88
89/// A user defined function.
90struct Function {
91    function: PyObject,
92    return_field: FieldRef,
93    mode: CallMode,
94}
95
96/// A user defined aggregate function.
97struct Aggregate {
98    state_field: FieldRef,
99    output_field: FieldRef,
100    mode: CallMode,
101    create_state: PyObject,
102    accumulate: PyObject,
103    retract: Option<PyObject>,
104    finish: Option<PyObject>,
105    merge: Option<PyObject>,
106}
107
108/// A builder for `Runtime`.
109#[derive(Default, Debug)]
110pub struct Builder {
111    sandboxed: bool,
112    removed_symbols: Vec<String>,
113}
114
115impl Builder {
116    /// Set whether the runtime is sandboxed.
117    ///
118    /// When sandboxed, only a limited set of modules can be imported, and some built-in functions are disabled.
119    /// This is useful for running untrusted code.
120    ///
121    /// Allowed modules: `json`, `decimal`, `re`, `math`, `datetime`, `time`.
122    ///
123    /// Disallowed builtins: `breakpoint`, `exit`, `eval`, `help`, `input`, `open`, `print`.
124    ///
125    /// The default is `false`.
126    pub fn sandboxed(mut self, sandboxed: bool) -> Self {
127        self.sandboxed = sandboxed;
128        self.remove_symbol("__builtins__.breakpoint")
129            .remove_symbol("__builtins__.exit")
130            .remove_symbol("__builtins__.eval")
131            .remove_symbol("__builtins__.help")
132            .remove_symbol("__builtins__.input")
133            .remove_symbol("__builtins__.open")
134            .remove_symbol("__builtins__.print")
135    }
136
137    /// Remove a symbol from builtins.
138    ///
139    /// # Examples
140    ///
141    /// ```
142    /// # use arrow_udf_python::Runtime;
143    /// let builder = Runtime::builder().remove_symbol("__builtins__.eval");
144    /// ```
145    pub fn remove_symbol(mut self, symbol: &str) -> Self {
146        self.removed_symbols.push(symbol.to_string());
147        self
148    }
149
150    /// Build the `Runtime`.
151    pub fn build(self) -> Result<Runtime> {
152        let interpreter = SubInterpreter::new()?;
153        interpreter.run(
154            r#"
155# internal use for json types
156import json
157import pickle
158import decimal
159
160# an internal class used for struct input arguments
161class Struct:
162    pass
163"#,
164        )?;
165        if self.sandboxed {
166            let mut script = r#"
167# limit the modules that can be imported
168original_import = __builtins__.__import__
169
170def limited_import(name, globals=None, locals=None, fromlist=(), level=0):
171    # FIXME: 'sys' should not be allowed, but it is required by 'decimal'
172    # FIXME: 'time.sleep' should not be allowed, but 'time' is required by 'datetime'
173    allowlist = (
174        'json',
175        'decimal',
176        're',
177        'math',
178        'datetime',
179        'time',
180        'operator',
181        'numbers',
182        'abc',
183        'sys',
184        'contextvars',
185        '_io',
186        '_contextvars',
187        '_pydecimal',
188        '_pydatetime',
189    )
190    if level == 0 and name in allowlist:
191        return original_import(name, globals, locals, fromlist, level)
192    raise ImportError(f'import {name} is not allowed')
193
194__builtins__.__import__ = limited_import
195del limited_import
196"#
197            .to_string();
198            for symbol in self.removed_symbols {
199                script.push_str(&format!("del {}\n", symbol));
200            }
201            interpreter.run(&script)?;
202        }
203        Ok(Runtime {
204            interpreter,
205            functions: HashMap::new(),
206            aggregates: HashMap::new(),
207            converter: pyarrow::Converter::new(),
208        })
209    }
210}
211
212impl Runtime {
213    /// Create a new `Runtime`.
214    pub fn new() -> Result<Self> {
215        Builder::default().build()
216    }
217
218    /// Return a new builder for `Runtime`.
219    pub fn builder() -> Builder {
220        Builder::default()
221    }
222
223    /// Add a new scalar function or table function.
224    ///
225    /// # Arguments
226    ///
227    /// - `name`: The name of the function.
228    /// - `return_type`: The data type of the return value.
229    /// - `mode`: Whether the function will be called when some of its arguments are null.
230    /// - `code`: The Python code of the function.
231    ///
232    /// The code should define a function with the same name as the function.
233    /// The function should return a value for scalar functions, or yield values for table functions.
234    ///
235    /// # Example
236    ///
237    /// ```
238    /// # use arrow_udf_python::{Runtime, CallMode};
239    /// # use arrow_schema::DataType;
240    /// let mut runtime = Runtime::new().unwrap();
241    /// // add a scalar function
242    /// runtime
243    ///     .add_function(
244    ///         "gcd",
245    ///         DataType::Int32,
246    ///         CallMode::ReturnNullOnNullInput,
247    ///         r#"
248    /// def gcd(a: int, b: int) -> int:
249    ///     while b:
250    ///         a, b = b, a % b
251    ///     return a
252    /// "#,
253    ///     )
254    ///     .unwrap();
255    /// // add a table function
256    /// runtime
257    ///     .add_function(
258    ///         "series",
259    ///         DataType::Int32,
260    ///         CallMode::ReturnNullOnNullInput,
261    ///         r#"
262    /// def series(n: int):
263    ///     for i in range(n):
264    ///         yield i
265    /// "#,
266    ///     )
267    ///     .unwrap();
268    /// ```
269    pub fn add_function(
270        &mut self,
271        name: &str,
272        return_type: impl IntoField,
273        mode: CallMode,
274        code: &str,
275    ) -> Result<()> {
276        self.add_function_with_handler(name, return_type, mode, code, name)
277    }
278
279    /// Add a new scalar function or table function with custom handler name.
280    ///
281    /// # Arguments
282    ///
283    /// - `handler`: The name of function in Python code to be called.
284    /// - others: Same as [`add_function`].
285    ///
286    /// [`add_function`]: Runtime::add_function
287    pub fn add_function_with_handler(
288        &mut self,
289        name: &str,
290        return_type: impl IntoField,
291        mode: CallMode,
292        code: &str,
293        handler: &str,
294    ) -> Result<()> {
295        let function = self.interpreter.with_gil(|py| {
296            Ok(PyModule::from_code_bound(py, code, name, name)?
297                .getattr(handler)?
298                .into())
299        })?;
300        let function = Function {
301            function,
302            return_field: return_type.into_field(name).into(),
303            mode,
304        };
305        self.functions.insert(name.to_string(), function);
306        Ok(())
307    }
308
309    /// Add a new aggregate function from Python code.
310    ///
311    /// # Arguments
312    ///
313    /// - `name`: The name of the function.
314    /// - `state_type`: The data type of the internal state.
315    /// - `output_type`: The data type of the aggregate value.
316    /// - `mode`: Whether the function will be called when some of its arguments are null.
317    /// - `code`: The Python code of the aggregate function.
318    ///
319    /// The code should define at least two functions:
320    ///
321    /// - `create_state() -> state`: Create a new state object.
322    /// - `accumulate(state, *args) -> state`: Accumulate a new value into the state, returning the updated state.
323    ///
324    /// optionally, the code can define:
325    ///
326    /// - `finish(state) -> value`: Get the result of the aggregate function.
327    ///     If not defined, the state is returned as the result.
328    ///     In this case, `output_type` must be the same as `state_type`.
329    /// - `retract(state, *args) -> state`: Retract a value from the state, returning the updated state.
330    /// - `merge(state, state) -> state`: Merge two states, returning the merged state.
331    ///
332    /// # Example
333    ///
334    /// ```
335    /// # use arrow_udf_python::{Runtime, CallMode};
336    /// # use arrow_schema::DataType;
337    /// let mut runtime = Runtime::new().unwrap();
338    /// runtime
339    ///     .add_aggregate(
340    ///         "sum",
341    ///         DataType::Int32, // state_type
342    ///         DataType::Int32, // output_type
343    ///         CallMode::ReturnNullOnNullInput,
344    ///         r#"
345    /// def create_state():
346    ///     return 0
347    ///
348    /// def accumulate(state, value):
349    ///     return state + value
350    ///
351    /// def retract(state, value):
352    ///     return state - value
353    ///
354    /// def merge(state1, state2):
355    ///     return state1 + state2
356    ///         "#,
357    ///     )
358    ///     .unwrap();
359    /// ```
360    pub fn add_aggregate(
361        &mut self,
362        name: &str,
363        state_type: impl IntoField,
364        output_type: impl IntoField,
365        mode: CallMode,
366        code: &str,
367    ) -> Result<()> {
368        let aggregate = self.interpreter.with_gil(|py| {
369            let module = PyModule::from_code_bound(py, code, name, name)?;
370            Ok(Aggregate {
371                state_field: state_type.into_field(name).into(),
372                output_field: output_type.into_field(name).into(),
373                mode,
374                create_state: module.getattr("create_state")?.into(),
375                accumulate: module.getattr("accumulate")?.into(),
376                retract: module.getattr("retract").ok().map(|f| f.into()),
377                finish: module.getattr("finish").ok().map(|f| f.into()),
378                merge: module.getattr("merge").ok().map(|f| f.into()),
379            })
380        })?;
381        if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
382            bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
383        }
384        self.aggregates.insert(name.to_string(), aggregate);
385        Ok(())
386    }
387
388    /// Remove a scalar or table function.
389    pub fn del_function(&mut self, name: &str) -> Result<()> {
390        let function = self.functions.remove(name).context("function not found")?;
391        _ = self.interpreter.with_gil(|_| {
392            drop(function);
393            Ok(())
394        });
395        Ok(())
396    }
397
398    /// Remove an aggregate function.
399    pub fn del_aggregate(&mut self, name: &str) -> Result<()> {
400        let aggregate = self.functions.remove(name).context("function not found")?;
401        _ = self.interpreter.with_gil(|_| {
402            drop(aggregate);
403            Ok(())
404        });
405        Ok(())
406    }
407
408    /// Call a scalar function.
409    ///
410    /// # Example
411    ///
412    /// ```
413    #[doc = include_str!("doc_create_function.txt")]
414    /// // suppose we have created a scalar function `gcd`
415    /// // see the example in `add_function`
416    ///
417    /// let schema = Schema::new(vec![
418    ///     Field::new("x", DataType::Int32, true),
419    ///     Field::new("y", DataType::Int32, true),
420    /// ]);
421    /// let arg0 = Int32Array::from(vec![Some(25), None]);
422    /// let arg1 = Int32Array::from(vec![Some(15), None]);
423    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0), Arc::new(arg1)]).unwrap();
424    ///
425    /// let output = runtime.call("gcd", &input).unwrap();
426    /// assert_eq!(&**output.column(0), &Int32Array::from(vec![Some(5), None]));
427    /// ```
428    pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
429        let function = self.functions.get(name).context("function not found")?;
430        // convert each row to python objects and call the function
431        let (output, error) = self.interpreter.with_gil(|py| {
432            let mut results = Vec::with_capacity(input.num_rows());
433            let mut errors = vec![];
434            let mut row = Vec::with_capacity(input.num_columns());
435            for i in 0..input.num_rows() {
436                if function.mode == CallMode::ReturnNullOnNullInput
437                    && input.columns().iter().any(|column| column.is_null(i))
438                {
439                    results.push(py.None());
440                    continue;
441                }
442                row.clear();
443                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
444                    let pyobj = self.converter.get_pyobject(py, field, column, i)?;
445                    row.push(pyobj);
446                }
447                let args = PyTuple::new_bound(py, row.drain(..));
448                match function.function.call1(py, args) {
449                    Ok(result) => results.push(result),
450                    Err(e) => {
451                        results.push(py.None());
452                        errors.push((i, e.to_string()));
453                    }
454                }
455            }
456            let output = self
457                .converter
458                .build_array(&function.return_field, py, &results)?;
459            let error = build_error_array(input.num_rows(), errors);
460            Ok((output, error))
461        })?;
462        if let Some(error) = error {
463            let schema = Schema::new(vec![
464                function.return_field.clone(),
465                Field::new("error", DataType::Utf8, true).into(),
466            ]);
467            Ok(RecordBatch::try_new(Arc::new(schema), vec![output, error])?)
468        } else {
469            let schema = Schema::new(vec![function.return_field.clone()]);
470            Ok(RecordBatch::try_new(Arc::new(schema), vec![output])?)
471        }
472    }
473
474    /// Call a table function.
475    ///
476    /// # Example
477    ///
478    /// ```
479    #[doc = include_str!("doc_create_function.txt")]
480    /// // suppose we have created a table function `series`
481    /// // see the example in `add_function`
482    ///
483    /// let schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]);
484    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3)]);
485    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
486    ///
487    /// let mut outputs = runtime.call_table_function("series", &input, 10).unwrap();
488    /// let output = outputs.next().unwrap().unwrap();
489    /// let pretty = arrow_cast::pretty::pretty_format_batches(&[output]).unwrap().to_string();
490    /// assert_eq!(pretty, r#"
491    /// +-----+--------+
492    /// | row | series |
493    /// +-----+--------+
494    /// | 0   | 0      |
495    /// | 2   | 0      |
496    /// | 2   | 1      |
497    /// | 2   | 2      |
498    /// +-----+--------+"#.trim());
499    /// ```
500    pub fn call_table_function<'a>(
501        &'a self,
502        name: &'a str,
503        input: &'a RecordBatch,
504        chunk_size: usize,
505    ) -> Result<RecordBatchIter<'a>> {
506        assert!(chunk_size > 0);
507        let function = self.functions.get(name).context("function not found")?;
508
509        // initial state
510        Ok(RecordBatchIter {
511            interpreter: &self.interpreter,
512            input,
513            function,
514            schema: Arc::new(Schema::new(vec![
515                Field::new("row", DataType::Int32, true).into(),
516                function.return_field.clone(),
517            ])),
518            chunk_size,
519            row: 0,
520            generator: None,
521            converter: &self.converter,
522        })
523    }
524
525    /// Create a new state for an aggregate function.
526    ///
527    /// # Example
528    /// ```
529    #[doc = include_str!("doc_create_aggregate.txt")]
530    /// let state = runtime.create_state("sum").unwrap();
531    /// assert_eq!(&*state, &Int32Array::from(vec![0]));
532    /// ```
533    pub fn create_state(&self, name: &str) -> Result<ArrayRef> {
534        let aggregate = self.aggregates.get(name).context("function not found")?;
535        let state = self.interpreter.with_gil(|py| {
536            let state = aggregate.create_state.call0(py)?;
537            let state = self
538                .converter
539                .build_array(&aggregate.state_field, py, &[state])?;
540            Ok(state)
541        })?;
542        Ok(state)
543    }
544
545    /// Call accumulate of an aggregate function.
546    ///
547    /// # Example
548    /// ```
549    #[doc = include_str!("doc_create_aggregate.txt")]
550    /// let state = runtime.create_state("sum").unwrap();
551    ///
552    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
553    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
554    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
555    ///
556    /// let state = runtime.accumulate("sum", &state, &input).unwrap();
557    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
558    /// ```
559    pub fn accumulate(
560        &self,
561        name: &str,
562        state: &dyn Array,
563        input: &RecordBatch,
564    ) -> Result<ArrayRef> {
565        let aggregate = self.aggregates.get(name).context("function not found")?;
566        // convert each row to python objects and call the accumulate function
567        let new_state = self.interpreter.with_gil(|py| {
568            let mut state = self
569                .converter
570                .get_pyobject(py, &aggregate.state_field, state, 0)?;
571
572            let mut row = Vec::with_capacity(1 + input.num_columns());
573            for i in 0..input.num_rows() {
574                if aggregate.mode == CallMode::ReturnNullOnNullInput
575                    && input.columns().iter().any(|column| column.is_null(i))
576                {
577                    continue;
578                }
579                row.clear();
580                row.push(state.clone_ref(py));
581                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
582                    let pyobj = self.converter.get_pyobject(py, field, column, i)?;
583                    row.push(pyobj);
584                }
585                let args = PyTuple::new_bound(py, row.drain(..));
586                state = aggregate.accumulate.call1(py, args)?;
587            }
588            let output = self
589                .converter
590                .build_array(&aggregate.state_field, py, &[state])?;
591            Ok(output)
592        })?;
593        Ok(new_state)
594    }
595
596    /// Call accumulate or retract of an aggregate function.
597    ///
598    /// The `ops` is a boolean array that indicates whether to accumulate or retract each row.
599    /// `false` for accumulate and `true` for retract.
600    ///
601    /// # Example
602    /// ```
603    #[doc = include_str!("doc_create_aggregate.txt")]
604    /// let state = runtime.create_state("sum").unwrap();
605    ///
606    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
607    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
608    /// let ops = BooleanArray::from(vec![false, false, true, false]);
609    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
610    ///
611    /// let state = runtime.accumulate_or_retract("sum", &state, &ops, &input).unwrap();
612    /// assert_eq!(&*state, &Int32Array::from(vec![3]));
613    /// ```
614    pub fn accumulate_or_retract(
615        &self,
616        name: &str,
617        state: &dyn Array,
618        ops: &BooleanArray,
619        input: &RecordBatch,
620    ) -> Result<ArrayRef> {
621        let aggregate = self.aggregates.get(name).context("function not found")?;
622        let retract = aggregate
623            .retract
624            .as_ref()
625            .context("function does not support retraction")?;
626        // convert each row to python objects and call the accumulate function
627        let new_state = self.interpreter.with_gil(|py| {
628            let mut state = self
629                .converter
630                .get_pyobject(py, &aggregate.state_field, state, 0)?;
631
632            let mut row = Vec::with_capacity(1 + input.num_columns());
633            for i in 0..input.num_rows() {
634                if aggregate.mode == CallMode::ReturnNullOnNullInput
635                    && input.columns().iter().any(|column| column.is_null(i))
636                {
637                    continue;
638                }
639                row.clear();
640                row.push(state.clone_ref(py));
641                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
642                    let pyobj = self.converter.get_pyobject(py, field, column, i)?;
643                    row.push(pyobj);
644                }
645                let args = PyTuple::new_bound(py, row.drain(..));
646                let func = if ops.is_valid(i) && ops.value(i) {
647                    retract
648                } else {
649                    &aggregate.accumulate
650                };
651                state = func.call1(py, args)?;
652            }
653            let output = self
654                .converter
655                .build_array(&aggregate.state_field, py, &[state])?;
656            Ok(output)
657        })?;
658        Ok(new_state)
659    }
660
661    /// Merge states of an aggregate function.
662    ///
663    /// # Example
664    /// ```
665    #[doc = include_str!("doc_create_aggregate.txt")]
666    /// let states = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
667    ///
668    /// let state = runtime.merge("sum", &states).unwrap();
669    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
670    /// ```
671    pub fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
672        let aggregate = self.aggregates.get(name).context("function not found")?;
673        let merge = aggregate.merge.as_ref().context("merge not found")?;
674        let output = self.interpreter.with_gil(|py| {
675            let mut state = self
676                .converter
677                .get_pyobject(py, &aggregate.state_field, states, 0)?;
678            for i in 1..states.len() {
679                if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
680                    continue;
681                }
682                let state2 = self
683                    .converter
684                    .get_pyobject(py, &aggregate.state_field, states, i)?;
685                let args = PyTuple::new_bound(py, [state, state2]);
686                state = merge.call1(py, args)?;
687            }
688            let output = self
689                .converter
690                .build_array(&aggregate.state_field, py, &[state])?;
691            Ok(output)
692        })?;
693        Ok(output)
694    }
695
696    /// Get the result of an aggregate function.
697    ///
698    /// If the `finish` function is not defined, the state is returned as the result.
699    ///
700    /// # Example
701    /// ```
702    #[doc = include_str!("doc_create_aggregate.txt")]
703    /// let states: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)]));
704    ///
705    /// let outputs = runtime.finish("sum", &states).unwrap();
706    /// assert_eq!(&outputs, &states);
707    /// ```
708    pub fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
709        let aggregate = self.aggregates.get(name).context("function not found")?;
710        let Some(finish) = &aggregate.finish else {
711            return Ok(states.clone());
712        };
713        let output = self.interpreter.with_gil(|py| {
714            let mut results = Vec::with_capacity(states.len());
715            for i in 0..states.len() {
716                if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
717                    results.push(py.None());
718                    continue;
719                }
720                let state = self
721                    .converter
722                    .get_pyobject(py, &aggregate.state_field, states, i)?;
723                let args = PyTuple::new_bound(py, [state]);
724                let result = finish.call1(py, args)?;
725                results.push(result);
726            }
727            let output = self
728                .converter
729                .build_array(&aggregate.output_field, py, &results)?;
730            Ok(output)
731        })?;
732        Ok(output)
733    }
734}
735
736/// An iterator over the result of a table function.
737pub struct RecordBatchIter<'a> {
738    interpreter: &'a SubInterpreter,
739    input: &'a RecordBatch,
740    function: &'a Function,
741    schema: SchemaRef,
742    chunk_size: usize,
743    // mutable states
744    /// Current row index.
745    row: usize,
746    /// Generator of the current row.
747    generator: Option<Py<PyIterator>>,
748    converter: &'a pyarrow::Converter,
749}
750
751impl RecordBatchIter<'_> {
752    /// Get the schema of the output.
753    pub fn schema(&self) -> &Schema {
754        &self.schema
755    }
756
757    fn next(&mut self) -> Result<Option<RecordBatch>> {
758        if self.row == self.input.num_rows() {
759            return Ok(None);
760        }
761        let batch = self.interpreter.with_gil(|py| {
762            let mut indexes = Int32Builder::with_capacity(self.chunk_size);
763            let mut results = Vec::with_capacity(self.input.num_rows());
764            let mut errors = vec![];
765            let mut row = Vec::with_capacity(self.input.num_columns());
766            while self.row < self.input.num_rows() && results.len() < self.chunk_size {
767                let generator = if let Some(g) = self.generator.as_ref() {
768                    g
769                } else {
770                    // call the table function to get a generator
771                    if self.function.mode == CallMode::ReturnNullOnNullInput
772                        && (self.input.columns().iter()).any(|column| column.is_null(self.row))
773                    {
774                        self.row += 1;
775                        continue;
776                    }
777                    row.clear();
778                    for (column, field) in
779                        (self.input.columns().iter()).zip(self.input.schema().fields())
780                    {
781                        let val = self.converter.get_pyobject(py, field, column, self.row)?;
782                        row.push(val);
783                    }
784                    let args = PyTuple::new_bound(py, row.drain(..));
785                    match self.function.function.bind(py).call1(args) {
786                        Ok(result) => {
787                            let iter = result.iter()?.into();
788                            self.generator.insert(iter)
789                        }
790                        Err(e) => {
791                            // append a row with null value and error message
792                            indexes.append_value(self.row as i32);
793                            results.push(py.None());
794                            errors.push((indexes.len(), e.to_string()));
795                            self.row += 1;
796                            continue;
797                        }
798                    }
799                };
800                match generator.bind(py).clone().next() {
801                    Some(Ok(value)) => {
802                        indexes.append_value(self.row as i32);
803                        results.push(value.into());
804                    }
805                    Some(Err(e)) => {
806                        indexes.append_value(self.row as i32);
807                        results.push(py.None());
808                        errors.push((indexes.len(), e.to_string()));
809                        self.row += 1;
810                        self.generator = None;
811                    }
812                    None => {
813                        self.row += 1;
814                        self.generator = None;
815                    }
816                }
817            }
818
819            if results.is_empty() {
820                return Ok(None);
821            }
822            let indexes = Arc::new(indexes.finish());
823            let output = self
824                .converter
825                .build_array(&self.function.return_field, py, &results)
826                .context("failed to build arrow array from return values")?;
827            let error = build_error_array(indexes.len(), errors);
828            if let Some(error) = error {
829                Ok(Some(
830                    RecordBatch::try_new(
831                        Arc::new(append_error_to_schema(&self.schema)),
832                        vec![indexes, output, error],
833                    )
834                    .unwrap(),
835                ))
836            } else {
837                Ok(Some(
838                    RecordBatch::try_new(self.schema.clone(), vec![indexes, output]).unwrap(),
839                ))
840            }
841        })?;
842        Ok(batch)
843    }
844}
845
846impl Iterator for RecordBatchIter<'_> {
847    type Item = Result<RecordBatch>;
848    fn next(&mut self) -> Option<Self::Item> {
849        self.next().transpose()
850    }
851}
852
853impl Drop for RecordBatchIter<'_> {
854    fn drop(&mut self) {
855        if let Some(generator) = self.generator.take() {
856            _ = self.interpreter.with_gil(|_| {
857                drop(generator);
858                Ok(())
859            });
860        }
861    }
862}
863
864/// Whether the function will be called when some of its arguments are null.
865#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
866pub enum CallMode {
867    /// The function will be called normally when some of its arguments are null.
868    /// It is then the function author's responsibility to check for null values if necessary and respond appropriately.
869    #[default]
870    CalledOnNullInput,
871
872    /// The function always returns null whenever any of its arguments are null.
873    /// If this parameter is specified, the function is not executed when there are null arguments;
874    /// instead a null result is assumed automatically.
875    ReturnNullOnNullInput,
876}
877
878impl Drop for Runtime {
879    fn drop(&mut self) {
880        // `PyObject` must be dropped inside the interpreter
881        _ = self.interpreter.with_gil(|_| {
882            self.functions.clear();
883            self.aggregates.clear();
884            Ok(())
885        });
886    }
887}
888
889fn build_error_array(num_rows: usize, errors: Vec<(usize, String)>) -> Option<ArrayRef> {
890    if errors.is_empty() {
891        return None;
892    }
893    let data_capacity = errors.iter().map(|(i, _)| i).sum();
894    let mut builder = StringBuilder::with_capacity(num_rows, data_capacity);
895    for (i, msg) in errors {
896        while builder.len() + 1 < i {
897            builder.append_null();
898        }
899        builder.append_value(&msg);
900    }
901    while builder.len() < num_rows {
902        builder.append_null();
903    }
904    Some(Arc::new(builder.finish()))
905}
906
907/// Append an error field to the schema.
908fn append_error_to_schema(schema: &Schema) -> Schema {
909    let mut fields = schema.fields().to_vec();
910    fields.push(Arc::new(Field::new("error", DataType::Utf8, true)));
911    Schema::new(fields)
912}