arrow_udf_runtime/python/
mod.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("README.md")]
16
17use crate::into_field::IntoField;
18use crate::CallMode;
19
20use self::interpreter::Interpreter;
21use anyhow::{bail, Context, Result};
22use arrow_array::builder::{ArrayBuilder, Int32Builder, StringBuilder};
23use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
24use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
25use pyo3::types::{PyAnyMethods, PyIterator, PyModule, PyTuple};
26use pyo3::{Py, PyObject};
27use std::collections::HashMap;
28use std::ffi::CString;
29use std::fmt::Debug;
30use std::sync::Arc;
31
32// #[cfg(Py_3_12)]
33mod interpreter;
34mod pyarrow;
35
36/// A runtime to execute user defined functions in Python.
37///
38/// # Usages
39///
40/// - Create a new runtime with [`Runtime::new`] or [`Runtime::builder`].
41/// - For scalar functions, use [`add_function`] and [`call`].
42/// - For table functions, use [`add_function`] and [`call_table_function`].
43/// - For aggregate functions, create the function with [`add_aggregate`], and then
44///     - create a new state with [`create_state`],
45///     - update the state with [`accumulate`] or [`accumulate_or_retract`],
46///     - merge states with [`merge`],
47///     - finally get the result with [`finish`].
48///
49/// Click on each function to see the example.
50///
51/// # Parallelism
52///
53/// As we know, Python has a Global Interpreter Lock (GIL) that prevents multiple threads from executing Python code simultaneously.
54/// To work around this limitation, each runtime creates a sub-interpreter with its own GIL. This feature requires Python 3.12 or later.
55///
56/// [`add_function`]: Runtime::add_function
57/// [`add_aggregate`]: Runtime::add_aggregate
58/// [`call`]: Runtime::call
59/// [`call_table_function`]: Runtime::call_table_function
60/// [`create_state`]: Runtime::create_state
61/// [`accumulate`]: Runtime::accumulate
62/// [`accumulate_or_retract`]: Runtime::accumulate_or_retract
63/// [`merge`]: Runtime::merge
64/// [`finish`]: Runtime::finish
65pub struct Runtime {
66    interpreter: Interpreter,
67    functions: HashMap<String, Function>,
68    aggregates: HashMap<String, Aggregate>,
69    converter: pyarrow::Converter,
70}
71
72impl Debug for Runtime {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("Runtime")
75            .field("functions", &self.functions.keys())
76            .field("aggregates", &self.aggregates.keys())
77            .finish()
78    }
79}
80
81/// A user defined function.
82struct Function {
83    function: PyObject,
84    return_field: FieldRef,
85    mode: CallMode,
86}
87
88/// A user defined aggregate function.
89struct Aggregate {
90    state_field: FieldRef,
91    output_field: FieldRef,
92    mode: CallMode,
93    create_state: PyObject,
94    accumulate: PyObject,
95    retract: Option<PyObject>,
96    finish: Option<PyObject>,
97    merge: Option<PyObject>,
98}
99
100/// A builder for `Runtime`.
101#[derive(Default, Debug)]
102pub struct Builder {}
103
104impl Builder {
105    /// Build the `Runtime`.
106    pub fn build(self) -> Result<Runtime> {
107        let interpreter = Interpreter::new()?;
108        interpreter.run(
109            r#"
110# internal use for json types
111import json
112import pickle
113import decimal
114
115# an internal class used for struct input arguments
116class Struct:
117    pass
118"#,
119        )?;
120        Ok(Runtime {
121            interpreter,
122            functions: HashMap::new(),
123            aggregates: HashMap::new(),
124            converter: pyarrow::Converter::new(),
125        })
126    }
127}
128
129impl Runtime {
130    /// Create a new `Runtime`.
131    pub fn new() -> Result<Self> {
132        Builder::default().build()
133    }
134
135    /// Return a new builder for `Runtime`.
136    pub fn builder() -> Builder {
137        Builder::default()
138    }
139
140    /// Add a new scalar function or table function.
141    ///
142    /// # Arguments
143    ///
144    /// - `name`: The name of the function.
145    /// - `return_type`: The data type of the return value.
146    /// - `mode`: Whether the function will be called when some of its arguments are null.
147    /// - `code`: The Python code of the function.
148    ///
149    /// The code should define a function with the same name as the function.
150    /// The function should return a value for scalar functions, or yield values for table functions.
151    ///
152    /// # Example
153    ///
154    /// ```
155    /// # use arrow_udf_runtime::python::Runtime;
156    /// # use arrow_udf_runtime::CallMode;
157    /// # use arrow_schema::DataType;
158    /// let mut runtime = Runtime::new().unwrap();
159    /// // add a scalar function
160    /// runtime
161    ///     .add_function(
162    ///         "gcd",
163    ///         DataType::Int32,
164    ///         CallMode::ReturnNullOnNullInput,
165    ///         r#"
166    /// def gcd(a: int, b: int) -> int:
167    ///     while b:
168    ///         a, b = b, a % b
169    ///     return a
170    /// "#,
171    ///     )
172    ///     .unwrap();
173    /// // add a table function
174    /// runtime
175    ///     .add_function(
176    ///         "series",
177    ///         DataType::Int32,
178    ///         CallMode::ReturnNullOnNullInput,
179    ///         r#"
180    /// def series(n: int):
181    ///     for i in range(n):
182    ///         yield i
183    /// "#,
184    ///     )
185    ///     .unwrap();
186    /// ```
187    pub fn add_function(
188        &mut self,
189        name: &str,
190        return_type: impl IntoField,
191        mode: CallMode,
192        code: &str,
193    ) -> Result<()> {
194        self.add_function_with_handler(name, return_type, mode, code, name)
195    }
196
197    /// Add a new scalar function or table function with custom handler name.
198    ///
199    /// # Arguments
200    ///
201    /// - `handler`: The name of function in Python code to be called.
202    /// - others: Same as [`add_function`].
203    ///
204    /// [`add_function`]: Runtime::add_function
205    pub fn add_function_with_handler(
206        &mut self,
207        name: &str,
208        return_type: impl IntoField,
209        mode: CallMode,
210        code: &str,
211        handler: &str,
212    ) -> Result<()> {
213        let function = self.interpreter.with_gil(|py| {
214            let code = CString::new(code).unwrap();
215            let name = CString::new(name).unwrap();
216            Ok(PyModule::from_code(py, &code, &name, &name)?
217                .getattr(handler)?
218                .into())
219        })?;
220        let function = Function {
221            function,
222            return_field: return_type.into_field(name).into(),
223            mode,
224        };
225        self.functions.insert(name.to_string(), function);
226        Ok(())
227    }
228
229    /// Add a new aggregate function from Python code.
230    ///
231    /// # Arguments
232    ///
233    /// - `name`: The name of the function.
234    /// - `state_type`: The data type of the internal state.
235    /// - `output_type`: The data type of the aggregate value.
236    /// - `mode`: Whether the function will be called when some of its arguments are null.
237    /// - `code`: The Python code of the aggregate function.
238    ///
239    /// The code should define at least two functions:
240    ///
241    /// - `create_state() -> state`: Create a new state object.
242    /// - `accumulate(state, *args) -> state`: Accumulate a new value into the state, returning the updated state.
243    ///
244    /// optionally, the code can define:
245    ///
246    /// - `finish(state) -> value`: Get the result of the aggregate function.
247    ///   If not defined, the state is returned as the result.
248    ///   In this case, `output_type` must be the same as `state_type`.
249    /// - `retract(state, *args) -> state`: Retract a value from the state, returning the updated state.
250    /// - `merge(state, state) -> state`: Merge two states, returning the merged state.
251    ///
252    /// # Example
253    ///
254    /// ```
255    /// # use arrow_udf_runtime::python::Runtime;
256    /// # use arrow_udf_runtime::CallMode;
257    /// # use arrow_schema::DataType;
258    /// let mut runtime = Runtime::new().unwrap();
259    /// runtime
260    ///     .add_aggregate(
261    ///         "sum",
262    ///         DataType::Int32, // state_type
263    ///         DataType::Int32, // output_type
264    ///         CallMode::ReturnNullOnNullInput,
265    ///         r#"
266    /// def create_state():
267    ///     return 0
268    ///
269    /// def accumulate(state, value):
270    ///     return state + value
271    ///
272    /// def retract(state, value):
273    ///     return state - value
274    ///
275    /// def merge(state1, state2):
276    ///     return state1 + state2
277    ///         "#,
278    ///     )
279    ///     .unwrap();
280    /// ```
281    pub fn add_aggregate(
282        &mut self,
283        name: &str,
284        state_type: impl IntoField,
285        output_type: impl IntoField,
286        mode: CallMode,
287        code: &str,
288    ) -> Result<()> {
289        let aggregate = self.interpreter.with_gil(|py| {
290            let c_code = CString::new(code).unwrap();
291            let c_name = CString::new(name).unwrap();
292            let module = PyModule::from_code(py, &c_code, &c_name, &c_name)?;
293            Ok(Aggregate {
294                state_field: state_type.into_field(name).into(),
295                output_field: output_type.into_field(name).into(),
296                mode,
297                create_state: module.getattr("create_state")?.into(),
298                accumulate: module.getattr("accumulate")?.into(),
299                retract: module.getattr("retract").ok().map(|f| f.into()),
300                finish: module.getattr("finish").ok().map(|f| f.into()),
301                merge: module.getattr("merge").ok().map(|f| f.into()),
302            })
303        })?;
304        if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
305            bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
306        }
307        self.aggregates.insert(name.to_string(), aggregate);
308        Ok(())
309    }
310
311    /// Remove a scalar or table function.
312    pub fn del_function(&mut self, name: &str) -> Result<()> {
313        let function = self.functions.remove(name).context("function not found")?;
314        _ = self.interpreter.with_gil(|_| {
315            drop(function);
316            Ok(())
317        });
318        Ok(())
319    }
320
321    /// Remove an aggregate function.
322    pub fn del_aggregate(&mut self, name: &str) -> Result<()> {
323        let aggregate = self.functions.remove(name).context("function not found")?;
324        _ = self.interpreter.with_gil(|_| {
325            drop(aggregate);
326            Ok(())
327        });
328        Ok(())
329    }
330
331    /// Call a scalar function.
332    ///
333    /// # Example
334    ///
335    /// ```
336    #[doc = include_str!("doc_create_function.txt")]
337    /// // suppose we have created a scalar function `gcd`
338    /// // see the example in `add_function`
339    ///
340    /// let schema = Schema::new(vec![
341    ///     Field::new("x", DataType::Int32, true),
342    ///     Field::new("y", DataType::Int32, true),
343    /// ]);
344    /// let arg0 = Int32Array::from(vec![Some(25), None]);
345    /// let arg1 = Int32Array::from(vec![Some(15), None]);
346    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0), Arc::new(arg1)]).unwrap();
347    ///
348    /// let output = runtime.call("gcd", &input).unwrap();
349    /// assert_eq!(&**output.column(0), &Int32Array::from(vec![Some(5), None]));
350    /// ```
351    pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
352        let function = self.functions.get(name).context("function not found")?;
353        // convert each row to python objects and call the function
354        let (output, error) = self.interpreter.with_gil(|py| {
355            let mut results = Vec::with_capacity(input.num_rows());
356            let mut errors = vec![];
357            let mut row = Vec::with_capacity(input.num_columns());
358            for i in 0..input.num_rows() {
359                if function.mode == CallMode::ReturnNullOnNullInput
360                    && input.columns().iter().any(|column| column.is_null(i))
361                {
362                    results.push(py.None());
363                    continue;
364                }
365                row.clear();
366                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
367                    let pyobj = self.converter.get_pyobject(py, field, column, i)?;
368                    row.push(pyobj);
369                }
370                let args = PyTuple::new(py, row.drain(..))?;
371                match function.function.call1(py, args) {
372                    Ok(result) => results.push(result),
373                    Err(e) => {
374                        results.push(py.None());
375                        errors.push((i, e.to_string()));
376                    }
377                }
378            }
379            let output = self
380                .converter
381                .build_array(&function.return_field, py, &results)?;
382            let error = build_error_array(input.num_rows(), errors);
383            Ok((output, error))
384        })?;
385        if let Some(error) = error {
386            let schema = Schema::new(vec![
387                function.return_field.clone(),
388                Field::new("error", DataType::Utf8, true).into(),
389            ]);
390            Ok(RecordBatch::try_new(Arc::new(schema), vec![output, error])?)
391        } else {
392            let schema = Schema::new(vec![function.return_field.clone()]);
393            Ok(RecordBatch::try_new(Arc::new(schema), vec![output])?)
394        }
395    }
396
397    /// Call a table function.
398    ///
399    /// # Example
400    ///
401    /// ```
402    #[doc = include_str!("doc_create_function.txt")]
403    /// // suppose we have created a table function `series`
404    /// // see the example in `add_function`
405    ///
406    /// let schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]);
407    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3)]);
408    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
409    ///
410    /// let mut outputs = runtime.call_table_function("series", &input, 10).unwrap();
411    /// let output = outputs.next().unwrap().unwrap();
412    /// let pretty = arrow_cast::pretty::pretty_format_batches(&[output]).unwrap().to_string();
413    /// assert_eq!(pretty, r#"
414    /// +-----+--------+
415    /// | row | series |
416    /// +-----+--------+
417    /// | 0   | 0      |
418    /// | 2   | 0      |
419    /// | 2   | 1      |
420    /// | 2   | 2      |
421    /// +-----+--------+"#.trim());
422    /// ```
423    pub fn call_table_function<'a>(
424        &'a self,
425        name: &'a str,
426        input: &'a RecordBatch,
427        chunk_size: usize,
428    ) -> Result<RecordBatchIter<'a>> {
429        assert!(chunk_size > 0);
430        let function = self.functions.get(name).context("function not found")?;
431
432        // initial state
433        Ok(RecordBatchIter {
434            interpreter: &self.interpreter,
435            input,
436            function,
437            schema: Arc::new(Schema::new(vec![
438                Field::new("row", DataType::Int32, true).into(),
439                function.return_field.clone(),
440            ])),
441            chunk_size,
442            row: 0,
443            generator: None,
444            converter: &self.converter,
445        })
446    }
447
448    /// Create a new state for an aggregate function.
449    ///
450    /// # Example
451    /// ```
452    #[doc = include_str!("doc_create_aggregate.txt")]
453    /// let state = runtime.create_state("sum").unwrap();
454    /// assert_eq!(&*state, &Int32Array::from(vec![0]));
455    /// ```
456    pub fn create_state(&self, name: &str) -> Result<ArrayRef> {
457        let aggregate = self.aggregates.get(name).context("function not found")?;
458        let state = self.interpreter.with_gil(|py| {
459            let state = aggregate.create_state.call0(py)?;
460            let state = self
461                .converter
462                .build_array(&aggregate.state_field, py, &[state])?;
463            Ok(state)
464        })?;
465        Ok(state)
466    }
467
468    /// Call accumulate of an aggregate function.
469    ///
470    /// # Example
471    /// ```
472    #[doc = include_str!("doc_create_aggregate.txt")]
473    /// let state = runtime.create_state("sum").unwrap();
474    ///
475    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
476    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
477    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
478    ///
479    /// let state = runtime.accumulate("sum", &state, &input).unwrap();
480    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
481    /// ```
482    pub fn accumulate(
483        &self,
484        name: &str,
485        state: &dyn Array,
486        input: &RecordBatch,
487    ) -> Result<ArrayRef> {
488        let aggregate = self.aggregates.get(name).context("function not found")?;
489        // convert each row to python objects and call the accumulate function
490        let new_state = self.interpreter.with_gil(|py| {
491            let mut state = self
492                .converter
493                .get_pyobject(py, &aggregate.state_field, state, 0)?;
494
495            let mut row = Vec::with_capacity(1 + input.num_columns());
496            for i in 0..input.num_rows() {
497                if aggregate.mode == CallMode::ReturnNullOnNullInput
498                    && input.columns().iter().any(|column| column.is_null(i))
499                {
500                    continue;
501                }
502                row.clear();
503                row.push(state.clone_ref(py));
504                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
505                    let pyobj = self.converter.get_pyobject(py, field, column, i)?;
506                    row.push(pyobj);
507                }
508                let args = PyTuple::new(py, row.drain(..))?;
509                state = aggregate.accumulate.call1(py, args)?;
510            }
511            let output = self
512                .converter
513                .build_array(&aggregate.state_field, py, &[state])?;
514            Ok(output)
515        })?;
516        Ok(new_state)
517    }
518
519    /// Call accumulate or retract of an aggregate function.
520    ///
521    /// The `ops` is a boolean array that indicates whether to accumulate or retract each row.
522    /// `false` for accumulate and `true` for retract.
523    ///
524    /// # Example
525    /// ```
526    #[doc = include_str!("doc_create_aggregate.txt")]
527    /// let state = runtime.create_state("sum").unwrap();
528    ///
529    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
530    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
531    /// let ops = BooleanArray::from(vec![false, false, true, false]);
532    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
533    ///
534    /// let state = runtime.accumulate_or_retract("sum", &state, &ops, &input).unwrap();
535    /// assert_eq!(&*state, &Int32Array::from(vec![3]));
536    /// ```
537    pub fn accumulate_or_retract(
538        &self,
539        name: &str,
540        state: &dyn Array,
541        ops: &BooleanArray,
542        input: &RecordBatch,
543    ) -> Result<ArrayRef> {
544        let aggregate = self.aggregates.get(name).context("function not found")?;
545        let retract = aggregate
546            .retract
547            .as_ref()
548            .context("function does not support retraction")?;
549        // convert each row to python objects and call the accumulate function
550        let new_state = self.interpreter.with_gil(|py| {
551            let mut state = self
552                .converter
553                .get_pyobject(py, &aggregate.state_field, state, 0)?;
554
555            let mut row = Vec::with_capacity(1 + input.num_columns());
556            for i in 0..input.num_rows() {
557                if aggregate.mode == CallMode::ReturnNullOnNullInput
558                    && input.columns().iter().any(|column| column.is_null(i))
559                {
560                    continue;
561                }
562                row.clear();
563                row.push(state.clone_ref(py));
564                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
565                    let pyobj = self.converter.get_pyobject(py, field, column, i)?;
566                    row.push(pyobj);
567                }
568                let args = PyTuple::new(py, row.drain(..))?;
569                let func = if ops.is_valid(i) && ops.value(i) {
570                    retract
571                } else {
572                    &aggregate.accumulate
573                };
574                state = func.call1(py, args)?;
575            }
576            let output = self
577                .converter
578                .build_array(&aggregate.state_field, py, &[state])?;
579            Ok(output)
580        })?;
581        Ok(new_state)
582    }
583
584    /// Merge states of an aggregate function.
585    ///
586    /// # Example
587    /// ```
588    #[doc = include_str!("doc_create_aggregate.txt")]
589    /// let states = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
590    ///
591    /// let state = runtime.merge("sum", &states).unwrap();
592    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
593    /// ```
594    pub fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
595        let aggregate = self.aggregates.get(name).context("function not found")?;
596        let merge = aggregate.merge.as_ref().context("merge not found")?;
597        let output = self.interpreter.with_gil(|py| {
598            let mut state = self
599                .converter
600                .get_pyobject(py, &aggregate.state_field, states, 0)?;
601            for i in 1..states.len() {
602                if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
603                    continue;
604                }
605                let state2 = self
606                    .converter
607                    .get_pyobject(py, &aggregate.state_field, states, i)?;
608                let args = PyTuple::new(py, [state, state2])?;
609                state = merge.call1(py, args)?;
610            }
611            let output = self
612                .converter
613                .build_array(&aggregate.state_field, py, &[state])?;
614            Ok(output)
615        })?;
616        Ok(output)
617    }
618
619    /// Get the result of an aggregate function.
620    ///
621    /// If the `finish` function is not defined, the state is returned as the result.
622    ///
623    /// # Example
624    /// ```
625    #[doc = include_str!("doc_create_aggregate.txt")]
626    /// let states: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)]));
627    ///
628    /// let outputs = runtime.finish("sum", &states).unwrap();
629    /// assert_eq!(&outputs, &states);
630    /// ```
631    pub fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
632        let aggregate = self.aggregates.get(name).context("function not found")?;
633        let Some(finish) = &aggregate.finish else {
634            return Ok(states.clone());
635        };
636        let output = self.interpreter.with_gil(|py| {
637            let mut results = Vec::with_capacity(states.len());
638            for i in 0..states.len() {
639                if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
640                    results.push(py.None());
641                    continue;
642                }
643                let state = self
644                    .converter
645                    .get_pyobject(py, &aggregate.state_field, states, i)?;
646                let args = PyTuple::new(py, [state])?;
647                let result = finish.call1(py, args)?;
648                results.push(result);
649            }
650            let output = self
651                .converter
652                .build_array(&aggregate.output_field, py, &results)?;
653            Ok(output)
654        })?;
655        Ok(output)
656    }
657}
658
659/// An iterator over the result of a table function.
660pub struct RecordBatchIter<'a> {
661    interpreter: &'a Interpreter,
662    input: &'a RecordBatch,
663    function: &'a Function,
664    schema: SchemaRef,
665    chunk_size: usize,
666    // mutable states
667    /// Current row index.
668    row: usize,
669    /// Generator of the current row.
670    generator: Option<Py<PyIterator>>,
671    converter: &'a pyarrow::Converter,
672}
673
674impl RecordBatchIter<'_> {
675    /// Get the schema of the output.
676    pub fn schema(&self) -> &Schema {
677        &self.schema
678    }
679
680    fn next(&mut self) -> Result<Option<RecordBatch>> {
681        if self.row == self.input.num_rows() {
682            return Ok(None);
683        }
684        let batch = self.interpreter.with_gil(|py| {
685            let mut indexes = Int32Builder::with_capacity(self.chunk_size);
686            let mut results = Vec::with_capacity(self.input.num_rows());
687            let mut errors = vec![];
688            let mut row = Vec::with_capacity(self.input.num_columns());
689            while self.row < self.input.num_rows() && results.len() < self.chunk_size {
690                let generator = if let Some(g) = self.generator.as_ref() {
691                    g
692                } else {
693                    // call the table function to get a generator
694                    if self.function.mode == CallMode::ReturnNullOnNullInput
695                        && (self.input.columns().iter()).any(|column| column.is_null(self.row))
696                    {
697                        self.row += 1;
698                        continue;
699                    }
700                    row.clear();
701                    for (column, field) in
702                        (self.input.columns().iter()).zip(self.input.schema().fields())
703                    {
704                        let val = self.converter.get_pyobject(py, field, column, self.row)?;
705                        row.push(val);
706                    }
707                    let args = PyTuple::new(py, row.drain(..))?;
708                    match self.function.function.bind(py).call1(args) {
709                        Ok(result) => {
710                            let iter = result.try_iter()?.into();
711                            self.generator.insert(iter)
712                        }
713                        Err(e) => {
714                            // append a row with null value and error message
715                            indexes.append_value(self.row as i32);
716                            results.push(py.None());
717                            errors.push((indexes.len(), e.to_string()));
718                            self.row += 1;
719                            continue;
720                        }
721                    }
722                };
723                match generator.bind(py).clone().next() {
724                    Some(Ok(value)) => {
725                        indexes.append_value(self.row as i32);
726                        results.push(value.into());
727                    }
728                    Some(Err(e)) => {
729                        indexes.append_value(self.row as i32);
730                        results.push(py.None());
731                        errors.push((indexes.len(), e.to_string()));
732                        self.row += 1;
733                        self.generator = None;
734                    }
735                    None => {
736                        self.row += 1;
737                        self.generator = None;
738                    }
739                }
740            }
741
742            if results.is_empty() {
743                return Ok(None);
744            }
745            let indexes = Arc::new(indexes.finish());
746            let output = self
747                .converter
748                .build_array(&self.function.return_field, py, &results)
749                .context("failed to build arrow array from return values")?;
750            let error = build_error_array(indexes.len(), errors);
751            if let Some(error) = error {
752                Ok(Some(
753                    RecordBatch::try_new(
754                        Arc::new(append_error_to_schema(&self.schema)),
755                        vec![indexes, output, error],
756                    )
757                    .unwrap(),
758                ))
759            } else {
760                Ok(Some(
761                    RecordBatch::try_new(self.schema.clone(), vec![indexes, output]).unwrap(),
762                ))
763            }
764        })?;
765        Ok(batch)
766    }
767}
768
769impl Iterator for RecordBatchIter<'_> {
770    type Item = Result<RecordBatch>;
771    fn next(&mut self) -> Option<Self::Item> {
772        self.next().transpose()
773    }
774}
775
776impl Drop for RecordBatchIter<'_> {
777    fn drop(&mut self) {
778        if let Some(generator) = self.generator.take() {
779            _ = self.interpreter.with_gil(|_| {
780                drop(generator);
781                Ok(())
782            });
783        }
784    }
785}
786
787impl Drop for Runtime {
788    fn drop(&mut self) {
789        // `PyObject` must be dropped inside the interpreter
790        _ = self.interpreter.with_gil(|_| {
791            self.functions.clear();
792            self.aggregates.clear();
793            Ok(())
794        });
795    }
796}
797
798fn build_error_array(num_rows: usize, errors: Vec<(usize, String)>) -> Option<ArrayRef> {
799    if errors.is_empty() {
800        return None;
801    }
802    let data_capacity = errors.iter().map(|(i, _)| i).sum();
803    let mut builder = StringBuilder::with_capacity(num_rows, data_capacity);
804    for (i, msg) in errors {
805        while builder.len() + 1 < i {
806            builder.append_null();
807        }
808        builder.append_value(&msg);
809    }
810    while builder.len() < num_rows {
811        builder.append_null();
812    }
813    Some(Arc::new(builder.finish()))
814}
815
816/// Append an error field to the schema.
817fn append_error_to_schema(schema: &Schema) -> Schema {
818    let mut fields = schema.fields().to_vec();
819    fields.push(Arc::new(Field::new("error", DataType::Utf8, true)));
820    Schema::new(fields)
821}