arrow_udf_runtime/python/mod.rs
1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("README.md")]
16
17use crate::into_field::IntoField;
18use crate::CallMode;
19
20use self::interpreter::Interpreter;
21use anyhow::{bail, Context, Result};
22use arrow_array::builder::{ArrayBuilder, Int32Builder, StringBuilder};
23use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
24use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
25use pyo3::types::{PyAnyMethods, PyIterator, PyModule, PyTuple};
26use pyo3::{Py, PyObject};
27use std::collections::HashMap;
28use std::ffi::CString;
29use std::fmt::Debug;
30use std::sync::Arc;
31
32// #[cfg(Py_3_12)]
33mod interpreter;
34mod pyarrow;
35
36/// A runtime to execute user defined functions in Python.
37///
38/// # Usages
39///
40/// - Create a new runtime with [`Runtime::new`] or [`Runtime::builder`].
41/// - For scalar functions, use [`add_function`] and [`call`].
42/// - For table functions, use [`add_function`] and [`call_table_function`].
43/// - For aggregate functions, create the function with [`add_aggregate`], and then
44/// - create a new state with [`create_state`],
45/// - update the state with [`accumulate`] or [`accumulate_or_retract`],
46/// - merge states with [`merge`],
47/// - finally get the result with [`finish`].
48///
49/// Click on each function to see the example.
50///
51/// # Parallelism
52///
53/// As we know, Python has a Global Interpreter Lock (GIL) that prevents multiple threads from executing Python code simultaneously.
54/// To work around this limitation, each runtime creates a sub-interpreter with its own GIL. This feature requires Python 3.12 or later.
55///
56/// [`add_function`]: Runtime::add_function
57/// [`add_aggregate`]: Runtime::add_aggregate
58/// [`call`]: Runtime::call
59/// [`call_table_function`]: Runtime::call_table_function
60/// [`create_state`]: Runtime::create_state
61/// [`accumulate`]: Runtime::accumulate
62/// [`accumulate_or_retract`]: Runtime::accumulate_or_retract
63/// [`merge`]: Runtime::merge
64/// [`finish`]: Runtime::finish
65pub struct Runtime {
66 interpreter: Interpreter,
67 functions: HashMap<String, Function>,
68 aggregates: HashMap<String, Aggregate>,
69 converter: pyarrow::Converter,
70}
71
72impl Debug for Runtime {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("Runtime")
75 .field("functions", &self.functions.keys())
76 .field("aggregates", &self.aggregates.keys())
77 .finish()
78 }
79}
80
81/// A user defined function.
82struct Function {
83 function: PyObject,
84 return_field: FieldRef,
85 mode: CallMode,
86}
87
88/// A user defined aggregate function.
89struct Aggregate {
90 state_field: FieldRef,
91 output_field: FieldRef,
92 mode: CallMode,
93 create_state: PyObject,
94 accumulate: PyObject,
95 retract: Option<PyObject>,
96 finish: Option<PyObject>,
97 merge: Option<PyObject>,
98}
99
100/// A builder for `Runtime`.
101#[derive(Default, Debug)]
102pub struct Builder {}
103
104impl Builder {
105 /// Build the `Runtime`.
106 pub fn build(self) -> Result<Runtime> {
107 let interpreter = Interpreter::new()?;
108 interpreter.run(
109 r#"
110# internal use for json types
111import json
112import pickle
113import decimal
114
115# an internal class used for struct input arguments
116class Struct:
117 pass
118"#,
119 )?;
120 Ok(Runtime {
121 interpreter,
122 functions: HashMap::new(),
123 aggregates: HashMap::new(),
124 converter: pyarrow::Converter::new(),
125 })
126 }
127}
128
129impl Runtime {
130 /// Create a new `Runtime`.
131 pub fn new() -> Result<Self> {
132 Builder::default().build()
133 }
134
135 /// Return a new builder for `Runtime`.
136 pub fn builder() -> Builder {
137 Builder::default()
138 }
139
140 /// Add a new scalar function or table function.
141 ///
142 /// # Arguments
143 ///
144 /// - `name`: The name of the function.
145 /// - `return_type`: The data type of the return value.
146 /// - `mode`: Whether the function will be called when some of its arguments are null.
147 /// - `code`: The Python code of the function.
148 ///
149 /// The code should define a function with the same name as the function.
150 /// The function should return a value for scalar functions, or yield values for table functions.
151 ///
152 /// # Example
153 ///
154 /// ```
155 /// # use arrow_udf_runtime::python::Runtime;
156 /// # use arrow_udf_runtime::CallMode;
157 /// # use arrow_schema::DataType;
158 /// let mut runtime = Runtime::new().unwrap();
159 /// // add a scalar function
160 /// runtime
161 /// .add_function(
162 /// "gcd",
163 /// DataType::Int32,
164 /// CallMode::ReturnNullOnNullInput,
165 /// r#"
166 /// def gcd(a: int, b: int) -> int:
167 /// while b:
168 /// a, b = b, a % b
169 /// return a
170 /// "#,
171 /// )
172 /// .unwrap();
173 /// // add a table function
174 /// runtime
175 /// .add_function(
176 /// "series",
177 /// DataType::Int32,
178 /// CallMode::ReturnNullOnNullInput,
179 /// r#"
180 /// def series(n: int):
181 /// for i in range(n):
182 /// yield i
183 /// "#,
184 /// )
185 /// .unwrap();
186 /// ```
187 pub fn add_function(
188 &mut self,
189 name: &str,
190 return_type: impl IntoField,
191 mode: CallMode,
192 code: &str,
193 ) -> Result<()> {
194 self.add_function_with_handler(name, return_type, mode, code, name)
195 }
196
197 /// Add a new scalar function or table function with custom handler name.
198 ///
199 /// # Arguments
200 ///
201 /// - `handler`: The name of function in Python code to be called.
202 /// - others: Same as [`add_function`].
203 ///
204 /// [`add_function`]: Runtime::add_function
205 pub fn add_function_with_handler(
206 &mut self,
207 name: &str,
208 return_type: impl IntoField,
209 mode: CallMode,
210 code: &str,
211 handler: &str,
212 ) -> Result<()> {
213 let function = self.interpreter.with_gil(|py| {
214 let code = CString::new(code).unwrap();
215 let name = CString::new(name).unwrap();
216 Ok(PyModule::from_code(py, &code, &name, &name)?
217 .getattr(handler)?
218 .into())
219 })?;
220 let function = Function {
221 function,
222 return_field: return_type.into_field(name).into(),
223 mode,
224 };
225 self.functions.insert(name.to_string(), function);
226 Ok(())
227 }
228
229 /// Add a new aggregate function from Python code.
230 ///
231 /// # Arguments
232 ///
233 /// - `name`: The name of the function.
234 /// - `state_type`: The data type of the internal state.
235 /// - `output_type`: The data type of the aggregate value.
236 /// - `mode`: Whether the function will be called when some of its arguments are null.
237 /// - `code`: The Python code of the aggregate function.
238 ///
239 /// The code should define at least two functions:
240 ///
241 /// - `create_state() -> state`: Create a new state object.
242 /// - `accumulate(state, *args) -> state`: Accumulate a new value into the state, returning the updated state.
243 ///
244 /// optionally, the code can define:
245 ///
246 /// - `finish(state) -> value`: Get the result of the aggregate function.
247 /// If not defined, the state is returned as the result.
248 /// In this case, `output_type` must be the same as `state_type`.
249 /// - `retract(state, *args) -> state`: Retract a value from the state, returning the updated state.
250 /// - `merge(state, state) -> state`: Merge two states, returning the merged state.
251 ///
252 /// # Example
253 ///
254 /// ```
255 /// # use arrow_udf_runtime::python::Runtime;
256 /// # use arrow_udf_runtime::CallMode;
257 /// # use arrow_schema::DataType;
258 /// let mut runtime = Runtime::new().unwrap();
259 /// runtime
260 /// .add_aggregate(
261 /// "sum",
262 /// DataType::Int32, // state_type
263 /// DataType::Int32, // output_type
264 /// CallMode::ReturnNullOnNullInput,
265 /// r#"
266 /// def create_state():
267 /// return 0
268 ///
269 /// def accumulate(state, value):
270 /// return state + value
271 ///
272 /// def retract(state, value):
273 /// return state - value
274 ///
275 /// def merge(state1, state2):
276 /// return state1 + state2
277 /// "#,
278 /// )
279 /// .unwrap();
280 /// ```
281 pub fn add_aggregate(
282 &mut self,
283 name: &str,
284 state_type: impl IntoField,
285 output_type: impl IntoField,
286 mode: CallMode,
287 code: &str,
288 ) -> Result<()> {
289 let aggregate = self.interpreter.with_gil(|py| {
290 let c_code = CString::new(code).unwrap();
291 let c_name = CString::new(name).unwrap();
292 let module = PyModule::from_code(py, &c_code, &c_name, &c_name)?;
293 Ok(Aggregate {
294 state_field: state_type.into_field(name).into(),
295 output_field: output_type.into_field(name).into(),
296 mode,
297 create_state: module.getattr("create_state")?.into(),
298 accumulate: module.getattr("accumulate")?.into(),
299 retract: module.getattr("retract").ok().map(|f| f.into()),
300 finish: module.getattr("finish").ok().map(|f| f.into()),
301 merge: module.getattr("merge").ok().map(|f| f.into()),
302 })
303 })?;
304 if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
305 bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
306 }
307 self.aggregates.insert(name.to_string(), aggregate);
308 Ok(())
309 }
310
311 /// Remove a scalar or table function.
312 pub fn del_function(&mut self, name: &str) -> Result<()> {
313 let function = self.functions.remove(name).context("function not found")?;
314 _ = self.interpreter.with_gil(|_| {
315 drop(function);
316 Ok(())
317 });
318 Ok(())
319 }
320
321 /// Remove an aggregate function.
322 pub fn del_aggregate(&mut self, name: &str) -> Result<()> {
323 let aggregate = self.functions.remove(name).context("function not found")?;
324 _ = self.interpreter.with_gil(|_| {
325 drop(aggregate);
326 Ok(())
327 });
328 Ok(())
329 }
330
331 /// Call a scalar function.
332 ///
333 /// # Example
334 ///
335 /// ```
336 #[doc = include_str!("doc_create_function.txt")]
337 /// // suppose we have created a scalar function `gcd`
338 /// // see the example in `add_function`
339 ///
340 /// let schema = Schema::new(vec![
341 /// Field::new("x", DataType::Int32, true),
342 /// Field::new("y", DataType::Int32, true),
343 /// ]);
344 /// let arg0 = Int32Array::from(vec![Some(25), None]);
345 /// let arg1 = Int32Array::from(vec![Some(15), None]);
346 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0), Arc::new(arg1)]).unwrap();
347 ///
348 /// let output = runtime.call("gcd", &input).unwrap();
349 /// assert_eq!(&**output.column(0), &Int32Array::from(vec![Some(5), None]));
350 /// ```
351 pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
352 let function = self.functions.get(name).context("function not found")?;
353 // convert each row to python objects and call the function
354 let (output, error) = self.interpreter.with_gil(|py| {
355 let mut results = Vec::with_capacity(input.num_rows());
356 let mut errors = vec![];
357 let mut row = Vec::with_capacity(input.num_columns());
358 for i in 0..input.num_rows() {
359 if function.mode == CallMode::ReturnNullOnNullInput
360 && input.columns().iter().any(|column| column.is_null(i))
361 {
362 results.push(py.None());
363 continue;
364 }
365 row.clear();
366 for (column, field) in input.columns().iter().zip(input.schema().fields()) {
367 let pyobj = self.converter.get_pyobject(py, field, column, i)?;
368 row.push(pyobj);
369 }
370 let args = PyTuple::new(py, row.drain(..))?;
371 match function.function.call1(py, args) {
372 Ok(result) => results.push(result),
373 Err(e) => {
374 results.push(py.None());
375 errors.push((i, e.to_string()));
376 }
377 }
378 }
379 let output = self
380 .converter
381 .build_array(&function.return_field, py, &results)?;
382 let error = build_error_array(input.num_rows(), errors);
383 Ok((output, error))
384 })?;
385 if let Some(error) = error {
386 let schema = Schema::new(vec![
387 function.return_field.clone(),
388 Field::new("error", DataType::Utf8, true).into(),
389 ]);
390 Ok(RecordBatch::try_new(Arc::new(schema), vec![output, error])?)
391 } else {
392 let schema = Schema::new(vec![function.return_field.clone()]);
393 Ok(RecordBatch::try_new(Arc::new(schema), vec![output])?)
394 }
395 }
396
397 /// Call a table function.
398 ///
399 /// # Example
400 ///
401 /// ```
402 #[doc = include_str!("doc_create_function.txt")]
403 /// // suppose we have created a table function `series`
404 /// // see the example in `add_function`
405 ///
406 /// let schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]);
407 /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3)]);
408 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
409 ///
410 /// let mut outputs = runtime.call_table_function("series", &input, 10).unwrap();
411 /// let output = outputs.next().unwrap().unwrap();
412 /// let pretty = arrow_cast::pretty::pretty_format_batches(&[output]).unwrap().to_string();
413 /// assert_eq!(pretty, r#"
414 /// +-----+--------+
415 /// | row | series |
416 /// +-----+--------+
417 /// | 0 | 0 |
418 /// | 2 | 0 |
419 /// | 2 | 1 |
420 /// | 2 | 2 |
421 /// +-----+--------+"#.trim());
422 /// ```
423 pub fn call_table_function<'a>(
424 &'a self,
425 name: &'a str,
426 input: &'a RecordBatch,
427 chunk_size: usize,
428 ) -> Result<RecordBatchIter<'a>> {
429 assert!(chunk_size > 0);
430 let function = self.functions.get(name).context("function not found")?;
431
432 // initial state
433 Ok(RecordBatchIter {
434 interpreter: &self.interpreter,
435 input,
436 function,
437 schema: Arc::new(Schema::new(vec![
438 Field::new("row", DataType::Int32, true).into(),
439 function.return_field.clone(),
440 ])),
441 chunk_size,
442 row: 0,
443 generator: None,
444 converter: &self.converter,
445 })
446 }
447
448 /// Create a new state for an aggregate function.
449 ///
450 /// # Example
451 /// ```
452 #[doc = include_str!("doc_create_aggregate.txt")]
453 /// let state = runtime.create_state("sum").unwrap();
454 /// assert_eq!(&*state, &Int32Array::from(vec![0]));
455 /// ```
456 pub fn create_state(&self, name: &str) -> Result<ArrayRef> {
457 let aggregate = self.aggregates.get(name).context("function not found")?;
458 let state = self.interpreter.with_gil(|py| {
459 let state = aggregate.create_state.call0(py)?;
460 let state = self
461 .converter
462 .build_array(&aggregate.state_field, py, &[state])?;
463 Ok(state)
464 })?;
465 Ok(state)
466 }
467
468 /// Call accumulate of an aggregate function.
469 ///
470 /// # Example
471 /// ```
472 #[doc = include_str!("doc_create_aggregate.txt")]
473 /// let state = runtime.create_state("sum").unwrap();
474 ///
475 /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
476 /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
477 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
478 ///
479 /// let state = runtime.accumulate("sum", &state, &input).unwrap();
480 /// assert_eq!(&*state, &Int32Array::from(vec![9]));
481 /// ```
482 pub fn accumulate(
483 &self,
484 name: &str,
485 state: &dyn Array,
486 input: &RecordBatch,
487 ) -> Result<ArrayRef> {
488 let aggregate = self.aggregates.get(name).context("function not found")?;
489 // convert each row to python objects and call the accumulate function
490 let new_state = self.interpreter.with_gil(|py| {
491 let mut state = self
492 .converter
493 .get_pyobject(py, &aggregate.state_field, state, 0)?;
494
495 let mut row = Vec::with_capacity(1 + input.num_columns());
496 for i in 0..input.num_rows() {
497 if aggregate.mode == CallMode::ReturnNullOnNullInput
498 && input.columns().iter().any(|column| column.is_null(i))
499 {
500 continue;
501 }
502 row.clear();
503 row.push(state.clone_ref(py));
504 for (column, field) in input.columns().iter().zip(input.schema().fields()) {
505 let pyobj = self.converter.get_pyobject(py, field, column, i)?;
506 row.push(pyobj);
507 }
508 let args = PyTuple::new(py, row.drain(..))?;
509 state = aggregate.accumulate.call1(py, args)?;
510 }
511 let output = self
512 .converter
513 .build_array(&aggregate.state_field, py, &[state])?;
514 Ok(output)
515 })?;
516 Ok(new_state)
517 }
518
519 /// Call accumulate or retract of an aggregate function.
520 ///
521 /// The `ops` is a boolean array that indicates whether to accumulate or retract each row.
522 /// `false` for accumulate and `true` for retract.
523 ///
524 /// # Example
525 /// ```
526 #[doc = include_str!("doc_create_aggregate.txt")]
527 /// let state = runtime.create_state("sum").unwrap();
528 ///
529 /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
530 /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
531 /// let ops = BooleanArray::from(vec![false, false, true, false]);
532 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
533 ///
534 /// let state = runtime.accumulate_or_retract("sum", &state, &ops, &input).unwrap();
535 /// assert_eq!(&*state, &Int32Array::from(vec![3]));
536 /// ```
537 pub fn accumulate_or_retract(
538 &self,
539 name: &str,
540 state: &dyn Array,
541 ops: &BooleanArray,
542 input: &RecordBatch,
543 ) -> Result<ArrayRef> {
544 let aggregate = self.aggregates.get(name).context("function not found")?;
545 let retract = aggregate
546 .retract
547 .as_ref()
548 .context("function does not support retraction")?;
549 // convert each row to python objects and call the accumulate function
550 let new_state = self.interpreter.with_gil(|py| {
551 let mut state = self
552 .converter
553 .get_pyobject(py, &aggregate.state_field, state, 0)?;
554
555 let mut row = Vec::with_capacity(1 + input.num_columns());
556 for i in 0..input.num_rows() {
557 if aggregate.mode == CallMode::ReturnNullOnNullInput
558 && input.columns().iter().any(|column| column.is_null(i))
559 {
560 continue;
561 }
562 row.clear();
563 row.push(state.clone_ref(py));
564 for (column, field) in input.columns().iter().zip(input.schema().fields()) {
565 let pyobj = self.converter.get_pyobject(py, field, column, i)?;
566 row.push(pyobj);
567 }
568 let args = PyTuple::new(py, row.drain(..))?;
569 let func = if ops.is_valid(i) && ops.value(i) {
570 retract
571 } else {
572 &aggregate.accumulate
573 };
574 state = func.call1(py, args)?;
575 }
576 let output = self
577 .converter
578 .build_array(&aggregate.state_field, py, &[state])?;
579 Ok(output)
580 })?;
581 Ok(new_state)
582 }
583
584 /// Merge states of an aggregate function.
585 ///
586 /// # Example
587 /// ```
588 #[doc = include_str!("doc_create_aggregate.txt")]
589 /// let states = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
590 ///
591 /// let state = runtime.merge("sum", &states).unwrap();
592 /// assert_eq!(&*state, &Int32Array::from(vec![9]));
593 /// ```
594 pub fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
595 let aggregate = self.aggregates.get(name).context("function not found")?;
596 let merge = aggregate.merge.as_ref().context("merge not found")?;
597 let output = self.interpreter.with_gil(|py| {
598 let mut state = self
599 .converter
600 .get_pyobject(py, &aggregate.state_field, states, 0)?;
601 for i in 1..states.len() {
602 if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
603 continue;
604 }
605 let state2 = self
606 .converter
607 .get_pyobject(py, &aggregate.state_field, states, i)?;
608 let args = PyTuple::new(py, [state, state2])?;
609 state = merge.call1(py, args)?;
610 }
611 let output = self
612 .converter
613 .build_array(&aggregate.state_field, py, &[state])?;
614 Ok(output)
615 })?;
616 Ok(output)
617 }
618
619 /// Get the result of an aggregate function.
620 ///
621 /// If the `finish` function is not defined, the state is returned as the result.
622 ///
623 /// # Example
624 /// ```
625 #[doc = include_str!("doc_create_aggregate.txt")]
626 /// let states: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)]));
627 ///
628 /// let outputs = runtime.finish("sum", &states).unwrap();
629 /// assert_eq!(&outputs, &states);
630 /// ```
631 pub fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
632 let aggregate = self.aggregates.get(name).context("function not found")?;
633 let Some(finish) = &aggregate.finish else {
634 return Ok(states.clone());
635 };
636 let output = self.interpreter.with_gil(|py| {
637 let mut results = Vec::with_capacity(states.len());
638 for i in 0..states.len() {
639 if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
640 results.push(py.None());
641 continue;
642 }
643 let state = self
644 .converter
645 .get_pyobject(py, &aggregate.state_field, states, i)?;
646 let args = PyTuple::new(py, [state])?;
647 let result = finish.call1(py, args)?;
648 results.push(result);
649 }
650 let output = self
651 .converter
652 .build_array(&aggregate.output_field, py, &results)?;
653 Ok(output)
654 })?;
655 Ok(output)
656 }
657}
658
659/// An iterator over the result of a table function.
660pub struct RecordBatchIter<'a> {
661 interpreter: &'a Interpreter,
662 input: &'a RecordBatch,
663 function: &'a Function,
664 schema: SchemaRef,
665 chunk_size: usize,
666 // mutable states
667 /// Current row index.
668 row: usize,
669 /// Generator of the current row.
670 generator: Option<Py<PyIterator>>,
671 converter: &'a pyarrow::Converter,
672}
673
674impl RecordBatchIter<'_> {
675 /// Get the schema of the output.
676 pub fn schema(&self) -> &Schema {
677 &self.schema
678 }
679
680 fn next(&mut self) -> Result<Option<RecordBatch>> {
681 if self.row == self.input.num_rows() {
682 return Ok(None);
683 }
684 let batch = self.interpreter.with_gil(|py| {
685 let mut indexes = Int32Builder::with_capacity(self.chunk_size);
686 let mut results = Vec::with_capacity(self.input.num_rows());
687 let mut errors = vec![];
688 let mut row = Vec::with_capacity(self.input.num_columns());
689 while self.row < self.input.num_rows() && results.len() < self.chunk_size {
690 let generator = if let Some(g) = self.generator.as_ref() {
691 g
692 } else {
693 // call the table function to get a generator
694 if self.function.mode == CallMode::ReturnNullOnNullInput
695 && (self.input.columns().iter()).any(|column| column.is_null(self.row))
696 {
697 self.row += 1;
698 continue;
699 }
700 row.clear();
701 for (column, field) in
702 (self.input.columns().iter()).zip(self.input.schema().fields())
703 {
704 let val = self.converter.get_pyobject(py, field, column, self.row)?;
705 row.push(val);
706 }
707 let args = PyTuple::new(py, row.drain(..))?;
708 match self.function.function.bind(py).call1(args) {
709 Ok(result) => {
710 let iter = result.try_iter()?.into();
711 self.generator.insert(iter)
712 }
713 Err(e) => {
714 // append a row with null value and error message
715 indexes.append_value(self.row as i32);
716 results.push(py.None());
717 errors.push((indexes.len(), e.to_string()));
718 self.row += 1;
719 continue;
720 }
721 }
722 };
723 match generator.bind(py).clone().next() {
724 Some(Ok(value)) => {
725 indexes.append_value(self.row as i32);
726 results.push(value.into());
727 }
728 Some(Err(e)) => {
729 indexes.append_value(self.row as i32);
730 results.push(py.None());
731 errors.push((indexes.len(), e.to_string()));
732 self.row += 1;
733 self.generator = None;
734 }
735 None => {
736 self.row += 1;
737 self.generator = None;
738 }
739 }
740 }
741
742 if results.is_empty() {
743 return Ok(None);
744 }
745 let indexes = Arc::new(indexes.finish());
746 let output = self
747 .converter
748 .build_array(&self.function.return_field, py, &results)
749 .context("failed to build arrow array from return values")?;
750 let error = build_error_array(indexes.len(), errors);
751 if let Some(error) = error {
752 Ok(Some(
753 RecordBatch::try_new(
754 Arc::new(append_error_to_schema(&self.schema)),
755 vec![indexes, output, error],
756 )
757 .unwrap(),
758 ))
759 } else {
760 Ok(Some(
761 RecordBatch::try_new(self.schema.clone(), vec![indexes, output]).unwrap(),
762 ))
763 }
764 })?;
765 Ok(batch)
766 }
767}
768
769impl Iterator for RecordBatchIter<'_> {
770 type Item = Result<RecordBatch>;
771 fn next(&mut self) -> Option<Self::Item> {
772 self.next().transpose()
773 }
774}
775
776impl Drop for RecordBatchIter<'_> {
777 fn drop(&mut self) {
778 if let Some(generator) = self.generator.take() {
779 _ = self.interpreter.with_gil(|_| {
780 drop(generator);
781 Ok(())
782 });
783 }
784 }
785}
786
787impl Drop for Runtime {
788 fn drop(&mut self) {
789 // `PyObject` must be dropped inside the interpreter
790 _ = self.interpreter.with_gil(|_| {
791 self.functions.clear();
792 self.aggregates.clear();
793 Ok(())
794 });
795 }
796}
797
798fn build_error_array(num_rows: usize, errors: Vec<(usize, String)>) -> Option<ArrayRef> {
799 if errors.is_empty() {
800 return None;
801 }
802 let data_capacity = errors.iter().map(|(i, _)| i).sum();
803 let mut builder = StringBuilder::with_capacity(num_rows, data_capacity);
804 for (i, msg) in errors {
805 while builder.len() + 1 < i {
806 builder.append_null();
807 }
808 builder.append_value(&msg);
809 }
810 while builder.len() < num_rows {
811 builder.append_null();
812 }
813 Some(Arc::new(builder.finish()))
814}
815
816/// Append an error field to the schema.
817fn append_error_to_schema(schema: &Schema) -> Schema {
818 let mut fields = schema.fields().to_vec();
819 fields.push(Arc::new(Field::new("error", DataType::Utf8, true)));
820 Schema::new(fields)
821}