arrow_udf_wasm/
lib.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("../README.md")]
16
17use anyhow::{anyhow, bail, ensure, Context};
18use arrow_array::RecordBatch;
19use ram_file::{RamFile, RamFileRef};
20use std::collections::{HashMap, HashSet};
21use std::fmt::Debug;
22use std::sync::Mutex;
23use wasi_common::{sync::WasiCtxBuilder, WasiCtx};
24use wasmtime::*;
25
26#[cfg(feature = "build")]
27pub mod build;
28mod ram_file;
29
30/// The WASM UDF runtime.
31///
32/// This runtime contains an instance pool and can be shared by multiple threads.
33pub struct Runtime {
34    module: Module,
35    /// Configurations.
36    config: Config,
37    /// Function names.
38    functions: HashSet<String>,
39    /// User-defined types.
40    types: HashMap<String, String>,
41    /// Instance pool.
42    instances: Mutex<Vec<Instance>>,
43    /// ABI version. (major, minor)
44    abi_version: (u8, u8),
45}
46
47/// Configurations.
48#[derive(Debug, Default, PartialEq, Eq)]
49#[non_exhaustive]
50pub struct Config {
51    /// Memory size limit in bytes.
52    pub memory_size_limit: Option<usize>,
53    /// File size limit in bytes.
54    pub file_size_limit: Option<usize>,
55}
56
57struct Instance {
58    // extern "C" fn(len: usize, align: usize) -> *mut u8
59    alloc: TypedFunc<(u32, u32), u32>,
60    // extern "C" fn(ptr: *mut u8, len: usize, align: usize)
61    dealloc: TypedFunc<(u32, u32, u32), ()>,
62    // extern "C" fn(iter: *mut RecordBatchIter, out: *mut CSlice)
63    record_batch_iterator_next: TypedFunc<(u32, u32), ()>,
64    // extern "C" fn(iter: *mut RecordBatchIter)
65    record_batch_iterator_drop: TypedFunc<u32, ()>,
66    // extern "C" fn(ptr: *const u8, len: usize, out: *mut CSlice) -> i32
67    functions: HashMap<String, TypedFunc<(u32, u32, u32), i32>>,
68    memory: Memory,
69    store: Store<(WasiCtx, StoreLimits)>,
70    stdout: RamFileRef,
71    stderr: RamFileRef,
72}
73
74impl Debug for Runtime {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("Runtime")
77            .field("config", &self.config)
78            .field("functions", &self.functions)
79            .field("types", &self.types)
80            .field("instances", &self.instances.lock().unwrap().len())
81            .finish()
82    }
83}
84
85impl Runtime {
86    /// Create a new UDF runtime from a WASM binary.
87    pub fn new(binary: &[u8]) -> Result<Self> {
88        Self::with_config(binary, Config::default())
89    }
90
91    /// Create a new UDF runtime from a WASM binary with configuration.
92    pub fn with_config(binary: &[u8], config: Config) -> Result<Self> {
93        // use a global engine by default
94        static ENGINE: once_cell::sync::Lazy<Engine> = once_cell::sync::Lazy::new(Engine::default);
95        Self::with_config_engine(binary, config, &ENGINE)
96    }
97
98    /// Create a new UDF runtime from a WASM binary with a customized engine.
99    fn with_config_engine(binary: &[u8], config: Config, engine: &Engine) -> Result<Self> {
100        let module = Module::from_binary(engine, binary).context("failed to load wasm binary")?;
101
102        // check abi version
103        let version = module
104            .exports()
105            .find_map(|e| e.name().strip_prefix("ARROWUDF_VERSION_"))
106            .context("version not found")?;
107        let (major, minor) = version.split_once('_').context("invalid version")?;
108        let (major, minor) = (major.parse::<u8>()?, minor.parse::<u8>()?);
109        ensure!(major <= 3, "unsupported abi version: {major}.{minor}");
110
111        let mut functions = HashSet::new();
112        let mut types = HashMap::new();
113        for export in module.exports() {
114            if let Some(encoded) = export.name().strip_prefix("arrowudf_") {
115                let name = base64_decode(encoded).context("invalid symbol")?;
116                functions.insert(name);
117            } else if let Some(encoded) = export.name().strip_prefix("arrowudt_") {
118                let meta = base64_decode(encoded).context("invalid symbol")?;
119                let (name, fields) = meta.split_once('=').context("invalid type string")?;
120                types.insert(name.to_string(), fields.to_string());
121            }
122        }
123
124        Ok(Self {
125            module,
126            config,
127            functions,
128            types,
129            instances: Mutex::new(vec![]),
130            abi_version: (major, minor),
131        })
132    }
133
134    /// Return available functions.
135    pub fn functions(&self) -> impl Iterator<Item = &str> {
136        self.functions.iter().map(|s| s.as_str())
137    }
138
139    /// Return available types.
140    pub fn types(&self) -> impl Iterator<Item = (&str, &str)> {
141        self.types.iter().map(|(k, v)| (k.as_str(), v.as_str()))
142    }
143
144    /// Return the ABI version.
145    pub fn abi_version(&self) -> (u8, u8) {
146        self.abi_version
147    }
148
149    /// Given a function signature that inlines struct types, find the function name.
150    ///
151    /// # Example
152    ///
153    /// ```text
154    /// types = { "KeyValue": "key:string,value:string" }
155    /// input = "keyvalue(string, string) -> struct<key:string,value:string>"
156    /// output = "keyvalue(string, string) -> struct KeyValue"
157    /// ```
158    pub fn find_function_by_inlined_signature(&self, s: &str) -> Option<&str> {
159        self.functions
160            .iter()
161            .find(|f| self.inline_types(f) == s)
162            .map(|f| f.as_str())
163    }
164
165    /// Inline types in function signature.
166    ///
167    /// # Example
168    ///
169    /// ```text
170    /// types = { "KeyValue": "key:string,value:string" }
171    /// input = "keyvalue(string, string) -> struct KeyValue"
172    /// output = "keyvalue(string, string) -> struct<key:string,value:string>"
173    /// ```
174    fn inline_types(&self, s: &str) -> String {
175        let mut inlined = s.to_string();
176        loop {
177            let replaced = inlined.clone();
178            for (k, v) in self.types.iter() {
179                inlined = inlined.replace(&format!("struct {k}"), &format!("struct<{v}>"));
180            }
181            if replaced == inlined {
182                return inlined;
183            }
184        }
185    }
186
187    /// Call a function.
188    pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
189        if !self.functions.contains(name) {
190            bail!("function not found: {name}");
191        }
192
193        // get an instance from the pool, or create a new one if the pool is empty
194        let mut instance = if let Some(instance) = self.instances.lock().unwrap().pop() {
195            instance
196        } else {
197            Instance::new(self)?
198        };
199
200        // call the function
201        let output = instance.call_scalar_function(name, input);
202
203        // put the instance back to the pool
204        if output.is_ok() {
205            self.instances.lock().unwrap().push(instance);
206        }
207
208        output
209    }
210
211    /// Call a table function.
212    pub fn call_table_function<'a>(
213        &'a self,
214        name: &'a str,
215        input: &'a RecordBatch,
216    ) -> Result<impl Iterator<Item = Result<RecordBatch>> + 'a> {
217        use genawaiter2::{sync::gen, yield_};
218        if !self.functions.contains(name) {
219            bail!("function not found: {name}");
220        }
221
222        // get an instance from the pool, or create a new one if the pool is empty
223        let mut instance = if let Some(instance) = self.instances.lock().unwrap().pop() {
224            instance
225        } else {
226            Instance::new(self)?
227        };
228
229        Ok(gen!({
230            // call the function
231            let iter = match instance.call_table_function(name, input) {
232                Ok(iter) => iter,
233                Err(e) => {
234                    yield_!(Err(e));
235                    return;
236                }
237            };
238            for output in iter {
239                yield_!(output);
240            }
241            // put the instance back to the pool
242            // FIXME: if the iterator is not consumed, the instance will be dropped
243            self.instances.lock().unwrap().push(instance);
244        })
245        .into_iter())
246    }
247}
248
249impl Instance {
250    /// Create a new instance.
251    fn new(rt: &Runtime) -> Result<Self> {
252        let module = &rt.module;
253        let engine = module.engine();
254        let mut linker = Linker::new(engine);
255        wasi_common::sync::add_to_linker(&mut linker, |(wasi, _)| wasi)?;
256
257        // Create a WASI context and put it in a Store; all instances in the store
258        // share this context. `WasiCtxBuilder` provides a number of ways to
259        // configure what the target program will have access to.
260        let file_size_limit = rt.config.file_size_limit.unwrap_or(1024);
261        let stdout = RamFileRef::new(RamFile::with_size_limit(file_size_limit));
262        let stderr = RamFileRef::new(RamFile::with_size_limit(file_size_limit));
263        let wasi = WasiCtxBuilder::new()
264            .stdout(Box::new(stdout.clone()))
265            .stderr(Box::new(stderr.clone()))
266            .build();
267        let limits = {
268            let mut builder = StoreLimitsBuilder::new();
269            if let Some(limit) = rt.config.memory_size_limit {
270                builder = builder.memory_size(limit);
271            }
272            builder.build()
273        };
274        let mut store = Store::new(engine, (wasi, limits));
275        store.limiter(|(_, limiter)| limiter);
276
277        let instance = linker.instantiate(&mut store, module)?;
278        let mut functions = HashMap::new();
279        for export in module.exports() {
280            let Some(encoded) = export.name().strip_prefix("arrowudf_") else {
281                continue;
282            };
283            let name = base64_decode(encoded).context("invalid symbol")?;
284            let func = instance.get_typed_func(&mut store, export.name())?;
285            functions.insert(name, func);
286        }
287        let alloc = instance.get_typed_func(&mut store, "alloc")?;
288        let dealloc = instance.get_typed_func(&mut store, "dealloc")?;
289        let record_batch_iterator_next =
290            instance.get_typed_func(&mut store, "record_batch_iterator_next")?;
291        let record_batch_iterator_drop =
292            instance.get_typed_func(&mut store, "record_batch_iterator_drop")?;
293        let memory = instance
294            .get_memory(&mut store, "memory")
295            .context("no memory")?;
296
297        Ok(Instance {
298            alloc,
299            dealloc,
300            record_batch_iterator_next,
301            record_batch_iterator_drop,
302            memory,
303            store,
304            functions,
305            stdout,
306            stderr,
307        })
308    }
309
310    /// Call a scalar function.
311    fn call_scalar_function(&mut self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
312        // TODO: optimize data transfer
313        // currently there are 3 copies in input path:
314        //      host record batch -> host encoding -> wasm memory -> wasm record batch
315        // and 2 copies in output path:
316        //      wasm record batch -> wasm memory -> host record batch
317
318        // get function
319        let func = self
320            .functions
321            .get(name)
322            .with_context(|| format!("function not found: {name}"))?;
323
324        // encode input batch
325        let input = encode_record_batch(input)?;
326
327        // allocate memory for input buffer and output struct
328        let alloc_len = u32::try_from(input.len() + 4 * 2).context("input too large")?;
329        let alloc_ptr = self.alloc.call(&mut self.store, (alloc_len, 4))?;
330        ensure!(alloc_ptr != 0, "failed to allocate for input");
331        let in_ptr = alloc_ptr + 4 * 2;
332
333        // write input to memory
334        self.memory
335            .write(&mut self.store, in_ptr as usize, &input)?;
336
337        // call the function
338        let result = func.call(&mut self.store, (in_ptr, input.len() as u32, alloc_ptr));
339        let errno = self.append_stdio(result)?;
340
341        // get return values
342        let out_ptr = self.read_u32(alloc_ptr)?;
343        let out_len = self.read_u32(alloc_ptr + 4)?;
344
345        // read output from memory
346        let out_bytes = self
347            .memory
348            .data(&self.store)
349            .get(out_ptr as usize..(out_ptr + out_len) as usize)
350            .context("output slice out of bounds")?;
351        let result = match errno {
352            0 => Ok(decode_record_batch(out_bytes)?),
353            _ => Err(anyhow!("{}", std::str::from_utf8(out_bytes)?)),
354        };
355
356        // deallocate memory
357        self.dealloc
358            .call(&mut self.store, (alloc_ptr, alloc_len, 4))?;
359        self.dealloc.call(&mut self.store, (out_ptr, out_len, 1))?;
360
361        result
362    }
363
364    /// Call a table function.
365    fn call_table_function<'a>(
366        &'a mut self,
367        name: &str,
368        input: &RecordBatch,
369    ) -> Result<impl Iterator<Item = Result<RecordBatch>> + 'a> {
370        // TODO: optimize data transfer
371        // currently there are 3 copies in input path:
372        //      host record batch -> host encoding -> wasm memory -> wasm record batch
373        // and 2 copies in output path:
374        //      wasm record batch -> wasm memory -> host record batch
375
376        // get function
377        let func = self
378            .functions
379            .get(name)
380            .with_context(|| format!("function not found: {name}"))?;
381
382        // encode input batch
383        let input = encode_record_batch(input)?;
384
385        // allocate memory for input buffer and output struct
386        let alloc_len = u32::try_from(input.len() + 4 * 2).context("input too large")?;
387        let alloc_ptr = self.alloc.call(&mut self.store, (alloc_len, 4))?;
388        ensure!(alloc_ptr != 0, "failed to allocate for input");
389        let in_ptr = alloc_ptr + 4 * 2;
390
391        // write input to memory
392        self.memory
393            .write(&mut self.store, in_ptr as usize, &input)?;
394
395        // call the function
396        let result = func.call(&mut self.store, (in_ptr, input.len() as u32, alloc_ptr));
397        let errno = self.append_stdio(result)?;
398
399        // get return values
400        let out_ptr = self.read_u32(alloc_ptr)?;
401        let out_len = self.read_u32(alloc_ptr + 4)?;
402
403        // read output from memory
404        let out_bytes = self
405            .memory
406            .data(&self.store)
407            .get(out_ptr as usize..(out_ptr + out_len) as usize)
408            .context("output slice out of bounds")?;
409
410        let ptr = match errno {
411            0 => out_ptr,
412            _ => {
413                let err = anyhow!("{}", std::str::from_utf8(out_bytes)?);
414                // deallocate memory
415                self.dealloc
416                    .call(&mut self.store, (alloc_ptr, alloc_len, 4))?;
417                self.dealloc.call(&mut self.store, (out_ptr, out_len, 1))?;
418
419                return Err(err);
420            }
421        };
422
423        struct RecordBatchIter<'a> {
424            instance: &'a mut Instance,
425            ptr: u32,
426            alloc_ptr: u32,
427            alloc_len: u32,
428        }
429
430        impl RecordBatchIter<'_> {
431            /// Get the next record batch.
432            fn next(&mut self) -> Result<Option<RecordBatch>> {
433                self.instance
434                    .record_batch_iterator_next
435                    .call(&mut self.instance.store, (self.ptr, self.alloc_ptr))?;
436                // get return values
437                let out_ptr = self.instance.read_u32(self.alloc_ptr)?;
438                let out_len = self.instance.read_u32(self.alloc_ptr + 4)?;
439
440                if out_ptr == 0 {
441                    // end of iteration
442                    return Ok(None);
443                }
444
445                // read output from memory
446                let out_bytes = self
447                    .instance
448                    .memory
449                    .data(&self.instance.store)
450                    .get(out_ptr as usize..(out_ptr + out_len) as usize)
451                    .context("output slice out of bounds")?;
452                let batch = decode_record_batch(out_bytes)?;
453
454                // dealloc output
455                self.instance
456                    .dealloc
457                    .call(&mut self.instance.store, (out_ptr, out_len, 1))?;
458
459                Ok(Some(batch))
460            }
461        }
462
463        impl Iterator for RecordBatchIter<'_> {
464            type Item = Result<RecordBatch>;
465
466            fn next(&mut self) -> Option<Self::Item> {
467                let result = self.next();
468                self.instance.append_stdio(result).transpose()
469            }
470        }
471
472        impl Drop for RecordBatchIter<'_> {
473            fn drop(&mut self) {
474                _ = self.instance.dealloc.call(
475                    &mut self.instance.store,
476                    (self.alloc_ptr, self.alloc_len, 4),
477                );
478                _ = self
479                    .instance
480                    .record_batch_iterator_drop
481                    .call(&mut self.instance.store, self.ptr);
482            }
483        }
484
485        Ok(RecordBatchIter {
486            instance: self,
487            ptr,
488            alloc_ptr,
489            alloc_len,
490        })
491    }
492
493    /// Read a `u32` from memory.
494    fn read_u32(&mut self, ptr: u32) -> Result<u32> {
495        Ok(u32::from_le_bytes(
496            self.memory.data(&self.store)[ptr as usize..(ptr + 4) as usize]
497                .try_into()
498                .unwrap(),
499        ))
500    }
501
502    /// Take stdout and stderr, append to the error context.
503    fn append_stdio<T>(&self, result: Result<T>) -> Result<T> {
504        let stdout = self.stdout.take();
505        let stderr = self.stderr.take();
506        match result {
507            Ok(v) => Ok(v),
508            Err(e) => Err(e.context(format!(
509                "--- stdout\n{}\n--- stderr\n{}",
510                String::from_utf8_lossy(&stdout),
511                String::from_utf8_lossy(&stderr),
512            ))),
513        }
514    }
515}
516
517/// Decode a string from symbol name using customized base64.
518fn base64_decode(input: &str) -> Result<String> {
519    use base64::{
520        alphabet::Alphabet,
521        engine::{general_purpose::NO_PAD, GeneralPurpose},
522        Engine,
523    };
524    // standard base64 uses '+' and '/', which is not a valid symbol name.
525    // we use '$' and '_' instead.
526    let alphabet =
527        Alphabet::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789$_").unwrap();
528    let engine = GeneralPurpose::new(&alphabet, NO_PAD);
529    let bytes = engine.decode(input)?;
530    String::from_utf8(bytes).context("invalid utf8")
531}
532
533fn encode_record_batch(batch: &RecordBatch) -> Result<Vec<u8>> {
534    let mut buf = vec![];
535    let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batch.schema())?;
536    writer.write(batch)?;
537    writer.finish()?;
538    drop(writer);
539    Ok(buf)
540}
541
542fn decode_record_batch(bytes: &[u8]) -> Result<RecordBatch> {
543    let mut reader = arrow_ipc::reader::FileReader::try_new(std::io::Cursor::new(bytes), None)?;
544    let batch = reader.next().unwrap()?;
545    Ok(batch)
546}