arrow_udf_runtime/javascript/
mod.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("README.md")]
16
17use std::collections::HashMap;
18use std::fmt::Debug;
19use std::pin::Pin;
20use std::sync::{atomic::Ordering, Arc};
21use std::task::{Context, Poll};
22use std::time::{Duration, Instant};
23
24use anyhow::{anyhow, bail, Context as _, Result};
25use arrow_array::{builder::Int32Builder, Array, ArrayRef, BooleanArray, RecordBatch};
26use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
27use futures_util::{FutureExt, Stream};
28use rquickjs::context::intrinsic::{All, Base};
29pub use rquickjs::runtime::MemoryUsage;
30use rquickjs::{
31    async_with, function::Args, module::Evaluated, Array as JsArray, AsyncContext, AsyncRuntime,
32    Ctx, FromJs, IteratorJs as _, Module, Object, Persistent, Promise, Value,
33};
34
35use crate::into_field::IntoField;
36use crate::CallMode;
37
38#[cfg(feature = "javascript-fetch")]
39mod fetch;
40mod jsarrow;
41
42/// A runtime to execute user defined functions in JavaScript.
43///
44/// # Usages
45///
46/// - Create a new runtime with [`Runtime::new`].
47/// - For scalar functions, use [`add_function`] and [`call`].
48/// - For table functions, use [`add_function`] and [`call_table_function`].
49/// - For aggregate functions, create the function with [`add_aggregate`], and then
50///     - create a new state with [`create_state`],
51///     - update the state with [`accumulate`] or [`accumulate_or_retract`],
52///     - merge states with [`merge`],
53///     - finally get the result with [`finish`].
54///
55/// Click on each function to see the example.
56///
57/// [`add_function`]: Runtime::add_function
58/// [`add_aggregate`]: Runtime::add_aggregate
59/// [`call`]: Runtime::call
60/// [`call_table_function`]: Runtime::call_table_function
61/// [`create_state`]: Runtime::create_state
62/// [`accumulate`]: Runtime::accumulate
63/// [`accumulate_or_retract`]: Runtime::accumulate_or_retract
64/// [`merge`]: Runtime::merge
65/// [`finish`]: Runtime::finish
66pub struct Runtime {
67    functions: HashMap<String, Function>,
68    aggregates: HashMap<String, Aggregate>,
69    // NOTE: `functions` and `aggregates` must be put before the `runtime` and `context` to be dropped first.
70    converter: jsarrow::Converter,
71    runtime: AsyncRuntime,
72    context: AsyncContext,
73    /// Timeout of each function call.
74    timeout: Option<Duration>,
75    /// Deadline of the current function call.
76    deadline: Arc<atomic_time::AtomicOptionInstant>,
77}
78
79impl Debug for Runtime {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("Runtime")
82            .field("functions", &self.functions.keys())
83            .field("aggregates", &self.aggregates.keys())
84            .field("timeout", &self.timeout)
85            .finish()
86    }
87}
88
89/// A user defined scalar function or table function.
90struct Function {
91    function: JsFunction,
92    return_field: FieldRef,
93    options: FunctionOptions,
94}
95
96/// A user defined aggregate function.
97struct Aggregate {
98    state_field: FieldRef,
99    output_field: FieldRef,
100    create_state: JsFunction,
101    accumulate: JsFunction,
102    retract: Option<JsFunction>,
103    finish: Option<JsFunction>,
104    merge: Option<JsFunction>,
105    options: AggregateOptions,
106}
107
108// This is required to pass `Function` and `Aggregate` from `async_with!` to outside.
109// Otherwise, the compiler will complain that `*mut JSRuntime` cannot be sent between threads safely
110// SAFETY: We ensure the `JSRuntime` used in `async_with!` is same as the caller's.
111// The `parallel` feature of `rquickjs` is enabled, so itself can't ensure this.
112unsafe impl Send for Function {}
113unsafe impl Sync for Function {}
114unsafe impl Send for Aggregate {}
115unsafe impl Sync for Aggregate {}
116
117/// A persistent function, can be either sync or async.
118type JsFunction = Persistent<rquickjs::Function<'static>>;
119
120// SAFETY: `rquickjs::Runtime` is `Send` and `Sync`
121unsafe impl Send for Runtime {}
122unsafe impl Sync for Runtime {}
123
124/// Options for configuring user-defined functions.
125#[derive(Debug, Clone, Default)]
126pub struct FunctionOptions {
127    /// Whether the function will be called when some of its arguments are null.
128    pub call_mode: CallMode,
129    /// Whether the function is async. An async function would return a Promise.
130    pub is_async: bool,
131    /// Whether the function accepts a batch of records as input.
132    pub is_batched: bool,
133    /// The name of the function in JavaScript code to be called.
134    /// If not set, the function name will be used.
135    pub handler: Option<String>,
136}
137
138impl FunctionOptions {
139    /// Sets the function to return null when some of its arguments are null.
140    /// See [`CallMode`] for more details.
141    pub fn return_null_on_null_input(mut self) -> Self {
142        self.call_mode = CallMode::ReturnNullOnNullInput;
143        self
144    }
145
146    /// Marks the function to be async JS function.
147    pub fn async_mode(mut self) -> Self {
148        self.is_async = true;
149        self
150    }
151
152    /// Sets the function to accept a batch of records as input.
153    pub fn batched(mut self) -> Self {
154        self.is_batched = true;
155        self
156    }
157
158    /// Sets the name of the function in JavaScript code to be called.
159    pub fn handler(mut self, handler: impl Into<String>) -> Self {
160        self.handler = Some(handler.into());
161        self
162    }
163}
164
165/// Options for configuring user-defined aggregate functions.
166#[derive(Debug, Clone, Default)]
167pub struct AggregateOptions {
168    /// Whether the function will be called when some of its arguments are null.
169    pub call_mode: CallMode,
170    /// Whether the function is async. An async function would return a Promise.
171    pub is_async: bool,
172}
173
174impl AggregateOptions {
175    /// Sets the function to return null when some of its arguments are null.
176    /// See [`CallMode`] for more details.
177    pub fn return_null_on_null_input(mut self) -> Self {
178        self.call_mode = CallMode::ReturnNullOnNullInput;
179        self
180    }
181
182    /// Marks the function to be async JS function.
183    pub fn async_mode(mut self) -> Self {
184        self.is_async = true;
185        self
186    }
187}
188
189impl Runtime {
190    /// Create a new `Runtime`.
191    ///
192    /// # Example
193    ///
194    /// ```
195    /// # use arrow_udf_runtime::javascript::Runtime;
196    /// # tokio_test::block_on(async {
197    /// let runtime = Runtime::new().await.unwrap();
198    /// runtime.set_memory_limit(Some(1 << 20)); // 1MB
199    /// # });
200    /// ```
201    pub async fn new() -> Result<Self> {
202        let runtime = AsyncRuntime::new().context("failed to create quickjs runtime")?;
203        let context = AsyncContext::custom::<(Base, All)>(&runtime)
204            .await
205            .context("failed to create quickjs context")?;
206
207        Ok(Self {
208            functions: HashMap::new(),
209            aggregates: HashMap::new(),
210            runtime,
211            context,
212            timeout: None,
213            deadline: Default::default(),
214            converter: jsarrow::Converter::new(),
215        })
216    }
217
218    /// Set the memory limit of the runtime.
219    ///
220    /// # Example
221    ///
222    /// ```
223    /// # use arrow_udf_runtime::javascript::Runtime;
224    /// # tokio_test::block_on(async {
225    /// let runtime = Runtime::new().await.unwrap();
226    /// runtime.set_memory_limit(Some(1 << 20)); // 1MB
227    /// # });
228    /// ```
229    pub async fn set_memory_limit(&self, limit: Option<usize>) {
230        self.runtime.set_memory_limit(limit.unwrap_or(0)).await;
231    }
232
233    /// Set the timeout of each function call.
234    ///
235    /// # Example
236    ///
237    /// ```
238    /// # use arrow_udf_runtime::javascript::Runtime;
239    /// # use std::time::Duration;
240    /// # tokio_test::block_on(async {
241    /// let mut runtime = Runtime::new().await.unwrap();
242    /// runtime.set_timeout(Some(Duration::from_secs(1))).await;
243    /// # });
244    /// ```
245    pub async fn set_timeout(&mut self, timeout: Option<Duration>) {
246        self.timeout = timeout;
247        if timeout.is_some() {
248            let deadline = self.deadline.clone();
249            self.runtime
250                .set_interrupt_handler(Some(Box::new(move || {
251                    if let Some(deadline) = deadline.load(Ordering::Relaxed) {
252                        return deadline <= Instant::now();
253                    }
254                    false
255                })))
256                .await;
257        } else {
258            self.runtime.set_interrupt_handler(None).await;
259        }
260    }
261
262    /// Return the inner quickjs runtime.
263    pub fn inner(&self) -> &AsyncRuntime {
264        &self.runtime
265    }
266
267    /// Return the converter where you can configure the extension metadata key and values.
268    pub fn converter_mut(&mut self) -> &mut jsarrow::Converter {
269        &mut self.converter
270    }
271
272    /// Add a new scalar function or table function.
273    ///
274    /// # Arguments
275    ///
276    /// - `name`: The name of the function.
277    /// - `return_type`: The data type of the return value.
278    /// - `options`: The options for configuring the function.
279    /// - `code`: The JavaScript code of the function.
280    ///
281    /// The code should define an **exported** function with the same name as the function.
282    /// The function should return a value for scalar functions, or yield values for table functions.
283    ///
284    /// # Example
285    ///
286    /// ```
287    /// # use arrow_udf_runtime::javascript::{FunctionOptions, Runtime};
288    /// # use arrow_schema::DataType;
289    /// # tokio_test::block_on(async {
290    /// let mut runtime = Runtime::new().await.unwrap();
291    /// // add a scalar function
292    /// runtime
293    ///     .add_function(
294    ///         "gcd",
295    ///         DataType::Int32,
296    ///         r#"
297    ///         export function gcd(a, b) {
298    ///             while (b != 0) {
299    ///                 let t = b;
300    ///                 b = a % b;
301    ///                 a = t;
302    ///             }
303    ///             return a;
304    ///         }
305    /// "#,
306    ///         FunctionOptions::default().return_null_on_null_input(),
307    ///     )
308    ///     .await
309    ///     .unwrap();
310    /// // add a table function
311    /// runtime
312    ///     .add_function(
313    ///         "series",
314    ///         DataType::Int32,
315    ///         r#"
316    ///         export function* series(n) {
317    ///             for (let i = 0; i < n; i++) {
318    ///                 yield i;
319    ///             }
320    ///         }
321    /// "#,
322    ///         FunctionOptions::default().return_null_on_null_input(),
323    ///     )
324    ///     .await
325    ///     .unwrap();
326    /// # });
327    /// ```
328    pub async fn add_function(
329        &mut self,
330        name: &str,
331        return_type: impl IntoField + Send,
332        code: &str,
333        options: FunctionOptions,
334    ) -> Result<()> {
335        let function = async_with!(self.context => |ctx| {
336            let (module, _) = Module::declare(ctx.clone(), name, code)
337                .map_err(|e| check_exception(e, &ctx))
338                .context("failed to declare module")?
339                .eval()
340                .map_err(|e| check_exception(e, &ctx))
341                .context("failed to evaluate module")?;
342            let function = Self::get_function(&ctx, &module, options.handler.as_deref().unwrap_or(name))?;
343            Ok(Function {
344                function,
345                return_field: return_type.into_field(name).into(),
346                options,
347            }) as Result<Function>
348        })
349        .await?;
350        self.functions.insert(name.to_string(), function);
351        Ok(())
352    }
353
354    /// Get a function from a module.
355    fn get_function<'a>(
356        ctx: &Ctx<'a>,
357        module: &Module<'a, Evaluated>,
358        name: &str,
359    ) -> Result<JsFunction> {
360        let function: rquickjs::Function = module.get(name).with_context(|| {
361            format!("function \"{name}\" not found. HINT: make sure the function is exported")
362        })?;
363        Ok(Persistent::save(ctx, function))
364    }
365
366    /// Add a new aggregate function.
367    ///
368    /// # Arguments
369    ///
370    /// - `name`: The name of the function.
371    /// - `state_type`: The data type of the internal state.
372    /// - `output_type`: The data type of the aggregate value.
373    /// - `mode`: Whether the function will be called when some of its arguments are null.
374    /// - `code`: The JavaScript code of the aggregate function.
375    ///
376    /// The code should define at least two functions:
377    ///
378    /// - `create_state() -> state`: Create a new state object.
379    /// - `accumulate(state, *args) -> state`: Accumulate a new value into the state, returning the updated state.
380    ///
381    /// optionally, the code can define:
382    ///
383    /// - `finish(state) -> value`: Get the result of the aggregate function.
384    ///   If not defined, the state is returned as the result.
385    ///   In this case, `output_type` must be the same as `state_type`.
386    /// - `retract(state, *args) -> state`: Retract a value from the state, returning the updated state.
387    /// - `merge(state, state) -> state`: Merge two states, returning the merged state.
388    ///
389    /// Each function must be **exported**.
390    ///
391    /// # Example
392    ///
393    /// ```
394    /// # use arrow_udf_runtime::javascript::{AggregateOptions, Runtime};
395    /// # use arrow_schema::DataType;
396    /// # tokio_test::block_on(async {
397    /// let mut runtime = Runtime::new().await.unwrap();
398    /// runtime
399    ///     .add_aggregate(
400    ///         "sum",
401    ///         DataType::Int32, // state_type
402    ///         DataType::Int32, // output_type
403    ///         r#"
404    ///         export function create_state() {
405    ///             return 0;
406    ///         }
407    ///         export function accumulate(state, value) {
408    ///             return state + value;
409    ///         }
410    ///         export function retract(state, value) {
411    ///             return state - value;
412    ///         }
413    ///         export function merge(state1, state2) {
414    ///             return state1 + state2;
415    ///         }
416    ///         "#,
417    ///         AggregateOptions::default().return_null_on_null_input(),
418    ///     )
419    ///     .await
420    ///     .unwrap();
421    /// # });
422    /// ```
423    pub async fn add_aggregate(
424        &mut self,
425        name: &str,
426        state_type: impl IntoField + Send,
427        output_type: impl IntoField + Send,
428        code: &str,
429        options: AggregateOptions,
430    ) -> Result<()> {
431        let aggregate = async_with!(self.context => |ctx| {
432            let (module, _) = Module::declare(ctx.clone(), name, code)
433                .map_err(|e| check_exception(e, &ctx))
434                .context("failed to declare module")?
435                .eval()
436                .map_err(|e| check_exception(e, &ctx))
437                .context("failed to evaluate module")?;
438            Ok(Aggregate {
439                state_field: state_type.into_field(name).into(),
440                output_field: output_type.into_field(name).into(),
441                create_state: Self::get_function(&ctx, &module, "create_state")?,
442                accumulate: Self::get_function(&ctx, &module, "accumulate")?,
443                retract: Self::get_function(&ctx, &module, "retract").ok(),
444                finish: Self::get_function(&ctx, &module, "finish").ok(),
445                merge: Self::get_function(&ctx, &module, "merge").ok(),
446                options,
447            }) as Result<Aggregate>
448        })
449        .await?;
450
451        if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
452            bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
453        }
454        self.aggregates.insert(name.to_string(), aggregate);
455        Ok(())
456    }
457
458    /// Call a scalar function.
459    ///
460    /// # Example
461    ///
462    /// ```
463    /// # tokio_test::block_on(async {
464    #[doc = include_str!("doc_create_function.txt")]
465    /// // suppose we have created a scalar function `gcd`
466    /// // see the example in `add_function`
467    ///
468    /// let schema = Schema::new(vec![
469    ///     Field::new("x", DataType::Int32, true),
470    ///     Field::new("y", DataType::Int32, true),
471    /// ]);
472    /// let arg0 = Int32Array::from(vec![Some(25), None]);
473    /// let arg1 = Int32Array::from(vec![Some(15), None]);
474    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0), Arc::new(arg1)]).unwrap();
475    ///
476    /// let output = runtime.call("gcd", &input).await.unwrap();
477    /// assert_eq!(&**output.column(0), &Int32Array::from(vec![Some(5), None]));
478    /// # });
479    /// ```
480    pub async fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
481        let function = self.functions.get(name).context("function not found")?;
482
483        async_with!(self.context => |ctx| {
484            if function.options.is_batched {
485                self.call_batched_function(&ctx, function, input).await
486            } else {
487                self.call_non_batched_function(&ctx, function, input).await
488            }
489        })
490        .await
491    }
492
493    async fn call_non_batched_function(
494        &self,
495        ctx: &Ctx<'_>,
496        function: &Function,
497        input: &RecordBatch,
498    ) -> Result<RecordBatch> {
499        let js_function = function.function.clone().restore(ctx)?;
500
501        let mut results = Vec::with_capacity(input.num_rows());
502        let mut row = Vec::with_capacity(input.num_columns());
503        for i in 0..input.num_rows() {
504            row.clear();
505            for (column, field) in input.columns().iter().zip(input.schema().fields()) {
506                let val = self
507                    .converter
508                    .get_jsvalue(ctx, field, column, i)
509                    .context("failed to get jsvalue from arrow array")?;
510
511                row.push(val);
512            }
513            if function.options.call_mode == CallMode::ReturnNullOnNullInput
514                && row.iter().any(|v| v.is_null())
515            {
516                results.push(Value::new_null(ctx.clone()));
517                continue;
518            }
519            let mut args = Args::new(ctx.clone(), row.len());
520            args.push_args(row.drain(..))?;
521            let result = self
522                .call_user_fn(ctx, &js_function, args, function.options.is_async)
523                .await
524                .context("failed to call function")?;
525            results.push(result);
526        }
527
528        let array = self
529            .converter
530            .build_array(&function.return_field, ctx, results)
531            .context("failed to build arrow array from return values")?;
532        let schema = Schema::new(vec![function.return_field.clone()]);
533        Ok(RecordBatch::try_new(Arc::new(schema), vec![array])?)
534    }
535
536    async fn call_batched_function(
537        &self,
538        ctx: &Ctx<'_>,
539        function: &Function,
540        input: &RecordBatch,
541    ) -> Result<RecordBatch> {
542        let js_function = function.function.clone().restore(ctx)?;
543
544        let mut js_columns = Vec::with_capacity(input.num_columns());
545        for (column, field) in input.columns().iter().zip(input.schema().fields()) {
546            let mut js_values = Vec::with_capacity(input.num_rows());
547            for i in 0..input.num_rows() {
548                let val = self
549                    .converter
550                    .get_jsvalue(ctx, field, column, i)
551                    .context("failed to get jsvalue from arrow array")?;
552                js_values.push(val);
553            }
554            js_columns.push(js_values);
555        }
556
557        let result = match function.options.call_mode {
558            CallMode::CalledOnNullInput => {
559                let mut args = Args::new(ctx.clone(), input.num_columns());
560                for js_values in js_columns {
561                    let js_array = js_values.into_iter().collect_js::<JsArray>(ctx)?;
562                    args.push_arg(js_array)?;
563                }
564                self.call_user_fn(ctx, &js_function, args, function.options.is_async)
565                    .await
566                    .context("failed to call function")?
567            }
568            CallMode::ReturnNullOnNullInput => {
569                // This is a bit tricky. We build input arrays without nulls, call user_fn on them,
570                // and then add back null results to form the final result.
571                let n_cols = input.num_columns();
572                let n_rows = input.num_rows();
573
574                // 1. Build a bitmap of which rows have nulls
575                let mut bitmap = Vec::with_capacity(n_rows);
576                for i in 0..n_rows {
577                    let has_null = (0..n_cols).any(|j| js_columns[j][i].is_null());
578                    bitmap.push(!has_null);
579                }
580
581                // 2. Build new inputs with only the rows that don't have nulls
582                let mut filtered_columns = Vec::with_capacity(n_cols);
583                for js_values in js_columns {
584                    let filtered_js_values: Vec<_> = js_values
585                        .into_iter()
586                        .zip(bitmap.iter())
587                        .filter(|(_, b)| **b)
588                        .map(|(v, _)| v)
589                        .collect();
590                    filtered_columns.push(filtered_js_values);
591                }
592
593                // 3. Call the function on the new inputs
594                let mut args = Args::new(ctx.clone(), filtered_columns.len());
595                for js_values in filtered_columns {
596                    let js_array = js_values.into_iter().collect_js::<JsArray>(ctx)?;
597                    args.push_arg(js_array)?;
598                }
599                let filtered_result: Vec<_> = self
600                    .call_user_fn(ctx, &js_function, args, function.options.is_async)
601                    .await
602                    .context("failed to call function")?;
603                let mut iter = filtered_result.into_iter();
604
605                // 4. Add back null results to the filtered results
606                let mut result = Vec::with_capacity(n_rows);
607                for b in bitmap.iter() {
608                    if *b {
609                        let v = iter.next().expect("filtered result length mismatch");
610                        result.push(v);
611                    } else {
612                        result.push(Value::new_null(ctx.clone()));
613                    }
614                }
615                assert!(iter.next().is_none(), "filtered result length mismatch");
616                result
617            }
618        };
619        let array = self
620            .converter
621            .build_array(&function.return_field, ctx, result)?;
622        let schema = Schema::new(vec![function.return_field.clone()]);
623        Ok(RecordBatch::try_new(Arc::new(schema), vec![array])?)
624    }
625
626    /// Call a table function.
627    ///
628    /// # Example
629    ///
630    /// ```
631    /// # tokio_test::block_on(async {
632    #[doc = include_str!("doc_create_function.txt")]
633    /// // suppose we have created a table function `series`
634    /// // see the example in `add_function`
635    ///
636    /// let schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]);
637    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3)]);
638    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
639    ///
640    /// let mut outputs = runtime.call_table_function("series", &input, 10).unwrap();
641    /// let output = outputs.next().await.unwrap().unwrap();
642    /// let pretty = arrow_cast::pretty::pretty_format_batches(&[output]).unwrap().to_string();
643    /// assert_eq!(pretty, r#"
644    /// +-----+--------+
645    /// | row | series |
646    /// +-----+--------+
647    /// | 0   | 0      |
648    /// | 2   | 0      |
649    /// | 2   | 1      |
650    /// | 2   | 2      |
651    /// +-----+--------+"#.trim());
652    /// # });
653    /// ```
654    pub fn call_table_function<'a>(
655        &'a self,
656        name: &'a str,
657        input: &'a RecordBatch,
658        chunk_size: usize,
659    ) -> Result<RecordBatchIter<'a>> {
660        assert!(chunk_size > 0);
661        let function = self.functions.get(name).context("function not found")?;
662        if function.options.is_batched {
663            bail!("table function does not support batched mode");
664        }
665
666        // initial state
667        Ok(RecordBatchIter {
668            rt: self,
669            input,
670            function,
671            schema: Arc::new(Schema::new(vec![
672                Arc::new(Field::new("row", DataType::Int32, false)),
673                function.return_field.clone(),
674            ])),
675            chunk_size,
676            row: 0,
677            generator: None,
678            converter: &self.converter,
679        })
680    }
681
682    /// Create a new state for an aggregate function.
683    ///
684    /// # Example
685    /// ```
686    /// # tokio_test::block_on(async {
687    #[doc = include_str!("doc_create_aggregate.txt")]
688    /// let state = runtime.create_state("sum").await.unwrap();
689    /// assert_eq!(&*state, &Int32Array::from(vec![0]));
690    /// # });
691    /// ```
692    pub async fn create_state(&self, name: &str) -> Result<ArrayRef> {
693        let aggregate = self.aggregates.get(name).context("function not found")?;
694        let state = async_with!(self.context => |ctx| {
695            let create_state = aggregate.create_state.clone().restore(&ctx)?;
696            let state = self
697                .call_user_fn(&ctx, &create_state, Args::new(ctx.clone(), 0), aggregate.options.is_async)
698                .await
699                .context("failed to call create_state")?;
700            let state = self
701                .converter
702                .build_array(&aggregate.state_field, &ctx, vec![state])?;
703            Ok(state) as Result<_>
704        })
705        .await?;
706        Ok(state)
707    }
708
709    /// Call accumulate of an aggregate function.
710    ///
711    /// # Example
712    /// ```
713    /// # tokio_test::block_on(async {
714    #[doc = include_str!("doc_create_aggregate.txt")]
715    /// let state = runtime.create_state("sum").await.unwrap();
716    ///
717    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
718    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
719    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
720    ///
721    /// let state = runtime.accumulate("sum", &state, &input).await.unwrap();
722    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
723    /// # });
724    /// ```
725    pub async fn accumulate(
726        &self,
727        name: &str,
728        state: &dyn Array,
729        input: &RecordBatch,
730    ) -> Result<ArrayRef> {
731        let aggregate = self.aggregates.get(name).context("function not found")?;
732        // convert each row to python objects and call the accumulate function
733        let new_state = async_with!(self.context => |ctx| {
734            let accumulate = aggregate.accumulate.clone().restore(&ctx)?;
735            let mut state = self
736                .converter
737                .get_jsvalue(&ctx, &aggregate.state_field, state, 0)?;
738
739            let mut row = Vec::with_capacity(1 + input.num_columns());
740            for i in 0..input.num_rows() {
741                if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput
742                    && input.columns().iter().any(|column| column.is_null(i))
743                {
744                    continue;
745                }
746                row.clear();
747                row.push(state.clone());
748                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
749                    let pyobj = self.converter.get_jsvalue(&ctx, field, column, i)?;
750                    row.push(pyobj);
751                }
752                let mut args = Args::new(ctx.clone(), row.len());
753                args.push_args(row.drain(..))?;
754                state = self
755                    .call_user_fn(&ctx, &accumulate, args, aggregate.options.is_async)
756                    .await
757                    .context("failed to call accumulate")?;
758            }
759            let output = self
760                .converter
761                .build_array(&aggregate.state_field, &ctx, vec![state])?;
762            Ok(output) as Result<_>
763        })
764        .await?;
765        Ok(new_state)
766    }
767
768    /// Call accumulate or retract of an aggregate function.
769    ///
770    /// The `ops` is a boolean array that indicates whether to accumulate or retract each row.
771    /// `false` for accumulate and `true` for retract.
772    ///
773    /// # Example
774    /// ```
775    /// # tokio_test::block_on(async {
776    #[doc = include_str!("doc_create_aggregate.txt")]
777    /// let state = runtime.create_state("sum").await.unwrap();
778    ///
779    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
780    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
781    /// let ops = BooleanArray::from(vec![false, false, true, false]);
782    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
783    ///
784    /// let state = runtime.accumulate_or_retract("sum", &state, &ops, &input).await.unwrap();
785    /// assert_eq!(&*state, &Int32Array::from(vec![3]));
786    /// # });
787    /// ```
788    pub async fn accumulate_or_retract(
789        &self,
790        name: &str,
791        state: &dyn Array,
792        ops: &BooleanArray,
793        input: &RecordBatch,
794    ) -> Result<ArrayRef> {
795        let aggregate = self.aggregates.get(name).context("function not found")?;
796        // convert each row to python objects and call the accumulate function
797        let new_state = async_with!(self.context => |ctx| {
798        let accumulate = aggregate.accumulate.clone().restore(&ctx)?;
799        let retract = aggregate
800            .retract
801            .clone()
802            .context("function does not support retraction")?
803            .restore(&ctx)?;
804
805        let mut state = self
806            .converter
807            .get_jsvalue(&ctx, &aggregate.state_field, state, 0)?;
808
809        let mut row = Vec::with_capacity(1 + input.num_columns());
810        for i in 0..input.num_rows() {
811            if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput
812                && input.columns().iter().any(|column| column.is_null(i))
813            {
814                continue;
815            }
816            row.clear();
817            row.push(state.clone());
818            for (column, field) in input.columns().iter().zip(input.schema().fields()) {
819                let pyobj = self.converter.get_jsvalue(&ctx, field, column, i)?;
820                row.push(pyobj);
821            }
822            let func = if ops.is_valid(i) && ops.value(i) {
823                &retract
824            } else {
825                &accumulate
826            };
827            let mut args = Args::new(ctx.clone(), row.len());
828            args.push_args(row.drain(..))?;
829            state = self
830                .call_user_fn(&ctx, func, args, aggregate.options.is_async)
831                .await
832                .context("failed to call accumulate or retract")?;
833        }
834        let output = self
835            .converter
836            .build_array(&aggregate.state_field, &ctx, vec![state])?;
837        Ok(output) as Result<_>
838            })
839        .await?;
840        Ok(new_state)
841    }
842
843    /// Merge states of an aggregate function.
844    ///
845    /// # Example
846    /// ```
847    /// # tokio_test::block_on(async {
848    #[doc = include_str!("doc_create_aggregate.txt")]
849    /// let states = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
850    ///
851    /// let state = runtime.merge("sum", &states).await.unwrap();
852    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
853    /// # });
854    /// ```
855    pub async fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
856        let aggregate = self.aggregates.get(name).context("function not found")?;
857        let output = async_with!(self.context => |ctx| {
858            let merge = aggregate
859                .merge
860                .clone()
861                .context("merge not found")?
862                .restore(&ctx)?;
863            let mut state = self
864                .converter
865                .get_jsvalue(&ctx, &aggregate.state_field, states, 0)?;
866            for i in 1..states.len() {
867                if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
868                    continue;
869                }
870                let state2 = self
871                    .converter
872                    .get_jsvalue(&ctx, &aggregate.state_field, states, i)?;
873                let mut args = Args::new(ctx.clone(), 2);
874                args.push_args([state, state2])?;
875                state = self
876                    .call_user_fn(&ctx, &merge, args, aggregate.options.is_async)
877                    .await
878                    .context("failed to call accumulate or retract")?;
879            }
880            let output = self
881                .converter
882                .build_array(&aggregate.state_field, &ctx, vec![state])?;
883            Ok(output) as Result<_>
884        })
885        .await?;
886        Ok(output)
887    }
888
889    /// Get the result of an aggregate function.
890    ///
891    /// If the `finish` function is not defined, the state is returned as the result.
892    ///
893    /// # Example
894    /// ```
895    /// # tokio_test::block_on(async {
896    #[doc = include_str!("doc_create_aggregate.txt")]
897    /// let states: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)]));
898    ///
899    /// let outputs = runtime.finish("sum", &states).await.unwrap();
900    /// assert_eq!(&outputs, &states);
901    /// # });
902    /// ```
903    pub async fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
904        let aggregate = self.aggregates.get(name).context("function not found")?;
905        if aggregate.finish.is_none() {
906            return Ok(states.clone());
907        };
908        let output = async_with!(self.context => |ctx| {
909            let finish = aggregate.finish.clone().unwrap().restore(&ctx)?;
910            let mut results = Vec::with_capacity(states.len());
911            for i in 0..states.len() {
912                if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
913                    results.push(Value::new_null(ctx.clone()));
914                    continue;
915                }
916                let state =
917                    self.converter
918                        .get_jsvalue(&ctx, &aggregate.state_field, states, i)?;
919                let mut args = Args::new(ctx.clone(), 1);
920                args.push_args([state])?;
921                let result = self
922                    .call_user_fn(&ctx, &finish, args, aggregate.options.is_async)
923                    .await
924                    .context("failed to call finish")?;
925                results.push(result);
926            }
927            let output = self
928                .converter
929                .build_array(&aggregate.output_field, &ctx, results)?;
930            Ok(output) as Result<_>
931        })
932        .await?;
933        Ok(output)
934    }
935
936    /// Call a user function.
937    ///
938    /// If `timeout` is set, the function will be interrupted after the timeout.
939    async fn call_user_fn<'js, T: FromJs<'js>>(
940        &self,
941        ctx: &Ctx<'js>,
942        f: &rquickjs::Function<'js>,
943        args: Args<'js>,
944        is_async: bool,
945    ) -> Result<T> {
946        if is_async {
947            Self::call_user_fn_async(self, ctx, f, args).await
948        } else {
949            Self::call_user_fn_sync(self, ctx, f, args)
950        }
951    }
952
953    async fn call_user_fn_async<'js, T: FromJs<'js>>(
954        &self,
955        ctx: &Ctx<'js>,
956        f: &rquickjs::Function<'js>,
957        args: Args<'js>,
958    ) -> Result<T> {
959        let call_result = if let Some(timeout) = self.timeout {
960            self.deadline
961                .store(Some(Instant::now() + timeout), Ordering::Relaxed);
962            let call_result = f.call_arg::<Promise>(args);
963            self.deadline.store(None, Ordering::Relaxed);
964            call_result
965        } else {
966            f.call_arg::<Promise>(args)
967        };
968        let promise = call_result.map_err(|e| check_exception(e, ctx))?;
969        promise
970            .into_future::<T>()
971            .await
972            .map_err(|e| check_exception(e, ctx))
973    }
974
975    fn call_user_fn_sync<'js, T: FromJs<'js>>(
976        &self,
977        ctx: &Ctx<'js>,
978        f: &rquickjs::Function<'js>,
979        args: Args<'js>,
980    ) -> Result<T> {
981        let result = if let Some(timeout) = self.timeout {
982            self.deadline
983                .store(Some(Instant::now() + timeout), Ordering::Relaxed);
984            let result = f.call_arg(args);
985            self.deadline.store(None, Ordering::Relaxed);
986            result
987        } else {
988            f.call_arg(args)
989        };
990        result.map_err(|e| check_exception(e, ctx))
991    }
992
993    pub fn context(&self) -> &AsyncContext {
994        &self.context
995    }
996
997    /// Enable the `fetch` API in the `Runtime`.
998    ///
999    /// See module [`fetch`] for more details.
1000    #[cfg(feature = "javascript-fetch")]
1001    pub async fn enable_fetch(&self) -> Result<()> {
1002        fetch::enable_fetch(&self.runtime, &self.context).await
1003    }
1004}
1005
1006/// An iterator over the result of a table function.
1007pub struct RecordBatchIter<'a> {
1008    rt: &'a Runtime,
1009    input: &'a RecordBatch,
1010    // The function to generate the generator
1011    function: &'a Function,
1012    schema: SchemaRef,
1013    chunk_size: usize,
1014    // mutable states
1015    /// Current row index.
1016    row: usize,
1017    /// Generator of the current row.
1018    generator: Option<Persistent<Object<'static>>>,
1019    converter: &'a jsarrow::Converter,
1020}
1021
1022// XXX: not sure if this is safe.
1023unsafe impl Send for RecordBatchIter<'_> {}
1024
1025impl RecordBatchIter<'_> {
1026    /// Get the schema of the output.
1027    pub fn schema(&self) -> &Schema {
1028        &self.schema
1029    }
1030
1031    pub async fn next(&mut self) -> Result<Option<RecordBatch>> {
1032        if self.row == self.input.num_rows() {
1033            return Ok(None);
1034        }
1035        async_with!(self.rt.context => |ctx| {
1036            let js_function = self.function.function.clone().restore(&ctx)?;
1037            let mut indexes = Int32Builder::with_capacity(self.chunk_size);
1038            let mut results = Vec::with_capacity(self.input.num_rows());
1039            let mut row = Vec::with_capacity(self.input.num_columns());
1040            // restore generator from state
1041            let mut generator = match self.generator.take() {
1042                Some(generator) => {
1043                    let gen = generator.restore(&ctx)?;
1044                    let next: rquickjs::Function =
1045                        gen.get("next").context("failed to get 'next' method")?;
1046                    Some((gen, next))
1047                }
1048                None => None,
1049            };
1050            while self.row < self.input.num_rows() && results.len() < self.chunk_size {
1051                let (gen, next) = if let Some(g) = generator.as_ref() {
1052                    g
1053                } else {
1054                    // call the table function to get a generator
1055                    row.clear();
1056                    for (column, field) in
1057                        (self.input.columns().iter()).zip(self.input.schema().fields())
1058                    {
1059                        let val = self
1060                            .converter
1061                            .get_jsvalue(&ctx, field, column, self.row)
1062                            .context("failed to get jsvalue from arrow array")?;
1063                        row.push(val);
1064                    }
1065                    if self.function.options.call_mode == CallMode::ReturnNullOnNullInput
1066                        && row.iter().any(|v| v.is_null())
1067                    {
1068                        self.row += 1;
1069                        continue;
1070                    }
1071                    let mut args = Args::new(ctx.clone(), row.len());
1072                    args.push_args(row.drain(..))?;
1073                    // NOTE: A async generator function, defined by `async function*`, itself is NOT async.
1074                    // That's why we call it with `is_async = false` here.
1075                    // The result is a `AsyncGenerator`, which has a async `next` method.
1076                    let gen: Object = self
1077                        .rt
1078                        .call_user_fn(&ctx, &js_function, args, false).await
1079                        .context("failed to call function")?;
1080                    let next: rquickjs::Function =
1081                        gen.get("next").context("failed to get 'next' method")?;
1082                    let mut args = Args::new(ctx.clone(), 0);
1083                    args.this(gen.clone())?;
1084                    generator.insert((gen, next))
1085                };
1086                let mut args = Args::new(ctx.clone(), 0);
1087                args.this(gen.clone())?;
1088                let object: Object = self
1089                    .rt
1090                    .call_user_fn(&ctx, next, args, self.function.options.is_async).await
1091                    .context("failed to call next")?;
1092                let value: Value = object.get("value")?;
1093                let done: bool = object.get("done")?;
1094                if done {
1095                    self.row += 1;
1096                    generator = None;
1097                    continue;
1098                }
1099                indexes.append_value(self.row as i32);
1100                results.push(value);
1101            }
1102            self.generator = generator.map(|(gen, _)| Persistent::save(&ctx, gen));
1103
1104            if results.is_empty() {
1105                return Ok(None);
1106            }
1107            let indexes = Arc::new(indexes.finish());
1108            let array = self
1109                .converter
1110                .build_array(&self.function.return_field, &ctx, results)
1111                .context("failed to build arrow array from return values")?;
1112            Ok(Some(RecordBatch::try_new(
1113                self.schema.clone(),
1114                vec![indexes, array],
1115            )?))
1116        })
1117        .await
1118    }
1119}
1120
1121impl Stream for RecordBatchIter<'_> {
1122    type Item = Result<RecordBatch>;
1123    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1124        Box::pin(self.next().map(|v| v.transpose()))
1125            .as_mut()
1126            .poll_unpin(cx)
1127    }
1128}
1129
1130/// Get exception from `ctx` if the error is an exception.
1131pub(crate) fn check_exception(err: rquickjs::Error, ctx: &Ctx) -> anyhow::Error {
1132    match err {
1133        rquickjs::Error::Exception => {
1134            anyhow!("exception generated by QuickJS: {:?}", ctx.catch())
1135        }
1136        e => e.into(),
1137    }
1138}