#![doc = include_str!("../README.md")]
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use anyhow::{anyhow, Context as _, Result};
use arrow_array::{builder::Int32Builder, Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use rquickjs::{
context::intrinsic::{BaseObjects, BigDecimal, Eval, Json, TypedArrays},
function::Args,
Context, Ctx, Object, Persistent, Value,
};
mod jsarrow;
pub struct Runtime {
functions: HashMap<String, Function>,
bigdecimal: Persistent<rquickjs::Function<'static>>,
_runtime: rquickjs::Runtime,
context: Context,
}
impl Debug for Runtime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Runtime")
.field("functions", &self.functions.keys())
.finish()
}
}
struct Function {
function: Persistent<rquickjs::Function<'static>>,
return_type: DataType,
mode: CallMode,
}
unsafe impl Send for Function {}
unsafe impl Sync for Function {}
#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
pub enum CallMode {
#[default]
CalledOnNullInput,
ReturnNullOnNullInput,
}
impl Runtime {
pub fn new() -> Result<Self> {
let runtime = rquickjs::Runtime::new().context("failed to create quickjs runtime")?;
let context =
rquickjs::Context::custom::<(BaseObjects, Eval, Json, BigDecimal, TypedArrays)>(
&runtime,
)
.context("failed to create quickjs context")?;
let bigdecimal = context.with(|ctx| {
let bigdecimal: rquickjs::Function = ctx.eval("BigDecimal")?;
Ok(Persistent::save(&ctx, bigdecimal)) as Result<_>
})?;
Ok(Self {
functions: HashMap::new(),
bigdecimal,
_runtime: runtime,
context,
})
}
pub fn add_function(
&mut self,
name: &str,
return_type: DataType,
mode: CallMode,
code: &str,
) -> Result<()> {
let function = self.context.with(|ctx| {
let module = ctx
.clone()
.compile("main", code)
.map_err(|e| check_exception(e, &ctx))
.context("failed to compile module")?;
let function: rquickjs::Function = module
.get(name)
.context("failed to get function. HINT: make sure the function is exported")?;
Ok(Persistent::save(&ctx, function)) as Result<_>
})?;
let function = Function {
function,
return_type,
mode,
};
self.functions.insert(name.to_string(), function);
Ok(())
}
pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
let function = self.functions.get(name).context("function not found")?;
self.context.with(|ctx| {
let bigdecimal = self.bigdecimal.clone().restore(&ctx)?;
let js_function = function.function.clone().restore(&ctx)?;
let mut results = Vec::with_capacity(input.num_rows());
let mut row = Vec::with_capacity(input.num_columns());
for i in 0..input.num_rows() {
row.clear();
for column in input.columns() {
let val = jsarrow::get_jsvalue(&ctx, &bigdecimal, column, i)
.context("failed to get jsvalue from arrow array")?;
row.push(val);
}
if function.mode == CallMode::ReturnNullOnNullInput
&& row.iter().any(|v| v.is_null())
{
results.push(Value::new_null(ctx.clone()));
continue;
}
let mut args = Args::new(ctx.clone(), row.len());
args.push_args(row.drain(..))?;
let result = js_function
.call_arg(args)
.map_err(|e| check_exception(e, &ctx))
.context("failed to call function")?;
results.push(result);
}
let array = jsarrow::build_array(&function.return_type, &ctx, results)
.context("failed to build arrow array from return values")?;
let schema = Schema::new(vec![Field::new(name, array.data_type().clone(), true)]);
Ok(RecordBatch::try_new(Arc::new(schema), vec![array])?)
})
}
pub fn call_table_function<'a>(
&'a self,
name: &'a str,
input: &'a RecordBatch,
chunk_size: usize,
) -> Result<impl Iterator<Item = Result<RecordBatch>> + 'a> {
assert!(chunk_size > 0);
struct State<'a> {
context: &'a Context,
bigdecimal: &'a Persistent<rquickjs::Function<'static>>,
input: &'a RecordBatch,
function: &'a Function,
name: &'a str,
chunk_size: usize,
row: usize,
generator: Option<Persistent<Object<'static>>>,
}
unsafe impl Send for State<'_> {}
impl State<'_> {
fn next(&mut self) -> Result<Option<RecordBatch>> {
if self.row == self.input.num_rows() {
return Ok(None);
}
self.context.with(|ctx| {
let bigdecimal = self.bigdecimal.clone().restore(&ctx)?;
let js_function = self.function.function.clone().restore(&ctx)?;
let mut indexes = Int32Builder::with_capacity(self.chunk_size);
let mut results = Vec::with_capacity(self.input.num_rows());
let mut row = Vec::with_capacity(self.input.num_columns());
let mut generator = match self.generator.take() {
Some(generator) => {
let gen = generator.restore(&ctx)?;
let next: rquickjs::Function =
gen.get("next").context("failed to get 'next' method")?;
Some((gen, next))
}
None => None,
};
while self.row < self.input.num_rows() && results.len() < self.chunk_size {
let (gen, next) = if let Some(g) = generator.as_ref() {
g
} else {
row.clear();
for column in self.input.columns() {
let val = jsarrow::get_jsvalue(&ctx, &bigdecimal, column, self.row)
.context("failed to get jsvalue from arrow array")?;
row.push(val);
}
if self.function.mode == CallMode::ReturnNullOnNullInput
&& row.iter().any(|v| v.is_null())
{
self.row += 1;
continue;
}
let mut args = Args::new(ctx.clone(), row.len());
args.push_args(row.drain(..))?;
let gen = js_function
.call_arg::<Object>(args)
.map_err(|e| check_exception(e, &ctx))
.context("failed to call function")?;
let next: rquickjs::Function =
gen.get("next").context("failed to get 'next' method")?;
let mut args = Args::new(ctx.clone(), 0);
args.this(gen.clone())?;
generator.insert((gen, next))
};
let mut args = Args::new(ctx.clone(), 0);
args.this(gen.clone())?;
let object: Object = next
.call_arg(args)
.map_err(|e| check_exception(e, &ctx))
.context("failed to call next")?;
let value: Value = object.get("value")?;
let done: bool = object.get("done")?;
if done {
self.row += 1;
generator = None;
continue;
}
indexes.append_value(self.row as i32);
results.push(value);
}
self.generator = generator.map(|(gen, _)| Persistent::save(&ctx, gen));
if results.is_empty() {
return Ok(None);
}
let indexes = Arc::new(indexes.finish());
let array = jsarrow::build_array(&self.function.return_type, &ctx, results)
.context("failed to build arrow array from return values")?;
Ok(Some(RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("row", DataType::Int32, true),
Field::new(self.name, array.data_type().clone(), true),
])),
vec![indexes, array],
)?))
})
}
}
impl Iterator for State<'_> {
type Item = Result<RecordBatch>;
fn next(&mut self) -> Option<Self::Item> {
self.next().transpose()
}
}
Ok(State {
context: &self.context,
bigdecimal: &self.bigdecimal,
input,
function: self.functions.get(name).context("function not found")?,
name,
chunk_size,
row: 0,
generator: None,
})
}
}
fn check_exception(err: rquickjs::Error, ctx: &Ctx) -> anyhow::Error {
match err {
rquickjs::Error::Exception => {
anyhow!("exception generated by QuickJS: {:?}", ctx.catch())
}
e => e.into(),
}
}