arrow_udf_js/
lib.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("../README.md")]
16
17use std::collections::HashMap;
18use std::fmt::Debug;
19use std::pin::Pin;
20use std::sync::{atomic::Ordering, Arc};
21use std::task::{Context, Poll};
22use std::time::{Duration, Instant};
23
24use anyhow::{anyhow, bail, Context as _, Result};
25use arrow_array::{builder::Int32Builder, Array, ArrayRef, BooleanArray, RecordBatch};
26use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
27use futures_util::{FutureExt, Stream};
28use rquickjs::context::intrinsic::{All, Base};
29pub use rquickjs::runtime::MemoryUsage;
30use rquickjs::{
31    async_with, function::Args, module::Evaluated, Array as JsArray, AsyncContext, AsyncRuntime,
32    Ctx, FromJs, IteratorJs as _, Module, Object, Persistent, Promise, Value,
33};
34
35pub use self::into_field::IntoField;
36
37#[cfg(feature = "fetch")]
38mod fetch;
39mod into_field;
40mod jsarrow;
41
42/// A runtime to execute user defined functions in JavaScript.
43///
44/// # Usages
45///
46/// - Create a new runtime with [`Runtime::new`].
47/// - For scalar functions, use [`add_function`] and [`call`].
48/// - For table functions, use [`add_function`] and [`call_table_function`].
49/// - For aggregate functions, create the function with [`add_aggregate`], and then
50///     - create a new state with [`create_state`],
51///     - update the state with [`accumulate`] or [`accumulate_or_retract`],
52///     - merge states with [`merge`],
53///     - finally get the result with [`finish`].
54///
55/// Click on each function to see the example.
56///
57/// [`add_function`]: Runtime::add_function
58/// [`add_aggregate`]: Runtime::add_aggregate
59/// [`call`]: Runtime::call
60/// [`call_table_function`]: Runtime::call_table_function
61/// [`create_state`]: Runtime::create_state
62/// [`accumulate`]: Runtime::accumulate
63/// [`accumulate_or_retract`]: Runtime::accumulate_or_retract
64/// [`merge`]: Runtime::merge
65/// [`finish`]: Runtime::finish
66pub struct Runtime {
67    functions: HashMap<String, Function>,
68    aggregates: HashMap<String, Aggregate>,
69    // NOTE: `functions` and `aggregates` must be put before the `runtime` and `context` to be dropped first.
70    converter: jsarrow::Converter,
71    runtime: AsyncRuntime,
72    context: AsyncContext,
73    /// Timeout of each function call.
74    timeout: Option<Duration>,
75    /// Deadline of the current function call.
76    deadline: Arc<atomic_time::AtomicOptionInstant>,
77}
78
79impl Debug for Runtime {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("Runtime")
82            .field("functions", &self.functions.keys())
83            .field("aggregates", &self.aggregates.keys())
84            .field("timeout", &self.timeout)
85            .finish()
86    }
87}
88
89/// A user defined scalar function or table function.
90struct Function {
91    function: JsFunction,
92    return_field: FieldRef,
93    options: FunctionOptions,
94}
95
96/// A user defined aggregate function.
97struct Aggregate {
98    state_field: FieldRef,
99    output_field: FieldRef,
100    create_state: JsFunction,
101    accumulate: JsFunction,
102    retract: Option<JsFunction>,
103    finish: Option<JsFunction>,
104    merge: Option<JsFunction>,
105    options: AggregateOptions,
106}
107
108// This is required to pass `Function` and `Aggregate` from `async_with!` to outside.
109// Otherwise, the compiler will complain that `*mut JSRuntime` cannot be sent between threads safely
110// SAFETY: We ensure the `JSRuntime` used in `async_with!` is same as the caller's.
111// The `parallel` feature of `rquickjs` is enabled, so itself can't ensure this.
112unsafe impl Send for Function {}
113unsafe impl Sync for Function {}
114unsafe impl Send for Aggregate {}
115unsafe impl Sync for Aggregate {}
116
117/// A persistent function, can be either sync or async.
118type JsFunction = Persistent<rquickjs::Function<'static>>;
119
120// SAFETY: `rquickjs::Runtime` is `Send` and `Sync`
121unsafe impl Send for Runtime {}
122unsafe impl Sync for Runtime {}
123
124/// Whether the function will be called when some of its arguments are null.
125#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
126pub enum CallMode {
127    /// The function will be called normally when some of its arguments are null.
128    /// It is then the function author's responsibility to check for null values if necessary and respond appropriately.
129    #[default]
130    CalledOnNullInput,
131
132    /// The function always returns null whenever any of its arguments are null.
133    /// If this parameter is specified, the function is not executed when there are null arguments;
134    /// instead a null result is assumed automatically.
135    ReturnNullOnNullInput,
136}
137
138/// Options for configuring user-defined functions.
139#[derive(Debug, Clone, Default)]
140pub struct FunctionOptions {
141    /// Whether the function will be called when some of its arguments are null.
142    pub call_mode: CallMode,
143    /// Whether the function is async. An async function would return a Promise.
144    pub is_async: bool,
145    /// Whether the function accepts a batch of records as input.
146    pub is_batched: bool,
147    /// The name of the function in JavaScript code to be called.
148    /// If not set, the function name will be used.
149    pub handler: Option<String>,
150}
151
152impl FunctionOptions {
153    /// Sets the function to return null when some of its arguments are null.
154    /// See [`CallMode`] for more details.
155    pub fn return_null_on_null_input(mut self) -> Self {
156        self.call_mode = CallMode::ReturnNullOnNullInput;
157        self
158    }
159
160    /// Marks the function to be async JS function.
161    pub fn async_mode(mut self) -> Self {
162        self.is_async = true;
163        self
164    }
165
166    /// Sets the function to accept a batch of records as input.
167    pub fn batched(mut self) -> Self {
168        self.is_batched = true;
169        self
170    }
171
172    /// Sets the name of the function in JavaScript code to be called.
173    pub fn handler(mut self, handler: impl Into<String>) -> Self {
174        self.handler = Some(handler.into());
175        self
176    }
177}
178
179/// Options for configuring user-defined aggregate functions.
180#[derive(Debug, Clone, Default)]
181pub struct AggregateOptions {
182    /// Whether the function will be called when some of its arguments are null.
183    pub call_mode: CallMode,
184    /// Whether the function is async. An async function would return a Promise.
185    pub is_async: bool,
186}
187
188impl AggregateOptions {
189    /// Sets the function to return null when some of its arguments are null.
190    /// See [`CallMode`] for more details.
191    pub fn return_null_on_null_input(mut self) -> Self {
192        self.call_mode = CallMode::ReturnNullOnNullInput;
193        self
194    }
195
196    /// Marks the function to be async JS function.
197    pub fn async_mode(mut self) -> Self {
198        self.is_async = true;
199        self
200    }
201}
202
203impl Runtime {
204    /// Create a new `Runtime`.
205    ///
206    /// # Example
207    ///
208    /// ```
209    /// # use arrow_udf_js::Runtime;
210    /// # tokio_test::block_on(async {
211    /// let runtime = Runtime::new().await.unwrap();
212    /// runtime.set_memory_limit(Some(1 << 20)); // 1MB
213    /// # });
214    /// ```
215    pub async fn new() -> Result<Self> {
216        let runtime = AsyncRuntime::new().context("failed to create quickjs runtime")?;
217        let context = AsyncContext::custom::<(Base, All)>(&runtime)
218            .await
219            .context("failed to create quickjs context")?;
220
221        Ok(Self {
222            functions: HashMap::new(),
223            aggregates: HashMap::new(),
224            runtime,
225            context,
226            timeout: None,
227            deadline: Default::default(),
228            converter: jsarrow::Converter::new(),
229        })
230    }
231
232    /// Set the memory limit of the runtime.
233    ///
234    /// # Example
235    ///
236    /// ```
237    /// # use arrow_udf_js::Runtime;
238    /// # tokio_test::block_on(async {
239    /// let runtime = Runtime::new().await.unwrap();
240    /// runtime.set_memory_limit(Some(1 << 20)); // 1MB
241    /// # });
242    /// ```
243    pub async fn set_memory_limit(&self, limit: Option<usize>) {
244        self.runtime.set_memory_limit(limit.unwrap_or(0)).await;
245    }
246
247    /// Set the timeout of each function call.
248    ///
249    /// # Example
250    ///
251    /// ```
252    /// # use arrow_udf_js::Runtime;
253    /// # use std::time::Duration;
254    /// # tokio_test::block_on(async {
255    /// let mut runtime = Runtime::new().await.unwrap();
256    /// runtime.set_timeout(Some(Duration::from_secs(1))).await;
257    /// # });
258    /// ```
259    pub async fn set_timeout(&mut self, timeout: Option<Duration>) {
260        self.timeout = timeout;
261        if timeout.is_some() {
262            let deadline = self.deadline.clone();
263            self.runtime
264                .set_interrupt_handler(Some(Box::new(move || {
265                    if let Some(deadline) = deadline.load(Ordering::Relaxed) {
266                        return deadline <= Instant::now();
267                    }
268                    false
269                })))
270                .await;
271        } else {
272            self.runtime.set_interrupt_handler(None).await;
273        }
274    }
275
276    /// Return the inner quickjs runtime.
277    pub fn inner(&self) -> &AsyncRuntime {
278        &self.runtime
279    }
280
281    /// Return the converter where you can configure the extension metadata key and values.
282    pub fn converter_mut(&mut self) -> &mut jsarrow::Converter {
283        &mut self.converter
284    }
285
286    /// Add a new scalar function or table function.
287    ///
288    /// # Arguments
289    ///
290    /// - `name`: The name of the function.
291    /// - `return_type`: The data type of the return value.
292    /// - `options`: The options for configuring the function.
293    /// - `code`: The JavaScript code of the function.
294    ///
295    /// The code should define an **exported** function with the same name as the function.
296    /// The function should return a value for scalar functions, or yield values for table functions.
297    ///
298    /// # Example
299    ///
300    /// ```
301    /// # use arrow_udf_js::{FunctionOptions, Runtime};
302    /// # use arrow_schema::DataType;
303    /// # tokio_test::block_on(async {
304    /// let mut runtime = Runtime::new().await.unwrap();
305    /// // add a scalar function
306    /// runtime
307    ///     .add_function(
308    ///         "gcd",
309    ///         DataType::Int32,
310    ///         r#"
311    ///         export function gcd(a, b) {
312    ///             while (b != 0) {
313    ///                 let t = b;
314    ///                 b = a % b;
315    ///                 a = t;
316    ///             }
317    ///             return a;
318    ///         }
319    /// "#,
320    ///         FunctionOptions::default().return_null_on_null_input(),
321    ///     )
322    ///     .await
323    ///     .unwrap();
324    /// // add a table function
325    /// runtime
326    ///     .add_function(
327    ///         "series",
328    ///         DataType::Int32,
329    ///         r#"
330    ///         export function* series(n) {
331    ///             for (let i = 0; i < n; i++) {
332    ///                 yield i;
333    ///             }
334    ///         }
335    /// "#,
336    ///         FunctionOptions::default().return_null_on_null_input(),
337    ///     )
338    ///     .await
339    ///     .unwrap();
340    /// # });
341    /// ```
342    pub async fn add_function(
343        &mut self,
344        name: &str,
345        return_type: impl IntoField + Send,
346        code: &str,
347        options: FunctionOptions,
348    ) -> Result<()> {
349        let function = async_with!(self.context => |ctx| {
350            let (module, _) = Module::declare(ctx.clone(), name, code)
351                .map_err(|e| check_exception(e, &ctx))
352                .context("failed to declare module")?
353                .eval()
354                .map_err(|e| check_exception(e, &ctx))
355                .context("failed to evaluate module")?;
356            let function = Self::get_function(&ctx, &module, options.handler.as_deref().unwrap_or(name))?;
357            Ok(Function {
358                function,
359                return_field: return_type.into_field(name).into(),
360                options,
361            }) as Result<Function>
362        })
363        .await?;
364        self.functions.insert(name.to_string(), function);
365        Ok(())
366    }
367
368    /// Get a function from a module.
369    fn get_function<'a>(
370        ctx: &Ctx<'a>,
371        module: &Module<'a, Evaluated>,
372        name: &str,
373    ) -> Result<JsFunction> {
374        let function: rquickjs::Function = module.get(name).with_context(|| {
375            format!("function \"{name}\" not found. HINT: make sure the function is exported")
376        })?;
377        Ok(Persistent::save(ctx, function))
378    }
379
380    /// Add a new aggregate function.
381    ///
382    /// # Arguments
383    ///
384    /// - `name`: The name of the function.
385    /// - `state_type`: The data type of the internal state.
386    /// - `output_type`: The data type of the aggregate value.
387    /// - `mode`: Whether the function will be called when some of its arguments are null.
388    /// - `code`: The JavaScript code of the aggregate function.
389    ///
390    /// The code should define at least two functions:
391    ///
392    /// - `create_state() -> state`: Create a new state object.
393    /// - `accumulate(state, *args) -> state`: Accumulate a new value into the state, returning the updated state.
394    ///
395    /// optionally, the code can define:
396    ///
397    /// - `finish(state) -> value`: Get the result of the aggregate function.
398    ///     If not defined, the state is returned as the result.
399    ///     In this case, `output_type` must be the same as `state_type`.
400    /// - `retract(state, *args) -> state`: Retract a value from the state, returning the updated state.
401    /// - `merge(state, state) -> state`: Merge two states, returning the merged state.
402    ///
403    /// Each function must be **exported**.
404    ///
405    /// # Example
406    ///
407    /// ```
408    /// # use arrow_udf_js::{AggregateOptions, Runtime};
409    /// # use arrow_schema::DataType;
410    /// # tokio_test::block_on(async {
411    /// let mut runtime = Runtime::new().await.unwrap();
412    /// runtime
413    ///     .add_aggregate(
414    ///         "sum",
415    ///         DataType::Int32, // state_type
416    ///         DataType::Int32, // output_type
417    ///         r#"
418    ///         export function create_state() {
419    ///             return 0;
420    ///         }
421    ///         export function accumulate(state, value) {
422    ///             return state + value;
423    ///         }
424    ///         export function retract(state, value) {
425    ///             return state - value;
426    ///         }
427    ///         export function merge(state1, state2) {
428    ///             return state1 + state2;
429    ///         }
430    ///         "#,
431    ///         AggregateOptions::default().return_null_on_null_input(),
432    ///     )
433    ///     .await
434    ///     .unwrap();
435    /// # });
436    /// ```
437    pub async fn add_aggregate(
438        &mut self,
439        name: &str,
440        state_type: impl IntoField + Send,
441        output_type: impl IntoField + Send,
442        code: &str,
443        options: AggregateOptions,
444    ) -> Result<()> {
445        let aggregate = async_with!(self.context => |ctx| {
446            let (module, _) = Module::declare(ctx.clone(), name, code)
447                .map_err(|e| check_exception(e, &ctx))
448                .context("failed to declare module")?
449                .eval()
450                .map_err(|e| check_exception(e, &ctx))
451                .context("failed to evaluate module")?;
452            Ok(Aggregate {
453                state_field: state_type.into_field(name).into(),
454                output_field: output_type.into_field(name).into(),
455                create_state: Self::get_function(&ctx, &module, "create_state")?,
456                accumulate: Self::get_function(&ctx, &module, "accumulate")?,
457                retract: Self::get_function(&ctx, &module, "retract").ok(),
458                finish: Self::get_function(&ctx, &module, "finish").ok(),
459                merge: Self::get_function(&ctx, &module, "merge").ok(),
460                options,
461            }) as Result<Aggregate>
462        })
463        .await?;
464
465        if aggregate.finish.is_none() && aggregate.state_field != aggregate.output_field {
466            bail!("`output_type` must be the same as `state_type` when `finish` is not defined");
467        }
468        self.aggregates.insert(name.to_string(), aggregate);
469        Ok(())
470    }
471
472    /// Call a scalar function.
473    ///
474    /// # Example
475    ///
476    /// ```
477    /// # tokio_test::block_on(async {
478    #[doc = include_str!("doc_create_function.txt")]
479    /// // suppose we have created a scalar function `gcd`
480    /// // see the example in `add_function`
481    ///
482    /// let schema = Schema::new(vec![
483    ///     Field::new("x", DataType::Int32, true),
484    ///     Field::new("y", DataType::Int32, true),
485    /// ]);
486    /// let arg0 = Int32Array::from(vec![Some(25), None]);
487    /// let arg1 = Int32Array::from(vec![Some(15), None]);
488    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0), Arc::new(arg1)]).unwrap();
489    ///
490    /// let output = runtime.call("gcd", &input).await.unwrap();
491    /// assert_eq!(&**output.column(0), &Int32Array::from(vec![Some(5), None]));
492    /// # });
493    /// ```
494    pub async fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
495        let function = self.functions.get(name).context("function not found")?;
496
497        async_with!(self.context => |ctx| {
498            if function.options.is_batched {
499                self.call_batched_function(&ctx, function, input).await
500            } else {
501                self.call_non_batched_function(&ctx, function, input).await
502            }
503        })
504        .await
505    }
506
507    async fn call_non_batched_function(
508        &self,
509        ctx: &Ctx<'_>,
510        function: &Function,
511        input: &RecordBatch,
512    ) -> Result<RecordBatch> {
513        let js_function = function.function.clone().restore(ctx)?;
514
515        let mut results = Vec::with_capacity(input.num_rows());
516        let mut row = Vec::with_capacity(input.num_columns());
517        for i in 0..input.num_rows() {
518            row.clear();
519            for (column, field) in input.columns().iter().zip(input.schema().fields()) {
520                let val = self
521                    .converter
522                    .get_jsvalue(ctx, field, column, i)
523                    .context("failed to get jsvalue from arrow array")?;
524
525                row.push(val);
526            }
527            if function.options.call_mode == CallMode::ReturnNullOnNullInput
528                && row.iter().any(|v| v.is_null())
529            {
530                results.push(Value::new_null(ctx.clone()));
531                continue;
532            }
533            let mut args = Args::new(ctx.clone(), row.len());
534            args.push_args(row.drain(..))?;
535            let result = self
536                .call_user_fn(ctx, &js_function, args, function.options.is_async)
537                .await
538                .context("failed to call function")?;
539            results.push(result);
540        }
541
542        let array = self
543            .converter
544            .build_array(&function.return_field, ctx, results)
545            .context("failed to build arrow array from return values")?;
546        let schema = Schema::new(vec![function.return_field.clone()]);
547        Ok(RecordBatch::try_new(Arc::new(schema), vec![array])?)
548    }
549
550    async fn call_batched_function(
551        &self,
552        ctx: &Ctx<'_>,
553        function: &Function,
554        input: &RecordBatch,
555    ) -> Result<RecordBatch> {
556        let js_function = function.function.clone().restore(ctx)?;
557
558        let mut js_columns = Vec::with_capacity(input.num_columns());
559        for (column, field) in input.columns().iter().zip(input.schema().fields()) {
560            let mut js_values = Vec::with_capacity(input.num_rows());
561            for i in 0..input.num_rows() {
562                let val = self
563                    .converter
564                    .get_jsvalue(ctx, field, column, i)
565                    .context("failed to get jsvalue from arrow array")?;
566                js_values.push(val);
567            }
568            js_columns.push(js_values);
569        }
570
571        let result = match function.options.call_mode {
572            CallMode::CalledOnNullInput => {
573                let mut args = Args::new(ctx.clone(), input.num_columns());
574                for js_values in js_columns {
575                    let js_array = js_values.into_iter().collect_js::<JsArray>(ctx)?;
576                    args.push_arg(js_array)?;
577                }
578                self.call_user_fn(ctx, &js_function, args, function.options.is_async)
579                    .await
580                    .context("failed to call function")?
581            }
582            CallMode::ReturnNullOnNullInput => {
583                // This is a bit tricky. We build input arrays without nulls, call user_fn on them,
584                // and then add back null results to form the final result.
585                let n_cols = input.num_columns();
586                let n_rows = input.num_rows();
587
588                // 1. Build a bitmap of which rows have nulls
589                let mut bitmap = Vec::with_capacity(n_rows);
590                for i in 0..n_rows {
591                    let has_null = (0..n_cols).any(|j| js_columns[j][i].is_null());
592                    bitmap.push(!has_null);
593                }
594
595                // 2. Build new inputs with only the rows that don't have nulls
596                let mut filtered_columns = Vec::with_capacity(n_cols);
597                for js_values in js_columns {
598                    let filtered_js_values: Vec<_> = js_values
599                        .into_iter()
600                        .zip(bitmap.iter())
601                        .filter(|(_, b)| **b)
602                        .map(|(v, _)| v)
603                        .collect();
604                    filtered_columns.push(filtered_js_values);
605                }
606
607                // 3. Call the function on the new inputs
608                let mut args = Args::new(ctx.clone(), filtered_columns.len());
609                for js_values in filtered_columns {
610                    let js_array = js_values.into_iter().collect_js::<JsArray>(ctx)?;
611                    args.push_arg(js_array)?;
612                }
613                let filtered_result: Vec<_> = self
614                    .call_user_fn(ctx, &js_function, args, function.options.is_async)
615                    .await
616                    .context("failed to call function")?;
617                let mut iter = filtered_result.into_iter();
618
619                // 4. Add back null results to the filtered results
620                let mut result = Vec::with_capacity(n_rows);
621                for b in bitmap.iter() {
622                    if *b {
623                        let v = iter.next().expect("filtered result length mismatch");
624                        result.push(v);
625                    } else {
626                        result.push(Value::new_null(ctx.clone()));
627                    }
628                }
629                assert!(iter.next().is_none(), "filtered result length mismatch");
630                result
631            }
632        };
633        let array = self
634            .converter
635            .build_array(&function.return_field, ctx, result)?;
636        let schema = Schema::new(vec![function.return_field.clone()]);
637        Ok(RecordBatch::try_new(Arc::new(schema), vec![array])?)
638    }
639
640    /// Call a table function.
641    ///
642    /// # Example
643    ///
644    /// ```
645    /// # tokio_test::block_on(async {
646    #[doc = include_str!("doc_create_function.txt")]
647    /// // suppose we have created a table function `series`
648    /// // see the example in `add_function`
649    ///
650    /// let schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]);
651    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3)]);
652    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
653    ///
654    /// let mut outputs = runtime.call_table_function("series", &input, 10).unwrap();
655    /// let output = outputs.next().await.unwrap().unwrap();
656    /// let pretty = arrow_cast::pretty::pretty_format_batches(&[output]).unwrap().to_string();
657    /// assert_eq!(pretty, r#"
658    /// +-----+--------+
659    /// | row | series |
660    /// +-----+--------+
661    /// | 0   | 0      |
662    /// | 2   | 0      |
663    /// | 2   | 1      |
664    /// | 2   | 2      |
665    /// +-----+--------+"#.trim());
666    /// # });
667    /// ```
668    pub fn call_table_function<'a>(
669        &'a self,
670        name: &'a str,
671        input: &'a RecordBatch,
672        chunk_size: usize,
673    ) -> Result<RecordBatchIter<'a>> {
674        assert!(chunk_size > 0);
675        let function = self.functions.get(name).context("function not found")?;
676        if function.options.is_batched {
677            bail!("table function does not support batched mode");
678        }
679
680        // initial state
681        Ok(RecordBatchIter {
682            rt: self,
683            input,
684            function,
685            schema: Arc::new(Schema::new(vec![
686                Arc::new(Field::new("row", DataType::Int32, false)),
687                function.return_field.clone(),
688            ])),
689            chunk_size,
690            row: 0,
691            generator: None,
692            converter: &self.converter,
693        })
694    }
695
696    /// Create a new state for an aggregate function.
697    ///
698    /// # Example
699    /// ```
700    /// # tokio_test::block_on(async {
701    #[doc = include_str!("doc_create_aggregate.txt")]
702    /// let state = runtime.create_state("sum").await.unwrap();
703    /// assert_eq!(&*state, &Int32Array::from(vec![0]));
704    /// # });
705    /// ```
706    pub async fn create_state(&self, name: &str) -> Result<ArrayRef> {
707        let aggregate = self.aggregates.get(name).context("function not found")?;
708        let state = async_with!(self.context => |ctx| {
709            let create_state = aggregate.create_state.clone().restore(&ctx)?;
710            let state = self
711                .call_user_fn(&ctx, &create_state, Args::new(ctx.clone(), 0), aggregate.options.is_async)
712                .await
713                .context("failed to call create_state")?;
714            let state = self
715                .converter
716                .build_array(&aggregate.state_field, &ctx, vec![state])?;
717            Ok(state) as Result<_>
718        })
719        .await?;
720        Ok(state)
721    }
722
723    /// Call accumulate of an aggregate function.
724    ///
725    /// # Example
726    /// ```
727    /// # tokio_test::block_on(async {
728    #[doc = include_str!("doc_create_aggregate.txt")]
729    /// let state = runtime.create_state("sum").await.unwrap();
730    ///
731    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
732    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
733    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
734    ///
735    /// let state = runtime.accumulate("sum", &state, &input).await.unwrap();
736    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
737    /// # });
738    /// ```
739    pub async fn accumulate(
740        &self,
741        name: &str,
742        state: &dyn Array,
743        input: &RecordBatch,
744    ) -> Result<ArrayRef> {
745        let aggregate = self.aggregates.get(name).context("function not found")?;
746        // convert each row to python objects and call the accumulate function
747        let new_state = async_with!(self.context => |ctx| {
748            let accumulate = aggregate.accumulate.clone().restore(&ctx)?;
749            let mut state = self
750                .converter
751                .get_jsvalue(&ctx, &aggregate.state_field, state, 0)?;
752
753            let mut row = Vec::with_capacity(1 + input.num_columns());
754            for i in 0..input.num_rows() {
755                if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput
756                    && input.columns().iter().any(|column| column.is_null(i))
757                {
758                    continue;
759                }
760                row.clear();
761                row.push(state.clone());
762                for (column, field) in input.columns().iter().zip(input.schema().fields()) {
763                    let pyobj = self.converter.get_jsvalue(&ctx, field, column, i)?;
764                    row.push(pyobj);
765                }
766                let mut args = Args::new(ctx.clone(), row.len());
767                args.push_args(row.drain(..))?;
768                state = self
769                    .call_user_fn(&ctx, &accumulate, args, aggregate.options.is_async)
770                    .await
771                    .context("failed to call accumulate")?;
772            }
773            let output = self
774                .converter
775                .build_array(&aggregate.state_field, &ctx, vec![state])?;
776            Ok(output) as Result<_>
777        })
778        .await?;
779        Ok(new_state)
780    }
781
782    /// Call accumulate or retract of an aggregate function.
783    ///
784    /// The `ops` is a boolean array that indicates whether to accumulate or retract each row.
785    /// `false` for accumulate and `true` for retract.
786    ///
787    /// # Example
788    /// ```
789    /// # tokio_test::block_on(async {
790    #[doc = include_str!("doc_create_aggregate.txt")]
791    /// let state = runtime.create_state("sum").await.unwrap();
792    ///
793    /// let schema = Schema::new(vec![Field::new("value", DataType::Int32, true)]);
794    /// let arg0 = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
795    /// let ops = BooleanArray::from(vec![false, false, true, false]);
796    /// let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
797    ///
798    /// let state = runtime.accumulate_or_retract("sum", &state, &ops, &input).await.unwrap();
799    /// assert_eq!(&*state, &Int32Array::from(vec![3]));
800    /// # });
801    /// ```
802    pub async fn accumulate_or_retract(
803        &self,
804        name: &str,
805        state: &dyn Array,
806        ops: &BooleanArray,
807        input: &RecordBatch,
808    ) -> Result<ArrayRef> {
809        let aggregate = self.aggregates.get(name).context("function not found")?;
810        // convert each row to python objects and call the accumulate function
811        let new_state = async_with!(self.context => |ctx| {
812        let accumulate = aggregate.accumulate.clone().restore(&ctx)?;
813        let retract = aggregate
814            .retract
815            .clone()
816            .context("function does not support retraction")?
817            .restore(&ctx)?;
818
819        let mut state = self
820            .converter
821            .get_jsvalue(&ctx, &aggregate.state_field, state, 0)?;
822
823        let mut row = Vec::with_capacity(1 + input.num_columns());
824        for i in 0..input.num_rows() {
825            if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput
826                && input.columns().iter().any(|column| column.is_null(i))
827            {
828                continue;
829            }
830            row.clear();
831            row.push(state.clone());
832            for (column, field) in input.columns().iter().zip(input.schema().fields()) {
833                let pyobj = self.converter.get_jsvalue(&ctx, field, column, i)?;
834                row.push(pyobj);
835            }
836            let func = if ops.is_valid(i) && ops.value(i) {
837                &retract
838            } else {
839                &accumulate
840            };
841            let mut args = Args::new(ctx.clone(), row.len());
842            args.push_args(row.drain(..))?;
843            state = self
844                .call_user_fn(&ctx, func, args, aggregate.options.is_async)
845                .await
846                .context("failed to call accumulate or retract")?;
847        }
848        let output = self
849            .converter
850            .build_array(&aggregate.state_field, &ctx, vec![state])?;
851        Ok(output) as Result<_>
852            })
853        .await?;
854        Ok(new_state)
855    }
856
857    /// Merge states of an aggregate function.
858    ///
859    /// # Example
860    /// ```
861    /// # tokio_test::block_on(async {
862    #[doc = include_str!("doc_create_aggregate.txt")]
863    /// let states = Int32Array::from(vec![Some(1), None, Some(3), Some(5)]);
864    ///
865    /// let state = runtime.merge("sum", &states).await.unwrap();
866    /// assert_eq!(&*state, &Int32Array::from(vec![9]));
867    /// # });
868    /// ```
869    pub async fn merge(&self, name: &str, states: &dyn Array) -> Result<ArrayRef> {
870        let aggregate = self.aggregates.get(name).context("function not found")?;
871        let output = async_with!(self.context => |ctx| {
872            let merge = aggregate
873                .merge
874                .clone()
875                .context("merge not found")?
876                .restore(&ctx)?;
877            let mut state = self
878                .converter
879                .get_jsvalue(&ctx, &aggregate.state_field, states, 0)?;
880            for i in 1..states.len() {
881                if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
882                    continue;
883                }
884                let state2 = self
885                    .converter
886                    .get_jsvalue(&ctx, &aggregate.state_field, states, i)?;
887                let mut args = Args::new(ctx.clone(), 2);
888                args.push_args([state, state2])?;
889                state = self
890                    .call_user_fn(&ctx, &merge, args, aggregate.options.is_async)
891                    .await
892                    .context("failed to call accumulate or retract")?;
893            }
894            let output = self
895                .converter
896                .build_array(&aggregate.state_field, &ctx, vec![state])?;
897            Ok(output) as Result<_>
898        })
899        .await?;
900        Ok(output)
901    }
902
903    /// Get the result of an aggregate function.
904    ///
905    /// If the `finish` function is not defined, the state is returned as the result.
906    ///
907    /// # Example
908    /// ```
909    /// # tokio_test::block_on(async {
910    #[doc = include_str!("doc_create_aggregate.txt")]
911    /// let states: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)]));
912    ///
913    /// let outputs = runtime.finish("sum", &states).await.unwrap();
914    /// assert_eq!(&outputs, &states);
915    /// # });
916    /// ```
917    pub async fn finish(&self, name: &str, states: &ArrayRef) -> Result<ArrayRef> {
918        let aggregate = self.aggregates.get(name).context("function not found")?;
919        if aggregate.finish.is_none() {
920            return Ok(states.clone());
921        };
922        let output = async_with!(self.context => |ctx| {
923            let finish = aggregate.finish.clone().unwrap().restore(&ctx)?;
924            let mut results = Vec::with_capacity(states.len());
925            for i in 0..states.len() {
926                if aggregate.options.call_mode == CallMode::ReturnNullOnNullInput && states.is_null(i) {
927                    results.push(Value::new_null(ctx.clone()));
928                    continue;
929                }
930                let state =
931                    self.converter
932                        .get_jsvalue(&ctx, &aggregate.state_field, states, i)?;
933                let mut args = Args::new(ctx.clone(), 1);
934                args.push_args([state])?;
935                let result = self
936                    .call_user_fn(&ctx, &finish, args, aggregate.options.is_async)
937                    .await
938                    .context("failed to call finish")?;
939                results.push(result);
940            }
941            let output = self
942                .converter
943                .build_array(&aggregate.output_field, &ctx, results)?;
944            Ok(output) as Result<_>
945        })
946        .await?;
947        Ok(output)
948    }
949
950    /// Call a user function.
951    ///
952    /// If `timeout` is set, the function will be interrupted after the timeout.
953    async fn call_user_fn<'js, T: FromJs<'js>>(
954        &self,
955        ctx: &Ctx<'js>,
956        f: &rquickjs::Function<'js>,
957        args: Args<'js>,
958        is_async: bool,
959    ) -> Result<T> {
960        if is_async {
961            Self::call_user_fn_async(self, ctx, f, args).await
962        } else {
963            Self::call_user_fn_sync(self, ctx, f, args)
964        }
965    }
966
967    async fn call_user_fn_async<'js, T: FromJs<'js>>(
968        &self,
969        ctx: &Ctx<'js>,
970        f: &rquickjs::Function<'js>,
971        args: Args<'js>,
972    ) -> Result<T> {
973        let call_result = if let Some(timeout) = self.timeout {
974            self.deadline
975                .store(Some(Instant::now() + timeout), Ordering::Relaxed);
976            let call_result = f.call_arg::<Promise>(args);
977            self.deadline.store(None, Ordering::Relaxed);
978            call_result
979        } else {
980            f.call_arg::<Promise>(args)
981        };
982        let promise = call_result.map_err(|e| check_exception(e, ctx))?;
983        promise
984            .into_future::<T>()
985            .await
986            .map_err(|e| check_exception(e, ctx))
987    }
988
989    fn call_user_fn_sync<'js, T: FromJs<'js>>(
990        &self,
991        ctx: &Ctx<'js>,
992        f: &rquickjs::Function<'js>,
993        args: Args<'js>,
994    ) -> Result<T> {
995        let result = if let Some(timeout) = self.timeout {
996            self.deadline
997                .store(Some(Instant::now() + timeout), Ordering::Relaxed);
998            let result = f.call_arg(args);
999            self.deadline.store(None, Ordering::Relaxed);
1000            result
1001        } else {
1002            f.call_arg(args)
1003        };
1004        result.map_err(|e| check_exception(e, ctx))
1005    }
1006
1007    pub fn context(&self) -> &AsyncContext {
1008        &self.context
1009    }
1010
1011    /// Enable the `fetch` API in the `Runtime`.
1012    ///
1013    /// See module [`fetch`] for more details.
1014    #[cfg(feature = "fetch")]
1015    pub async fn enable_fetch(&self) -> Result<()> {
1016        fetch::enable_fetch(&self.runtime, &self.context).await
1017    }
1018}
1019
1020/// An iterator over the result of a table function.
1021pub struct RecordBatchIter<'a> {
1022    rt: &'a Runtime,
1023    input: &'a RecordBatch,
1024    // The function to generate the generator
1025    function: &'a Function,
1026    schema: SchemaRef,
1027    chunk_size: usize,
1028    // mutable states
1029    /// Current row index.
1030    row: usize,
1031    /// Generator of the current row.
1032    generator: Option<Persistent<Object<'static>>>,
1033    converter: &'a jsarrow::Converter,
1034}
1035
1036// XXX: not sure if this is safe.
1037unsafe impl Send for RecordBatchIter<'_> {}
1038
1039impl RecordBatchIter<'_> {
1040    /// Get the schema of the output.
1041    pub fn schema(&self) -> &Schema {
1042        &self.schema
1043    }
1044
1045    pub async fn next(&mut self) -> Result<Option<RecordBatch>> {
1046        if self.row == self.input.num_rows() {
1047            return Ok(None);
1048        }
1049        async_with!(self.rt.context => |ctx| {
1050            let js_function = self.function.function.clone().restore(&ctx)?;
1051            let mut indexes = Int32Builder::with_capacity(self.chunk_size);
1052            let mut results = Vec::with_capacity(self.input.num_rows());
1053            let mut row = Vec::with_capacity(self.input.num_columns());
1054            // restore generator from state
1055            let mut generator = match self.generator.take() {
1056                Some(generator) => {
1057                    let gen = generator.restore(&ctx)?;
1058                    let next: rquickjs::Function =
1059                        gen.get("next").context("failed to get 'next' method")?;
1060                    Some((gen, next))
1061                }
1062                None => None,
1063            };
1064            while self.row < self.input.num_rows() && results.len() < self.chunk_size {
1065                let (gen, next) = if let Some(g) = generator.as_ref() {
1066                    g
1067                } else {
1068                    // call the table function to get a generator
1069                    row.clear();
1070                    for (column, field) in
1071                        (self.input.columns().iter()).zip(self.input.schema().fields())
1072                    {
1073                        let val = self
1074                            .converter
1075                            .get_jsvalue(&ctx, field, column, self.row)
1076                            .context("failed to get jsvalue from arrow array")?;
1077                        row.push(val);
1078                    }
1079                    if self.function.options.call_mode == CallMode::ReturnNullOnNullInput
1080                        && row.iter().any(|v| v.is_null())
1081                    {
1082                        self.row += 1;
1083                        continue;
1084                    }
1085                    let mut args = Args::new(ctx.clone(), row.len());
1086                    args.push_args(row.drain(..))?;
1087                    // NOTE: A async generator function, defined by `async function*`, itself is NOT async.
1088                    // That's why we call it with `is_async = false` here.
1089                    // The result is a `AsyncGenerator`, which has a async `next` method.
1090                    let gen: Object = self
1091                        .rt
1092                        .call_user_fn(&ctx, &js_function, args, false).await
1093                        .context("failed to call function")?;
1094                    let next: rquickjs::Function =
1095                        gen.get("next").context("failed to get 'next' method")?;
1096                    let mut args = Args::new(ctx.clone(), 0);
1097                    args.this(gen.clone())?;
1098                    generator.insert((gen, next))
1099                };
1100                let mut args = Args::new(ctx.clone(), 0);
1101                args.this(gen.clone())?;
1102                let object: Object = self
1103                    .rt
1104                    .call_user_fn(&ctx, next, args, self.function.options.is_async).await
1105                    .context("failed to call next")?;
1106                let value: Value = object.get("value")?;
1107                let done: bool = object.get("done")?;
1108                if done {
1109                    self.row += 1;
1110                    generator = None;
1111                    continue;
1112                }
1113                indexes.append_value(self.row as i32);
1114                results.push(value);
1115            }
1116            self.generator = generator.map(|(gen, _)| Persistent::save(&ctx, gen));
1117
1118            if results.is_empty() {
1119                return Ok(None);
1120            }
1121            let indexes = Arc::new(indexes.finish());
1122            let array = self
1123                .converter
1124                .build_array(&self.function.return_field, &ctx, results)
1125                .context("failed to build arrow array from return values")?;
1126            Ok(Some(RecordBatch::try_new(
1127                self.schema.clone(),
1128                vec![indexes, array],
1129            )?))
1130        })
1131        .await
1132    }
1133}
1134
1135impl Stream for RecordBatchIter<'_> {
1136    type Item = Result<RecordBatch>;
1137    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1138        Box::pin(self.next().map(|v| v.transpose()))
1139            .as_mut()
1140            .poll_unpin(cx)
1141    }
1142}
1143
1144/// Get exception from `ctx` if the error is an exception.
1145pub(crate) fn check_exception(err: rquickjs::Error, ctx: &Ctx) -> anyhow::Error {
1146    match err {
1147        rquickjs::Error::Exception => {
1148            anyhow!("exception generated by QuickJS: {:?}", ctx.catch())
1149        }
1150        e => e.into(),
1151    }
1152}