#![doc = include_str!("README.md")]
use std::collections::HashMap;
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::{Arc, atomic::Ordering};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use anyhow::{Context as _, Result, anyhow, bail};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, builder::Int32Builder};
use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
use futures_util::{FutureExt, Stream};
use rquickjs::context::intrinsic::{All, Base};
pub use rquickjs::runtime::MemoryUsage;
use rquickjs::{
Array as JsArray, AsyncContext, AsyncRuntime, Ctx, FromJs, IteratorJs as _, Module, Object,
Persistent, Promise, Value, async_with, function::Args, module::Evaluated,
};
use crate::CallMode;
use crate::into_field::IntoField;
#[cfg(feature = "javascript-fetch")]
mod fetch;
mod jsarrow;
pub struct Runtime {
functions: HashMap<String, Function>,
aggregates: HashMap<String, Aggregate>,
converter: jsarrow::Converter,
runtime: AsyncRuntime,
context: AsyncContext,
timeout: Option<Duration>,
deadline: Arc<atomic_time::AtomicOptionInstant>,
}
impl Debug for Runtime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Runtime")
.field("functions", &self.functions.keys())
.field("aggregates", &self.aggregates.keys())
.field("timeout", &self.timeout)
.finish()
}
}
struct Function {
function: JsFunction,
return_field: FieldRef,
options: FunctionOptions,
}
struct Aggregate {
state_field: FieldRef,
output_field: FieldRef,
create_state: JsFunction,
accumulate: JsFunction,
retract: Option<JsFunction>,
finish: Option<JsFunction>,
merge: Option<JsFunction>,
options: AggregateOptions,
}
unsafe impl Send for Function {}
unsafe impl Sync for Function {}
unsafe impl Send for Aggregate {}
unsafe impl Sync for Aggregate {}
type JsFunction = Persistent<rquickjs::Function<'static>>;
unsafe impl Send for Runtime {}
unsafe impl Sync for Runtime {}
#[derive(Debug, Clone, Default)]
pub struct FunctionOptions {
pub call_mode: CallMode,
pub is_async: bool,
pub is_batched: bool,
pub handler: Option<String>,
}
impl FunctionOptions {
pub fn return_null_on_null_input(mut self) -> Self {
self.call_mode = CallMode::ReturnNullOnNullInput;
self
}
pub fn async_mode(mut self) -> Self {
self.is_async = true;
self
}
pub fn batched(mut self) -> Self {
self.is_batched = true;
self
}
pub fn handler(mut self, handler: impl Into<String>) -> Self {
self.handler = Some(handler.into());
self
}
}
#[derive(Debug, Clone, Default)]
pub struct AggregateOptions {
pub call_mode: CallMode,
pub is_async: bool,
}
impl AggregateOptions {
pub fn return_null_on_null_input(mut self) -> Self {
self.call_mode = CallMode::ReturnNullOnNullInput;
self
}
pub fn async_mode(mut self) -> Self {
self.is_async = true;
self
}
}
impl Runtime {
pub async fn new() -> Result<Self> {
let runtime = AsyncRuntime::new().context("failed to create quickjs runtime")?;
let context = AsyncContext::custom::<(Base, All)>(&runtime)
.await
.context("failed to create quickjs context")?;
Ok(Self {
functions: HashMap::new(),
aggregates: HashMap::new(),
runtime,
context,
timeout: None,
deadline: Default::default(),
converter: jsarrow::Converter::new(),
})
}
pub async fn set_memory_limit(&self, limit: Option<usize>) {
self.runtime.set_memory_limit(limit.unwrap_or(0)).await;
}
pub async fn set_timeout(&mut self, timeout: Option<Duration>) {
self.timeout = timeout;
if timeout.is_some() {
let deadline = self.deadline.clone();
self.runtime
.set_interrupt_handler(Some(Box::new(move || {
if let Some(deadline) = deadline.load(Ordering::Relaxed) {
return deadline <= Instant::now();
}
false
})))
.await;
} else {
self.runtime.set_interrupt_handler(None).await;
}
}
pub fn inner(&self) -> &AsyncRuntime {
&self.runtime
}
pub fn converter_mut(&mut self) -> &mut jsarrow::Converter {
&mut self.converter
}
pub async fn add_function(
&mut self,
name: &str,
return_type: impl IntoField + Send,
code: &str,
options: FunctionOptions,
) -> Result<()> {
let function = async_with!(self.context => |ctx| {
let (module, _) = Module::declare(ctx.clone(), name, code)
.map_err(|e| check_exception(e, &ctx))
.context("failed to declare module")?
.eval()
.map_err(|e| check_exception(e, &ctx))
.context("failed to evaluate module")?;
let function = Self::get_function(&ctx, &module, options.handler.as_deref().unwrap_or(name))?;
Ok(Function {
function,
return_field: return_type.into_field(name).into(),
options,
}) as Result<Function>
})
.await?;
self.functions.insert(name.to_string(), function);
Ok(())
}
fn get_function<'a>(
ctx: &Ctx<'a>,
module: &Module<'a, Evaluated>,
name: &str,
) -> Result<JsFunction> {
let function: rquickjs::Function = module.get(name).with_context(|| {
format!("function \"{name}\" not found. HINT: make sure the function is exported")
})?;
Ok(Persistent::save(ctx, function))
}
pub async fn add_aggregate(
&mut self,
name: &str,
state_type: impl IntoField + Send,
output_type: impl IntoField + Send,
code: &str,
options: AggregateOptions,
) -> Result<()> {
let aggregate = async_with!(self.context => |ctx| {
let (module, _) = Module::declare(ctx.clone(), name, code)
.map_err(|e| check_exception(e, &ctx))
.context("failed to declare module")?
.eval()
.map_err(|e| check_exception(e, &ctx))
.context("failed to evaluate module")?;
Ok(Aggregate {
state_field: state_type.into_field(name).into(),
output_field: output_type.into_field(name).into(),
create_state: Self::get_function(&ctx, &module, "create_state")?,
accumulate: Self::get_function(&ctx, &module, "accumulate")?,
retract: Self::get_function(&ctx, &module, "retract").ok(),
finish: Self::get_function(&ctx, &module, "finish").ok(),
merge: Self::get_function(&ctx, &module, "merge").ok(),
options,
}) as Result<Aggregate>
})
.await?;
if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
}
self.aggregates.insert(name.to_string(), aggregate);
Ok(())
}
#[doc = include_str!("doc_create_function.txt")]
pub async fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
let function = self.functions.get(name).context("function not found")?;
async_with!(self.context => |ctx| {
if function.options.is_batched {
self.call_batched_function(&ctx, function, input).await
} else {
self.call_non_batched_function(&ctx, function, input).await
}
})
.await
}
async fn call_non_batched_function(
&self,
ctx: &Ctx<'_>,
function: &Function,
input: &RecordBatch,
) -> Result<RecordBatch> {
let js_function = function.function.clone().restore(ctx)?;
let mut results = Vec::with_capacity(input.num_rows());
let mut row = Vec::with_capacity(input.num_columns());
for i in 0..input.num_rows() {
row.clear();
for (column, field) in input.columns().iter().zip(input.schema().fields()) {
let val = self
.converter
.get_jsvalue(ctx, field, column, i)
.context("failed to get jsvalue from arrow array")?;
row.push(val);
}
if function.options.call_mode == CallMode::ReturnNullOnNullInput
&& row.iter().any(|v| v.is_null())
{
results.push(Value::new_null(ctx.clone()));
continue;
}
let mut args = Args::new(ctx.clone(), row.len());
args.push_args(row.drain(..))?;
let result = self
.call_user_fn(ctx, &js_function, args, function.options.is_async)
.await
.context("failed to call function")?;
results.push(result);
}
let array = self
.converter
.build_array(&function.return_field, ctx, results)
.context("failed to build arrow array from return values")?;
let schema = Schema::new(vec![function.return_field.clone()]);
Ok(RecordBatch::try_new(Arc::new(schema), vec![array])?)
}
async fn call_batched_function(
&self,
ctx: &Ctx<'_>,
function: &Function,
input: &RecordBatch,
) -> Result<RecordBatch> {
let js_function = function.function.clone().restore(ctx)?;
let mut js_columns = Vec::with_capacity(input.num_columns());
for (column, field) in input.columns().iter().zip(input.schema().fields()) {
let mut js_values = Vec::with_capacity(input.num_rows());
for i in 0..input.num_rows() {
let val = self
.converter
.get_jsvalue(ctx, field, column, i)
.context("failed to get jsvalue from arrow array")?;
js_values.push(val);
}
js_columns.push(js_values);
}
let result = match function.options.call_mode {
CallMode::CalledOnNullInput => {
let mut args = Args::new(ctx.clone(), input.num_columns());
for js_values in js_columns {
let js_array = js_values.into_iter().collect_js::<JsArray>(ctx)?;
args.push_arg(js_array)?;
}
self.call_user_fn(ctx, &js_function, args, function.options.is_async)
.await
.context("failed to call function")?
}
CallMode::ReturnNullOnNullInput => {
let n_cols = input.num_columns();
let n_rows = input.num_rows();
let bitmap: Vec<bool> = (0..n_rows)
.map(|row_idx| {
let has_null = (0..n_cols).any(|j| js_columns[j][row_idx].is_null());
!has_null
})
.collect();
let mut filtered_columns = Vec::with_capacity(n_cols);
for js_values in js_columns {
let filtered_js_values: Vec<_> = js_values
.into_iter()
.zip(bitmap.iter())
.filter(|(_, b)| **b)
.map(|(v, _)| v)
.collect();
filtered_columns.push(filtered_js_values);
}
let mut args = Args::new(ctx.clone(), filtered_columns.len());
for js_values in filtered_columns {
let js_array = js_values.into_iter().collect_js::<JsArray>(ctx)?;
args.push_arg(js_array)?;
}
let filtered_result: Vec<_> = self
.call_user_fn(ctx, &js_function, args, function.options.is_async)
.await
.context("failed to call function")?;
let mut iter = filtered_result.into_iter();
let mut result = Vec::with_capacity(n_rows);
for b in bitmap.iter() {
if *b {
let v = iter.next().expect("filtered result length mismatch");
result.push(v);
} else {
result.push(Value::new_null(ctx.clone()));
}
}
assert!(iter.next().is_none(), "filtered result length mismatch");
result
}
};
let array = self
.converter
.build_array(&function.return_field, ctx, result)?;
let schema = Schema::new(vec![function.return_field.clone()]);
Ok(RecordBatch::try_new(Arc::new(schema), vec![array])?)
}
#[doc = include_str!("doc_create_function.txt")]
pub fn call_table_function<'a>(
&'a self,
name: &'a str,
input: &'a RecordBatch,
chunk_size: usize,
) -> Result<RecordBatchIter<'a>> {
assert!(chunk_size > 0);
let function = self.functions.get(name).context("function not found")?;
if function.options.is_batched {
bail!("table function does not support batched mode");
}
Ok(RecordBatchIter {
rt: self,
input,
function,
schema: Arc::new(Schema::new(vec![
Arc::new(Field::new("row", DataType::Int32, false)),
function.return_field.clone(),
])),
chunk_size,
row: 0,
generator: None,
converter: &self.converter,
})
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub async fn create_state(&self, name: &str) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
let state = async_with!(self.context => |ctx| {
let create_state = aggregate.create_state.clone().restore(&ctx)?;
let state = self
.call_user_fn(&ctx, &create_state, Args::new(ctx.clone(), 0), aggregate.options.is_async)
.await
.context("failed to call create_state")?;
let state = self
.converter
.build_array(&aggregate.state_field, &ctx, vec![state])?;
Ok(state) as Result<_>
})
.await?;
Ok(state)
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub async fn accumulate(
&self,
name: &str,
state: &dyn Array,
input: &RecordBatch,
) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
let new_state = async_with!(self.context => |ctx| {
let accumulate = aggregate.accumulate.clone().restore(&ctx)?;
let mut state = self
.converter
.get_jsvalue(&ctx, &aggregate.state_field, state, 0)?;
let mut row = Vec::with_capacity(1 + input.num_columns());
for i in 0..input.num_rows() {
if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput
&& input.columns().iter().any(|column| column.is_null(i))
{
continue;
}
row.clear();
row.push(state.clone());
for (column, field) in input.columns().iter().zip(input.schema().fields()) {
let pyobj = self.converter.get_jsvalue(&ctx, field, column, i)?;
row.push(pyobj);
}
let mut args = Args::new(ctx.clone(), row.len());
args.push_args(row.drain(..))?;
state = self
.call_user_fn(&ctx, &accumulate, args, aggregate.options.is_async)
.await
.context("failed to call accumulate")?;
}
let output = self
.converter
.build_array(&aggregate.state_field, &ctx, vec![state])?;
Ok(output) as Result<_>
})
.await?;
Ok(new_state)
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub async fn accumulate_or_retract(
&self,
name: &str,
state: &dyn Array,
ops: &BooleanArray,
input: &RecordBatch,
) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
let new_state = async_with!(self.context => |ctx| {
let accumulate = aggregate.accumulate.clone().restore(&ctx)?;
let retract = aggregate
.retract
.clone()
.context("function does not support retraction")?
.restore(&ctx)?;
let mut state = self
.converter
.get_jsvalue(&ctx, &aggregate.state_field, state, 0)?;
let mut row = Vec::with_capacity(1 + input.num_columns());
for i in 0..input.num_rows() {
if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput
&& input.columns().iter().any(|column| column.is_null(i))
{
continue;
}
row.clear();
row.push(state.clone());
for (column, field) in input.columns().iter().zip(input.schema().fields()) {
let pyobj = self.converter.get_jsvalue(&ctx, field, column, i)?;
row.push(pyobj);
}
let func = if ops.is_valid(i) && ops.value(i) {
&retract
} else {
&accumulate
};
let mut args = Args::new(ctx.clone(), row.len());
args.push_args(row.drain(..))?;
state = self
.call_user_fn(&ctx, func, args, aggregate.options.is_async)
.await
.context("failed to call accumulate or retract")?;
}
let output = self
.converter
.build_array(&aggregate.state_field, &ctx, vec![state])?;
Ok(output) as Result<_>
})
.await?;
Ok(new_state)
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub async fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
let output = async_with!(self.context => |ctx| {
let merge = aggregate
.merge
.clone()
.context("merge not found")?
.restore(&ctx)?;
let mut state = self
.converter
.get_jsvalue(&ctx, &aggregate.state_field, states, 0)?;
for i in 1..states.len() {
if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
continue;
}
let state2 = self
.converter
.get_jsvalue(&ctx, &aggregate.state_field, states, i)?;
let mut args = Args::new(ctx.clone(), 2);
args.push_args([state, state2])?;
state = self
.call_user_fn(&ctx, &merge, args, aggregate.options.is_async)
.await
.context("failed to call accumulate or retract")?;
}
let output = self
.converter
.build_array(&aggregate.state_field, &ctx, vec![state])?;
Ok(output) as Result<_>
})
.await?;
Ok(output)
}
#[doc = include_str!("doc_create_aggregate.txt")]
pub async fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
let aggregate = self.aggregates.get(name).context("function not found")?;
if aggregate.finish.is_none() {
return Ok(states.clone());
};
let output = async_with!(self.context => |ctx| {
let finish = aggregate.finish.clone().unwrap().restore(&ctx)?;
let mut results = Vec::with_capacity(states.len());
for i in 0..states.len() {
if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
results.push(Value::new_null(ctx.clone()));
continue;
}
let state =
self.converter
.get_jsvalue(&ctx, &aggregate.state_field, states, i)?;
let mut args = Args::new(ctx.clone(), 1);
args.push_args([state])?;
let result = self
.call_user_fn(&ctx, &finish, args, aggregate.options.is_async)
.await
.context("failed to call finish")?;
results.push(result);
}
let output = self
.converter
.build_array(&aggregate.output_field, &ctx, results)?;
Ok(output) as Result<_>
})
.await?;
Ok(output)
}
async fn call_user_fn<'js, T: FromJs<'js>>(
&self,
ctx: &Ctx<'js>,
f: &rquickjs::Function<'js>,
args: Args<'js>,
is_async: bool,
) -> Result<T> {
if is_async {
Self::call_user_fn_async(self, ctx, f, args).await
} else {
Self::call_user_fn_sync(self, ctx, f, args)
}
}
async fn call_user_fn_async<'js, T: FromJs<'js>>(
&self,
ctx: &Ctx<'js>,
f: &rquickjs::Function<'js>,
args: Args<'js>,
) -> Result<T> {
let call_result = if let Some(timeout) = self.timeout {
self.deadline
.store(Some(Instant::now() + timeout), Ordering::Relaxed);
let call_result = f.call_arg::<Promise>(args);
self.deadline.store(None, Ordering::Relaxed);
call_result
} else {
f.call_arg::<Promise>(args)
};
let promise = call_result.map_err(|e| check_exception(e, ctx))?;
promise
.into_future::<T>()
.await
.map_err(|e| check_exception(e, ctx))
}
fn call_user_fn_sync<'js, T: FromJs<'js>>(
&self,
ctx: &Ctx<'js>,
f: &rquickjs::Function<'js>,
args: Args<'js>,
) -> Result<T> {
let result = if let Some(timeout) = self.timeout {
self.deadline
.store(Some(Instant::now() + timeout), Ordering::Relaxed);
let result = f.call_arg(args);
self.deadline.store(None, Ordering::Relaxed);
result
} else {
f.call_arg(args)
};
result.map_err(|e| check_exception(e, ctx))
}
pub fn context(&self) -> &AsyncContext {
&self.context
}
#[cfg(feature = "javascript-fetch")]
pub async fn enable_fetch(&self) -> Result<()> {
fetch::enable_fetch(&self.runtime, &self.context).await
}
}
pub struct RecordBatchIter<'a> {
rt: &'a Runtime,
input: &'a RecordBatch,
function: &'a Function,
schema: SchemaRef,
chunk_size: usize,
row: usize,
generator: Option<Persistent<Object<'static>>>,
converter: &'a jsarrow::Converter,
}
unsafe impl Send for RecordBatchIter<'_> {}
impl RecordBatchIter<'_> {
pub fn schema(&self) -> &Schema {
&self.schema
}
pub async fn next(&mut self) -> Result<Option<RecordBatch>> {
if self.row == self.input.num_rows() {
return Ok(None);
}
async_with!(self.rt.context => |ctx| {
let js_function = self.function.function.clone().restore(&ctx)?;
let mut indexes = Int32Builder::with_capacity(self.chunk_size);
let mut results = Vec::with_capacity(self.input.num_rows());
let mut row = Vec::with_capacity(self.input.num_columns());
let mut generator = match self.generator.take() {
Some(generator) => {
let generator_obj = generator.restore(&ctx)?;
let next: rquickjs::Function =
generator_obj.get("next").context("failed to get 'next' method")?;
Some((generator_obj, next))
}
None => None,
};
while self.row < self.input.num_rows() && results.len() < self.chunk_size {
let (generator_obj, next) = if let Some(g) = generator.as_ref() {
g
} else {
row.clear();
for (column, field) in
(self.input.columns().iter()).zip(self.input.schema().fields())
{
let val = self
.converter
.get_jsvalue(&ctx, field, column, self.row)
.context("failed to get jsvalue from arrow array")?;
row.push(val);
}
if self.function.options.call_mode == CallMode::ReturnNullOnNullInput
&& row.iter().any(|v| v.is_null())
{
self.row += 1;
continue;
}
let mut args = Args::new(ctx.clone(), row.len());
args.push_args(row.drain(..))?;
let generator_obj: Object = self
.rt
.call_user_fn(&ctx, &js_function, args, false).await
.context("failed to call function")?;
let next: rquickjs::Function =
generator_obj.get("next").context("failed to get 'next' method")?;
let mut args = Args::new(ctx.clone(), 0);
args.this(generator_obj.clone())?;
generator.insert((generator_obj, next))
};
let mut args = Args::new(ctx.clone(), 0);
args.this(generator_obj.clone())?;
let object: Object = self
.rt
.call_user_fn(&ctx, next, args, self.function.options.is_async).await
.context("failed to call next")?;
let value: Value = object.get("value")?;
let done: bool = object.get("done")?;
if done {
self.row += 1;
generator = None;
continue;
}
indexes.append_value(self.row as i32);
results.push(value);
}
self.generator = generator.map(|(generator_obj, _)| Persistent::save(&ctx, generator_obj));
if results.is_empty() {
return Ok(None);
}
let indexes = Arc::new(indexes.finish());
let array = self
.converter
.build_array(&self.function.return_field, &ctx, results)
.context("failed to build arrow array from return values")?;
Ok(Some(RecordBatch::try_new(
self.schema.clone(),
vec![indexes, array],
)?))
})
.await
}
}
impl Stream for RecordBatchIter<'_> {
type Item = Result<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Box::pin(self.next().map(|v| v.transpose()))
.as_mut()
.poll_unpin(cx)
}
}
pub(crate) fn check_exception(err: rquickjs::Error, ctx: &Ctx) -> anyhow::Error {
match err {
rquickjs::Error::Exception => {
anyhow!("exception generated by QuickJS: {:?}", ctx.catch())
}
e => e.into(),
}
}