#![doc = include_str!("../README.md")]
use self::interpreter::SubInterpreter;
pub use self::into_field::IntoField;
use anyhow::{bail, Context, Result};
use arrow_array::builder::{ArrayBuilder, Int32Builder, StringBuilder};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
use pyo3::types::{PyAnyMethods, PyIterator, PyModule, PyTuple};
use pyo3::{Py, PyObject};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
mod interpreter;
mod into_field;
mod pyarrow;
pub struct Runtime {
interpreter: SubInterpreter,
functions: HashMap<String, Function>,
aggregates: HashMap<String, Aggregate>,
converter: pyarrow::Converter,
}
impl Debug for Runtime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Runtime")
.field("functions", &self.functions.keys())
.field("aggregates", &self.aggregates.keys())
.finish()
}
}
struct Function {
function: PyObject,
return_field: FieldRef,
mode: CallMode,
}
struct Aggregate {
state_field: FieldRef,
output_field: FieldRef,
mode: CallMode,
create_state: PyObject,
accumulate: PyObject,
retract: Option<PyObject>,
finish: Option<PyObject>,
merge: Option<PyObject>,
}
#[derive(Default, Debug)]
pub struct Builder {
sandboxed: bool,
removed_symbols: Vec<String>,
}
impl Builder {
pub fn sandboxed(mut self, sandboxed: bool) -> Self {
self.sandboxed = sandboxed;
self.remove_symbol("__builtins__.breakpoint")
.remove_symbol("__builtins__.exit")
.remove_symbol("__builtins__.eval")
.remove_symbol("__builtins__.help")
.remove_symbol("__builtins__.input")
.remove_symbol("__builtins__.open")
.remove_symbol("__builtins__.print")
}
pub fn remove_symbol(mut self, symbol: &str) -> Self {
self.removed_symbols.push(symbol.to_string());
self
}
pub fn build(self) -> Result<Runtime> {
let interpreter = SubInterpreter::new()?;
interpreter.run(
r#"
# internal use for json types
import json
import pickle
import decimal
# an internal class used for struct input arguments
class Struct:
pass
"#,
)?;
if self.sandboxed {
let mut script = r#"
# limit the modules that can be imported
original_import = __builtins__.__import__
def limited_import(name, globals=None, locals=None, fromlist=(), level=0):
# FIXME: 'sys' should not be allowed, but it is required by 'decimal'
# FIXME: 'time.sleep' should not be allowed, but 'time' is required by 'datetime'
allowlist = (
'json',
'decimal',
're',
'math',
'datetime',
'time',
'operator',
'numbers',
'abc',
'sys',
'contextvars',
'_io',
'_contextvars',
'_pydecimal',
'_pydatetime',
)
if level == 0 and name in allowlist:
return original_import(name, globals, locals, fromlist, level)
raise ImportError(f'import {name} is not allowed')
__builtins__.__import__ = limited_import
del limited_import
"#
.to_string();
for symbol in self.removed_symbols {
script.push_str(&format!("del {}\n", symbol));
}
interpreter.run(&script)?;
}
Ok(Runtime {
interpreter,
functions: HashMap::new(),
aggregates: HashMap::new(),
converter: pyarrow::Converter::new(),
})
}
}
impl Runtime {
pub fn new() -> Result<Self> {
Builder::default().build()
}
pub fn builder() -> Builder {
Builder::default()
}
pub fn add_function(
&mut self,
name: &str,
return_type: impl IntoField,
mode: CallMode,
code: &str,
) -> Result<()> {
self.add_function_with_handler(name, return_type, mode, code, name)
}
pub fn add_function_with_handler(
&mut self,
name: &str,
return_type: impl IntoField,
mode: CallMode,
code: &str,
handler: &str,
) -> Result<()> {
let function = self.interpreter.with_gil(|py| {
Ok(PyModule::from_code_bound(py, code, name, name)?
.getattr(handler)?
.into())
})?;
let function = Function {
function,
return_field: return_type.into_field(name).into(),
mode,
};
self.functions.insert(name.to_string(), function);
Ok(())
}
pub fn add_aggregate(
&mut self,
name: &str,
state_type: impl IntoField,
output_type: impl IntoField,
mode: CallMode,
code: &str,
) -> Result<()> {
let aggregate = self.interpreter.with_gil(|py| {
let module = PyModule::from_code_bound(py, code, name, name)?;
Ok(Aggregate {
state_field: state_type.into_field(name).into(),
output_field: output_type.into_field(name).into(),
mode,
create_state: module.getattr("create_state")?.into(),
accumulate: module.getattr("accumulate")?.into(),
retract: module.getattr("retract").ok().map(|f| f.into()),
finish: module.getattr("finish").ok().map(|f| f.into()),
merge: module.getattr("merge").ok().map(|f| f.into()),
})
})?;
if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
}
self.aggregates.insert(name.to_string(), aggregate);
Ok(())
}
pub fn del_function(&mut self, name: &str) -> Result<()> {
let function = self.functions.remove(name).context("function not found")?;
_ = self.interpreter.with_gil(|_| {
drop(function);
Ok(())
});
Ok(())
}
pub fn del_aggregate(&mut self, name: &str) -> Result<()> {
let aggregate = self.functions.remove(name).context("function not found")?;
_ = self.interpreter.with_gil(|_| {
drop(aggregate);
Ok(())
});
Ok(())
}
#[doc = include_str!("doc_create_function.txt")]
pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
let function = self.functions.get(name).context("function not found")?;
let (output, error) = self.interpreter.with_gil(|py| {
let mut results = Vec::with_capacity(input.num_rows());
let mut errors = vec![];
let mut row = Vec::with_capacity(input.num_columns());
for i in 0..input.num_rows() {
if function.mode == CallMode::ReturnNullOnNullInput
&& input.columns().iter().any(|column| column.is_null(i))
{
results.push(py.None());
continue;
}
row.clear();
for (column, field) in input.columns().iter().zip(input.schema().fields()) {
let pyobj = self.converter.get_pyobject(py, field, column, i)?;
row.push(pyobj);
}
let args = PyTuple::new_bound(py, row.drain(..));
match function.function.call1(py, args) {
Ok(result) => results.push(result),
Err(e) => {
results.push(py.None());
errors.push((i, e.to_string()));
}
}
}
let output = self
.converter
.build_array(&function.return_field, py, &results)?;
let error = build_error_array(input.num_rows(), errors);
Ok((output, error))
})?;
if let Some(error) = error {
let schema = Schema::new(vec![
function.return_field.clone(),
Field::new("error", DataType::Utf8, true).into(),
]);
Ok(RecordBatch::try_new(Arc::new(schema), vec![output, error])?)
} else {
let schema = Schema::new(vec![function.return_field.clone()]);
Ok(RecordBatch::try_new(Arc::new(schema), vec![output])?)
}
}
#[doc = include_str!("doc_create_function.txt")]
pub fn call_table_function<'a>(
&'a self,
name: &'a str,
input: &'a RecordBatch,
chunk_size: usize,
) -> Result<RecordBatchIter<'a>> {
assert!(chunk_size > 0);
let function = self.functions.get(name).context("function not found")?;
Ok(RecordBatchIter {
interpreter: &self.interpreter,
input,
function,
schema: Arc::new(Schema::new(vec![
Field::new("row", DataType::Int32, true).into(),
function.return_field.clone(),
])),
chunk_size,
row: 0,
generator: None,
converter: &self.converter,
})
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub fn create_state(&self, name: &str) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
let state = self.interpreter.with_gil(|py| {
let state = aggregate.create_state.call0(py)?;
let state = self
.converter
.build_array(&aggregate.state_field, py, &[state])?;
Ok(state)
})?;
Ok(state)
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub fn accumulate(
&self,
name: &str,
state: &dyn Array,
input: &RecordBatch,
) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
let new_state = self.interpreter.with_gil(|py| {
let mut state = self
.converter
.get_pyobject(py, &aggregate.state_field, state, 0)?;
let mut row = Vec::with_capacity(1 + input.num_columns());
for i in 0..input.num_rows() {
if aggregate.mode == CallMode::ReturnNullOnNullInput
&& input.columns().iter().any(|column| column.is_null(i))
{
continue;
}
row.clear();
row.push(state.clone_ref(py));
for (column, field) in input.columns().iter().zip(input.schema().fields()) {
let pyobj = self.converter.get_pyobject(py, field, column, i)?;
row.push(pyobj);
}
let args = PyTuple::new_bound(py, row.drain(..));
state = aggregate.accumulate.call1(py, args)?;
}
let output = self
.converter
.build_array(&aggregate.state_field, py, &[state])?;
Ok(output)
})?;
Ok(new_state)
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub fn accumulate_or_retract(
&self,
name: &str,
state: &dyn Array,
ops: &BooleanArray,
input: &RecordBatch,
) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
let retract = aggregate
.retract
.as_ref()
.context("function does not support retraction")?;
let new_state = self.interpreter.with_gil(|py| {
let mut state = self
.converter
.get_pyobject(py, &aggregate.state_field, state, 0)?;
let mut row = Vec::with_capacity(1 + input.num_columns());
for i in 0..input.num_rows() {
if aggregate.mode == CallMode::ReturnNullOnNullInput
&& input.columns().iter().any(|column| column.is_null(i))
{
continue;
}
row.clear();
row.push(state.clone_ref(py));
for (column, field) in input.columns().iter().zip(input.schema().fields()) {
let pyobj = self.converter.get_pyobject(py, field, column, i)?;
row.push(pyobj);
}
let args = PyTuple::new_bound(py, row.drain(..));
let func = if ops.is_valid(i) && ops.value(i) {
retract
} else {
&aggregate.accumulate
};
state = func.call1(py, args)?;
}
let output = self
.converter
.build_array(&aggregate.state_field, py, &[state])?;
Ok(output)
})?;
Ok(new_state)
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
let merge = aggregate.merge.as_ref().context("merge not found")?;
let output = self.interpreter.with_gil(|py| {
let mut state = self
.converter
.get_pyobject(py, &aggregate.state_field, states, 0)?;
for i in 1..states.len() {
if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
continue;
}
let state2 = self
.converter
.get_pyobject(py, &aggregate.state_field, states, i)?;
let args = PyTuple::new_bound(py, [state, state2]);
state = merge.call1(py, args)?;
}
let output = self
.converter
.build_array(&aggregate.state_field, py, &[state])?;
Ok(output)
})?;
Ok(output)
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
let Some(finish) = &aggregate.finish else {
return Ok(states.clone());
};
let output = self.interpreter.with_gil(|py| {
let mut results = Vec::with_capacity(states.len());
for i in 0..states.len() {
if aggregate.mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
results.push(py.None());
continue;
}
let state = self
.converter
.get_pyobject(py, &aggregate.state_field, states, i)?;
let args = PyTuple::new_bound(py, [state]);
let result = finish.call1(py, args)?;
results.push(result);
}
let output = self
.converter
.build_array(&aggregate.output_field, py, &results)?;
Ok(output)
})?;
Ok(output)
}
}
pub struct RecordBatchIter<'a> {
interpreter: &'a SubInterpreter,
input: &'a RecordBatch,
function: &'a Function,
schema: SchemaRef,
chunk_size: usize,
row: usize,
generator: Option<Py<PyIterator>>,
converter: &'a pyarrow::Converter,
}
impl RecordBatchIter<'_> {
pub fn schema(&self) -> &Schema {
&self.schema
}
fn next(&mut self) -> Result<Option<RecordBatch>> {
if self.row == self.input.num_rows() {
return Ok(None);
}
let batch = self.interpreter.with_gil(|py| {
let mut indexes = Int32Builder::with_capacity(self.chunk_size);
let mut results = Vec::with_capacity(self.input.num_rows());
let mut errors = vec![];
let mut row = Vec::with_capacity(self.input.num_columns());
while self.row < self.input.num_rows() && results.len() < self.chunk_size {
let generator = if let Some(g) = self.generator.as_ref() {
g
} else {
if self.function.mode == CallMode::ReturnNullOnNullInput
&& (self.input.columns().iter()).any(|column| column.is_null(self.row))
{
self.row += 1;
continue;
}
row.clear();
for (column, field) in
(self.input.columns().iter()).zip(self.input.schema().fields())
{
let val = self.converter.get_pyobject(py, field, column, self.row)?;
row.push(val);
}
let args = PyTuple::new_bound(py, row.drain(..));
match self.function.function.bind(py).call1(args) {
Ok(result) => {
let iter = result.iter()?.into();
self.generator.insert(iter)
}
Err(e) => {
indexes.append_value(self.row as i32);
results.push(py.None());
errors.push((indexes.len(), e.to_string()));
self.row += 1;
continue;
}
}
};
match generator.bind(py).clone().next() {
Some(Ok(value)) => {
indexes.append_value(self.row as i32);
results.push(value.into());
}
Some(Err(e)) => {
indexes.append_value(self.row as i32);
results.push(py.None());
errors.push((indexes.len(), e.to_string()));
self.row += 1;
self.generator = None;
}
None => {
self.row += 1;
self.generator = None;
}
}
}
if results.is_empty() {
return Ok(None);
}
let indexes = Arc::new(indexes.finish());
let output = self
.converter
.build_array(&self.function.return_field, py, &results)
.context("failed to build arrow array from return values")?;
let error = build_error_array(indexes.len(), errors);
if let Some(error) = error {
Ok(Some(
RecordBatch::try_new(
Arc::new(append_error_to_schema(&self.schema)),
vec![indexes, output, error],
)
.unwrap(),
))
} else {
Ok(Some(
RecordBatch::try_new(self.schema.clone(), vec![indexes, output]).unwrap(),
))
}
})?;
Ok(batch)
}
}
impl Iterator for RecordBatchIter<'_> {
type Item = Result<RecordBatch>;
fn next(&mut self) -> Option<Self::Item> {
self.next().transpose()
}
}
impl Drop for RecordBatchIter<'_> {
fn drop(&mut self) {
if let Some(generator) = self.generator.take() {
_ = self.interpreter.with_gil(|_| {
drop(generator);
Ok(())
});
}
}
}
#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
pub enum CallMode {
#[default]
CalledOnNullInput,
ReturnNullOnNullInput,
}
impl Drop for Runtime {
fn drop(&mut self) {
_ = self.interpreter.with_gil(|_| {
self.functions.clear();
self.aggregates.clear();
Ok(())
});
}
}
fn build_error_array(num_rows: usize, errors: Vec<(usize, String)>) -> Option<ArrayRef> {
if errors.is_empty() {
return None;
}
let data_capacity = errors.iter().map(|(i, _)| i).sum();
let mut builder = StringBuilder::with_capacity(num_rows, data_capacity);
for (i, msg) in errors {
while builder.len() + 1 < i {
builder.append_null();
}
builder.append_value(&msg);
}
while builder.len() < num_rows {
builder.append_null();
}
Some(Arc::new(builder.finish()))
}
fn append_error_to_schema(schema: &Schema) -> Schema {
let mut fields = schema.fields().to_vec();
fields.push(Arc::new(Field::new("error", DataType::Utf8, true)));
Schema::new(fields)
}