arrow_udf_python/lib.rs
1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("../README.md")]
16
17// Notice for developers:
18// This library uses the sub-interpreter and per-interpreter GIL introduced in Python 3.12
19// to support concurrent execution of different functions in multiple threads.
20// However, pyo3 has not yet safely supported sub-interpreter. We use the raw FFI API of pyo3 to implement sub-interpreter.
21// Therefore, special attention is needed:
22// **All PyObject created in a sub-interpreter must be destroyed in the same sub-interpreter.**
23// Otherwise, it will cause a crash the next time Python is called.
24// Special attention is needed for PyErr in PyResult.
25// Remember to convert `PyErr` using the `pyerr_to_anyhow` function before passing it out of the sub-interpreter.
26
27use self::interpreter::SubInterpreter;
28pub use self::into_field::IntoField;
29use anyhow::{bail, Context, Result};
30use arrow_array::builder::{ArrayBuilder, Int32Builder, StringBuilder};
31use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
32use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
33use pyo3::types::{PyAnyMethods, PyIterator, PyModule, PyTuple};
34use pyo3::{Py, PyObject};
35use std::collections::HashMap;
36use std::fmt::Debug;
37use std::sync::Arc;
38
39// #[cfg(Py_3_12)]
40mod interpreter;
41mod into_field;
42mod pyarrow;
43
44/// A runtime to execute user defined functions in Python.
45///
46/// # Usages
47///
48/// - Create a new runtime with [`Runtime::new`] or [`Runtime::builder`].
49/// - For scalar functions, use [`add_function`] and [`call`].
50/// - For table functions, use [`add_function`] and [`call_table_function`].
51/// - For aggregate functions, create the function with [`add_aggregate`], and then
52/// - create a new state with [`create_state`],
53/// - update the state with [`accumulate`] or [`accumulate_or_retract`],
54/// - merge states with [`merge`],
55/// - finally get the result with [`finish`].
56///
57/// Click on each function to see the example.
58///
59/// # Parallelism
60///
61/// As we know, Python has a Global Interpreter Lock (GIL) that prevents multiple threads from executing Python code simultaneously.
62/// To work around this limitation, each runtime creates a sub-interpreter with its own GIL. This feature requires Python 3.12 or later.
63///
64/// [`add_function`]: Runtime::add_function
65/// [`add_aggregate`]: Runtime::add_aggregate
66/// [`call`]: Runtime::call
67/// [`call_table_function`]: Runtime::call_table_function
68/// [`create_state`]: Runtime::create_state
69/// [`accumulate`]: Runtime::accumulate
70/// [`accumulate_or_retract`]: Runtime::accumulate_or_retract
71/// [`merge`]: Runtime::merge
72/// [`finish`]: Runtime::finish
73pub struct Runtime {
74 interpreter: SubInterpreter,
75 functions: HashMap<String, Function>,
76 aggregates: HashMap<String, Aggregate>,
77 converter: pyarrow::Converter,
78}
79
80impl Debug for Runtime {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("Runtime")
83 .field("functions", &self.functions.keys())
84 .field("aggregates", &self.aggregates.keys())
85 .finish()
86 }
87}
88
89/// A user defined function.
90struct Function {
91 function: PyObject,
92 return_field: FieldRef,
93 mode: CallMode,
94}
95
96/// A user defined aggregate function.
97struct Aggregate {
98 state_field: FieldRef,
99 output_field: FieldRef,
100 mode: CallMode,
101 create_state: PyObject,
102 accumulate: PyObject,
103 retract: Option<PyObject>,
104 finish: Option<PyObject>,
105 merge: Option<PyObject>,
106}
107
108/// A builder for `Runtime`.
109#[derive(Default, Debug)]
110pub struct Builder {
111 sandboxed: bool,
112 removed_symbols: Vec<String>,
113}
114
115impl Builder {
116 /// Set whether the runtime is sandboxed.
117 ///
118 /// When sandboxed, only a limited set of modules can be imported, and some built-in functions are disabled.
119 /// This is useful for running untrusted code.
120 ///
121 /// Allowed modules: `json`, `decimal`, `re`, `math`, `datetime`, `time`.
122 ///
123 /// Disallowed builtins: `breakpoint`, `exit`, `eval`, `help`, `input`, `open`, `print`.
124 ///
125 /// The default is `false`.
126 pub fn sandboxed(mut self, sandboxed: bool) -> Self {
127 self.sandboxed = sandboxed;
128 self.remove_symbol("__builtins__.breakpoint")
129 .remove_symbol("__builtins__.exit")
130 .remove_symbol("__builtins__.eval")
131 .remove_symbol("__builtins__.help")
132 .remove_symbol("__builtins__.input")
133 .remove_symbol("__builtins__.open")
134 .remove_symbol("__builtins__.print")
135 }
136
137 /// Remove a symbol from builtins.
138 ///
139 /// # Examples
140 ///
141 /// ```
142 /// # use arrow_udf_python::Runtime;
143 /// let builder = Runtime::builder().remove_symbol("__builtins__.eval");
144 /// ```
145 pub fn remove_symbol(mut self, symbol: &str) -> Self {
146 self.removed_symbols.push(symbol.to_string());
147 self
148 }
149
150 /// Build the `Runtime`.
151 pub fn build(self) -> Result<Runtime> {
152 let interpreter = SubInterpreter::new()?;
153 interpreter.run(
154 r#"
155# internal use for json types
156import json
157import pickle
158import decimal
159
160# an internal class used for struct input arguments
161class Struct:
162 pass
163"#,
164 )?;
165 if self.sandboxed {
166 let mut script = r#"
167# limit the modules that can be imported
168original_import = __builtins__.__import__
169
170def limited_import(name, globals=None, locals=None, fromlist=(), level=0):
171 # FIXME: 'sys' should not be allowed, but it is required by 'decimal'
172 # FIXME: 'time.sleep' should not be allowed, but 'time' is required by 'datetime'
173 allowlist = (
174 'json',
175 'decimal',
176 're',
177 'math',
178 'datetime',
179 'time',
180 'operator',
181 'numbers',
182 'abc',
183 'sys',
184 'contextvars',
185 '_io',
186 '_contextvars',
187 '_pydecimal',
188 '_pydatetime',
189 )
190 if level == 0 and name in allowlist:
191 return original_import(name, globals, locals, fromlist, level)
192 raise ImportError(f'import {name} is not allowed')
193
194__builtins__.__import__ = limited_import
195del limited_import
196"#
197 .to_string();
198 for symbol in self.removed_symbols {
199 script.push_str(&format!("del {}\n", symbol));
200 }
201 interpreter.run(&script)?;
202 }
203 Ok(Runtime {
204 interpreter,
205 functions: HashMap::new(),
206 aggregates: HashMap::new(),
207 converter: pyarrow::Converter::new(),
208 })
209 }
210}
211
212impl Runtime {
213 /// Create a new `Runtime`.
214 pub fn new() -> Result<Self> {
215 Builder::default().build()
216 }
217
218 /// Return a new builder for `Runtime`.
219 pub fn builder() -> Builder {
220 Builder::default()
221 }
222
223 /// Add a new scalar function or table function.
224 ///
225 /// # Arguments
226 ///
227 /// - `name`: The name of the function.
228 /// - `return_type`: The data type of the return value.
229 /// - `mode`: Whether the function will be called when some of its arguments are null.
230 /// - `code`: The Python code of the function.
231 ///
232 /// The code should define a function with the same name as the function.
233 /// The function should return a value for scalar functions, or yield values for table functions.
234 ///
235 /// # Example
236 ///
237 /// ```
238 /// # use arrow_udf_python::{Runtime, CallMode};
239 /// # use arrow_schema::DataType;
240 /// let mut runtime = Runtime::new().unwrap();
241 /// // add a scalar function
242 /// runtime
243 /// .add_function(
244 /// "gcd",
245 /// DataType::Int32,
246 /// CallMode::ReturnNullOnNullInput,
247 /// r#"
248 /// def gcd(a: int, b: int) -> int:
249 /// while b:
250 /// a, b = b, a % b
251 /// return a
252 /// "#,
253 /// )
254 /// .unwrap();
255 /// // add a table function
256 /// runtime
257 /// .add_function(
258 /// "series",
259 /// DataType::Int32,
260 /// CallMode::ReturnNullOnNullInput,
261 /// r#"
262 /// def series(n: int):
263 /// for i in range(n):
264 /// yield i
265 /// "#,
266 /// )
267 /// .unwrap();
268 /// ```
269 pub fn add_function(
270 &mut self,
271 name: &str,
272 return_type: impl IntoField,
273 mode: CallMode,
274 code: &str,
275 ) -> Result<()> {
276 self.add_function_with_handler(name, return_type, mode, code, name)
277 }
278
279 /// Add a new scalar function or table function with custom handler name.
280 ///
281 /// # Arguments
282 ///
283 /// - `handler`: The name of function in Python code to be called.
284 /// - others: Same as [`add_function`].
285 ///
286 /// [`add_function`]: Runtime::add_function
287 pub fn add_function_with_handler(
288 &mut self,
289 name: &str,
290 return_type: impl IntoField,
291 mode: CallMode,
292 code: &str,
293 handler: &str,
294 ) -> Result<()> {
295 let function = self.interpreter.with_gil(|py| {
296 Ok(PyModule::from_code_bound(py, code, name, name)?
297 .getattr(handler)?
298 .into())
299 })?;
300 let function = Function {
301 function,
302 return_field: return_type.into_field(name).into(),
303 mode,
304 };
305 self.functions.insert(name.to_string(), function);
306 Ok(())
307 }
308
309 /// Add a new aggregate function from Python code.
310 ///
311 /// # Arguments
312 ///
313 /// - `name`: The name of the function.
314 /// - `state_type`: The data type of the internal state.
315 /// - `output_type`: The data type of the aggregate value.
316 /// - `mode`: Whether the function will be called when some of its arguments are null.
317 /// - `code`: The Python code of the aggregate function.
318 ///
319 /// The code should define at least two functions:
320 ///
321 /// - `create_state() -> state`: Create a new state object.
322 /// - `accumulate(state, *args) -> state`: Accumulate a new value into the state, returning the updated state.
323 ///
324 /// optionally, the code can define:
325 ///
326 /// - `finish(state) -> value`: Get the result of the aggregate function.
327 /// If not defined, the state is returned as the result.
328 /// In this case, `output_type` must be the same as `state_type`.
329 /// - `retract(state, *args) -> state`: Retract a value from the state, returning the updated state.
330 /// - `merge(state, state) -> state`: Merge two states, returning the merged state.
331 ///
332 /// # Example
333 ///
334 /// ```
335 /// # use arrow_udf_python::{Runtime, CallMode};
336 /// # use arrow_schema::DataType;
337 /// let mut runtime = Runtime::new().unwrap();
338 /// runtime
339 /// .add_aggregate(
340 /// "sum",
341 /// DataType::Int32, // state_type
342 /// DataType::Int32, // output_type
343 /// CallMode::ReturnNullOnNullInput,
344 /// r#"
345 /// def create_state():
346 /// return 0
347 ///
348 /// def accumulate(state, value):
349 /// return state + value
350 ///
351 /// def retract(state, value):
352 /// return state - value
353 ///
354 /// def merge(state1, state2):
355 /// return state1 + state2
356 /// "#,
357 /// )
358 /// .unwrap();
359 /// ```
360 pub fn add_aggregate(
361 &mut self,
362 name: &str,
363 state_type: impl IntoField,
364 output_type: impl IntoField,
365 mode: CallMode,
366 code: &str,
367 ) -> Result<()> {
368 let aggregate = self.interpreter.with_gil(|py| {
369 let module = PyModule::from_code_bound(py, code, name, name)?;
370 Ok(Aggregate {
371 state_field: state_type.into_field(name).into(),
372 output_field: output_type.into_field(name).into(),
373 mode,
374 create_state: module.getattr("create_state")?.into(),
375 accumulate: module.getattr("accumulate")?.into(),
376 retract: module.getattr("retract").ok().map(|f| f.into()),
377 finish: module.getattr("finish").ok().map(|f| f.into()),
378 merge: module.getattr("merge").ok().map(|f| f.into()),
379 })
380 })?;
381 if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
382 bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
383 }
384 self.aggregates.insert(name.to_string(), aggregate);
385 Ok(())
386 }
387
388 /// Remove a scalar or table function.
389 pub fn del_function(&mut self, name: &str) -> Result<()> {
390 let function = self.functions.remove(name).context("function not found")?;
391 _ = self.interpreter.with_gil(|_| {
392 drop(function);
393 Ok(())
394 });
395 Ok(())
396 }
397
398 /// Remove an aggregate function.
399 pub fn del_aggregate(&mut self, name: &str) -> Result<()> {
400 let aggregate = self.functions.remove(name).context("function not found")?;
401 _ = self.interpreter.with_gil(|_| {
402 drop(aggregate);
403 Ok(())
404 });
405 Ok(())
406 }
407
408 /// Call a scalar function.
409 ///
410 /// # Example
411 ///
412 /// ```
413 #[doc = include_str!("doc_create_function.txt")]
414 /// // suppose we have created a scalar function `gcd`
415 /// // see the example in `add_function`
416 ///
417 /// let schema = Schema::new(vec![
418 /// Field::new("x", DataType::Int32, true),
419 /// Field::new("y", DataType::Int32, true),
420 /// ]);
421 /// let arg0 = Int32Array::from(vec![Some(25), None]);
422 /// let arg1 = Int32Array::from(vec![Some(15), None]);
423 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0), Arc::new(arg1)]).unwrap();
424 ///
425 /// let output = runtime.call("gcd", &input).unwrap();
426 /// assert_eq!(&**output.column(0), &Int32Array::from(vec![Some(5), None]));
427 /// ```
428 pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
429 let function = self.functions.get(name).context("function not found")?;
430 // convert each row to python objects and call the function
431 let (output, error) = self.interpreter.with_gil(|py| {
432 let mut results = Vec::with_capacity(input.num_rows());
433 let mut errors = vec![];
434 let mut row = Vec::with_capacity(input.num_columns());
435 for i in 0..input.num_rows() {
436 if function.mode == CallMode::ReturnNullOnNullInput
437 && input.columns().iter().any(|column| column.is_null(i))
438 {
439 results.push(py.None());
440 continue;
441 }
442 row.clear();
443 for (column, field) in input.columns().iter().zip(input.schema().fields()) {
444 let pyobj = self.converter.get_pyobject(py, field, column, i)?;
445 row.push(pyobj);
446 }
447 let args = PyTuple::new_bound(py, row.drain(..));
448 match function.function.call1(py, args) {
449 Ok(result) => results.push(result),
450 Err(e) => {
451 results.push(py.None());
452 errors.push((i, e.to_string()));
453 }
454 }
455 }
456 let output = self
457 .converter
458 .build_array(&function.return_field, py, &results)?;
459 let error = build_error_array(input.num_rows(), errors);
460 Ok((output, error))
461 })?;
462 if let Some(error) = error {
463 let schema = Schema::new(vec![
464 function.return_field.clone(),
465 Field::new("error", DataType::Utf8, true).into(),
466 ]);
467 Ok(RecordBatch::try_new(Arc::new(schema), vec![output, error])?)
468 } else {
469 let schema = Schema::new(vec![function.return_field.clone()]);
470 Ok(RecordBatch::try_new(Arc::new(schema), vec![output])?)
471 }
472 }
473
474 /// Call a table function.
475 ///
476 /// # Example
477 ///
478 /// ```
479 #[doc = include_str!("doc_create_function.txt")]
480 /// // suppose we have created a table function `series`
481 /// // see the example in `add_function`
482 ///
483 /// let schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]);
484 /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3)]);
485 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
486 ///
487 /// let mut outputs = runtime.call_table_function("series", &input, 10).unwrap();
488 /// let output = outputs.next().unwrap().unwrap();
489 /// let pretty = arrow_cast::pretty::pretty_format_batches(&[output]).unwrap().to_string();
490 /// assert_eq!(pretty, r#"
491 /// +-----+--------+
492 /// | row | series |
493 /// +-----+--------+
494 /// | 0 | 0 |
495 /// | 2 | 0 |
496 /// | 2 | 1 |
497 /// | 2 | 2 |
498 /// +-----+--------+"#.trim());
499 /// ```
500 pub fn call_table_function<'a>(
501 &'a self,
502 name: &'a str,
503 input: &'a RecordBatch,
504 chunk_size: usize,
505 ) -> Result<RecordBatchIter<'a>> {
506 assert!(chunk_size > 0);
507 let function = self.functions.get(name).context("function not found")?;
508
509 // initial state
510 Ok(RecordBatchIter {
511 interpreter: &self.interpreter,
512 input,
513 function,
514 schema: Arc::new(Schema::new(vec![
515 Field::new("row", DataType::Int32, true).into(),
516 function.return_field.clone(),
517 ])),
518 chunk_size,
519 row: 0,
520 generator: None,
521 converter: &self.converter,
522 })
523 }
524
525 /// Create a new state for an aggregate function.
526 ///
527 /// # Example
528 /// ```
529 #[doc = include_str!("doc_create_aggregate.txt")]
530 /// let state = runtime.create_state("sum").unwrap();
531 /// assert_eq!(&*state, &Int32Array::from(vec![0]));
532 /// ```
533 pub fn create_state(&self, name: &str) -> Result<ArrayRef> {
534 let aggregate = self.aggregates.get(name).context("function not found")?;
535 let state = self.interpreter.with_gil(|py| {
536 let state = aggregate.create_state.call0(py)?;
537 let state = self
538 .converter
539 .build_array(&aggregate.state_field, py, &[state])?;
540 Ok(state)
541 })?;
542 Ok(state)
543 }
544
545 /// Call accumulate of an aggregate function.
546 ///
547 /// # Example
548 /// ```
549 #[doc = include_str!("doc_create_aggregate.txt")]
550 /// let state = runtime.create_state("sum").unwrap();
551 ///
552 /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
553 /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
554 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
555 ///
556 /// let state = runtime.accumulate("sum", &state, &input).unwrap();
557 /// assert_eq!(&*state, &Int32Array::from(vec![9]));
558 /// ```
559 pub fn accumulate(
560 &self,
561 name: &str,
562 state: &dyn Array,
563 input: &RecordBatch,
564 ) -> Result<ArrayRef> {
565 let aggregate = self.aggregates.get(name).context("function not found")?;
566 // convert each row to python objects and call the accumulate function
567 let new_state = self.interpreter.with_gil(|py| {
568 let mut state = self
569 .converter
570 .get_pyobject(py, &aggregate.state_field, state, 0)?;
571
572 let mut row = Vec::with_capacity(1 + input.num_columns());
573 for i in 0..input.num_rows() {
574 if aggregate.mode == CallMode::ReturnNullOnNullInput
575 && input.columns().iter().any(|column| column.is_null(i))
576 {
577 continue;
578 }
579 row.clear();
580 row.push(state.clone_ref(py));
581 for (column, field) in input.columns().iter().zip(input.schema().fields()) {
582 let pyobj = self.converter.get_pyobject(py, field, column, i)?;
583 row.push(pyobj);
584 }
585 let args = PyTuple::new_bound(py, row.drain(..));
586 state = aggregate.accumulate.call1(py, args)?;
587 }
588 let output = self
589 .converter
590 .build_array(&aggregate.state_field, py, &[state])?;
591 Ok(output)
592 })?;
593 Ok(new_state)
594 }
595
596 /// Call accumulate or retract of an aggregate function.
597 ///
598 /// The `ops` is a boolean array that indicates whether to accumulate or retract each row.
599 /// `false` for accumulate and `true` for retract.
600 ///
601 /// # Example
602 /// ```
603 #[doc = include_str!("doc_create_aggregate.txt")]
604 /// let state = runtime.create_state("sum").unwrap();
605 ///
606 /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
607 /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
608 /// let ops = BooleanArray::from(vec![false, false, true, false]);
609 /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
610 ///
611 /// let state = runtime.accumulate_or_retract("sum", &state, &ops, &input).unwrap();
612 /// assert_eq!(&*state, &Int32Array::from(vec![3]));
613 /// ```
614 pub fn accumulate_or_retract(
615 &self,
616 name: &str,
617 state: &dyn Array,
618 ops: &BooleanArray,
619 input: &RecordBatch,
620 ) -> Result<ArrayRef> {
621 let aggregate = self.aggregates.get(name).context("function not found")?;
622 let retract = aggregate
623 .retract
624 .as_ref()
625 .context("function does not support retraction")?;
626 // convert each row to python objects and call the accumulate function
627 let new_state = self.interpreter.with_gil(|py| {
628 let mut state = self
629 .converter
630 .get_pyobject(py, &aggregate.state_field, state, 0)?;
631
632 let mut row = Vec::with_capacity(1 + input.num_columns());
633 for i in 0..input.num_rows() {
634 if aggregate.mode == CallMode::ReturnNullOnNullInput
635 && input.columns().iter().any(|column| column.is_null(i))
636 {
637 continue;
638 }
639 row.clear();
640 row.push(state.clone_ref(py));
641 for (column, field) in input.columns().iter().zip(input.schema().fields()) {
642 let pyobj = self.converter.get_pyobject(py, field, column, i)?;
643 row.push(pyobj);
644 }
645 let args = PyTuple::new_bound(py, row.drain(..));
646 let func = if ops.is_valid(i) && ops.value(i) {
647 retract
648 } else {
649 &aggregate.accumulate
650 };
651 state = func.call1(py, args)?;
652 }
653 let output = self
654 .converter
655 .build_array(&aggregate.state_field, py, &[state])?;
656 Ok(output)
657 })?;
658 Ok(new_state)
659 }
660
661 /// Merge states of an aggregate function.
662 ///
663 /// # Example
664 /// ```
665 #[doc = include_str!("doc_create_aggregate.txt")]
666 /// let states = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
667 ///
668 /// let state = runtime.merge("sum", &states).unwrap();
669 /// assert_eq!(&*state, &Int32Array::from(vec![9]));
670 /// ```
671 pub fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
672 let aggregate = self.aggregates.get(name).context("function not found")?;
673 let merge = aggregate.merge.as_ref().context("merge not found")?;
674 let output = self.interpreter.with_gil(|py| {
675 let mut state = self
676 .converter
677 .get_pyobject(py, &aggregate.state_field, states, 0)?;
678 for i in 1..states.len() {
679 if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
680 continue;
681 }
682 let state2 = self
683 .converter
684 .get_pyobject(py, &aggregate.state_field, states, i)?;
685 let args = PyTuple::new_bound(py, [state, state2]);
686 state = merge.call1(py, args)?;
687 }
688 let output = self
689 .converter
690 .build_array(&aggregate.state_field, py, &[state])?;
691 Ok(output)
692 })?;
693 Ok(output)
694 }
695
696 /// Get the result of an aggregate function.
697 ///
698 /// If the `finish` function is not defined, the state is returned as the result.
699 ///
700 /// # Example
701 /// ```
702 #[doc = include_str!("doc_create_aggregate.txt")]
703 /// let states: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)]));
704 ///
705 /// let outputs = runtime.finish("sum", &states).unwrap();
706 /// assert_eq!(&outputs, &states);
707 /// ```
708 pub fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
709 let aggregate = self.aggregates.get(name).context("function not found")?;
710 let Some(finish) = &aggregate.finish else {
711 return Ok(states.clone());
712 };
713 let output = self.interpreter.with_gil(|py| {
714 let mut results = Vec::with_capacity(states.len());
715 for i in 0..states.len() {
716 if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
717 results.push(py.None());
718 continue;
719 }
720 let state = self
721 .converter
722 .get_pyobject(py, &aggregate.state_field, states, i)?;
723 let args = PyTuple::new_bound(py, [state]);
724 let result = finish.call1(py, args)?;
725 results.push(result);
726 }
727 let output = self
728 .converter
729 .build_array(&aggregate.output_field, py, &results)?;
730 Ok(output)
731 })?;
732 Ok(output)
733 }
734}
735
736/// An iterator over the result of a table function.
737pub struct RecordBatchIter<'a> {
738 interpreter: &'a SubInterpreter,
739 input: &'a RecordBatch,
740 function: &'a Function,
741 schema: SchemaRef,
742 chunk_size: usize,
743 // mutable states
744 /// Current row index.
745 row: usize,
746 /// Generator of the current row.
747 generator: Option<Py<PyIterator>>,
748 converter: &'a pyarrow::Converter,
749}
750
751impl RecordBatchIter<'_> {
752 /// Get the schema of the output.
753 pub fn schema(&self) -> &Schema {
754 &self.schema
755 }
756
757 fn next(&mut self) -> Result<Option<RecordBatch>> {
758 if self.row == self.input.num_rows() {
759 return Ok(None);
760 }
761 let batch = self.interpreter.with_gil(|py| {
762 let mut indexes = Int32Builder::with_capacity(self.chunk_size);
763 let mut results = Vec::with_capacity(self.input.num_rows());
764 let mut errors = vec![];
765 let mut row = Vec::with_capacity(self.input.num_columns());
766 while self.row < self.input.num_rows() && results.len() < self.chunk_size {
767 let generator = if let Some(g) = self.generator.as_ref() {
768 g
769 } else {
770 // call the table function to get a generator
771 if self.function.mode == CallMode::ReturnNullOnNullInput
772 && (self.input.columns().iter()).any(|column| column.is_null(self.row))
773 {
774 self.row += 1;
775 continue;
776 }
777 row.clear();
778 for (column, field) in
779 (self.input.columns().iter()).zip(self.input.schema().fields())
780 {
781 let val = self.converter.get_pyobject(py, field, column, self.row)?;
782 row.push(val);
783 }
784 let args = PyTuple::new_bound(py, row.drain(..));
785 match self.function.function.bind(py).call1(args) {
786 Ok(result) => {
787 let iter = result.iter()?.into();
788 self.generator.insert(iter)
789 }
790 Err(e) => {
791 // append a row with null value and error message
792 indexes.append_value(self.row as i32);
793 results.push(py.None());
794 errors.push((indexes.len(), e.to_string()));
795 self.row += 1;
796 continue;
797 }
798 }
799 };
800 match generator.bind(py).clone().next() {
801 Some(Ok(value)) => {
802 indexes.append_value(self.row as i32);
803 results.push(value.into());
804 }
805 Some(Err(e)) => {
806 indexes.append_value(self.row as i32);
807 results.push(py.None());
808 errors.push((indexes.len(), e.to_string()));
809 self.row += 1;
810 self.generator = None;
811 }
812 None => {
813 self.row += 1;
814 self.generator = None;
815 }
816 }
817 }
818
819 if results.is_empty() {
820 return Ok(None);
821 }
822 let indexes = Arc::new(indexes.finish());
823 let output = self
824 .converter
825 .build_array(&self.function.return_field, py, &results)
826 .context("failed to build arrow array from return values")?;
827 let error = build_error_array(indexes.len(), errors);
828 if let Some(error) = error {
829 Ok(Some(
830 RecordBatch::try_new(
831 Arc::new(append_error_to_schema(&self.schema)),
832 vec![indexes, output, error],
833 )
834 .unwrap(),
835 ))
836 } else {
837 Ok(Some(
838 RecordBatch::try_new(self.schema.clone(), vec![indexes, output]).unwrap(),
839 ))
840 }
841 })?;
842 Ok(batch)
843 }
844}
845
846impl Iterator for RecordBatchIter<'_> {
847 type Item = Result<RecordBatch>;
848 fn next(&mut self) -> Option<Self::Item> {
849 self.next().transpose()
850 }
851}
852
853impl Drop for RecordBatchIter<'_> {
854 fn drop(&mut self) {
855 if let Some(generator) = self.generator.take() {
856 _ = self.interpreter.with_gil(|_| {
857 drop(generator);
858 Ok(())
859 });
860 }
861 }
862}
863
864/// Whether the function will be called when some of its arguments are null.
865#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
866pub enum CallMode {
867 /// The function will be called normally when some of its arguments are null.
868 /// It is then the function author's responsibility to check for null values if necessary and respond appropriately.
869 #[default]
870 CalledOnNullInput,
871
872 /// The function always returns null whenever any of its arguments are null.
873 /// If this parameter is specified, the function is not executed when there are null arguments;
874 /// instead a null result is assumed automatically.
875 ReturnNullOnNullInput,
876}
877
878impl Drop for Runtime {
879 fn drop(&mut self) {
880 // `PyObject` must be dropped inside the interpreter
881 _ = self.interpreter.with_gil(|_| {
882 self.functions.clear();
883 self.aggregates.clear();
884 Ok(())
885 });
886 }
887}
888
889fn build_error_array(num_rows: usize, errors: Vec<(usize, String)>) -> Option<ArrayRef> {
890 if errors.is_empty() {
891 return None;
892 }
893 let data_capacity = errors.iter().map(|(i, _)| i).sum();
894 let mut builder = StringBuilder::with_capacity(num_rows, data_capacity);
895 for (i, msg) in errors {
896 while builder.len() + 1 < i {
897 builder.append_null();
898 }
899 builder.append_value(&msg);
900 }
901 while builder.len() < num_rows {
902 builder.append_null();
903 }
904 Some(Arc::new(builder.finish()))
905}
906
907/// Append an error field to the schema.
908fn append_error_to_schema(schema: &Schema) -> Schema {
909 let mut fields = schema.fields().to_vec();
910 fields.push(Arc::new(Field::new("error", DataType::Utf8, true)));
911 Schema::new(fields)
912}