1#![doc = include_str!("README.md")]
16
17use anyhow::{anyhow, bail, ensure, Context};
18use arrow_array::RecordBatch;
19use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit};
20use itertools::Itertools;
21use ram_file::{RamFile, RamFileRef};
22use std::collections::{HashMap, HashSet};
23use std::fmt::Debug;
24use std::sync::{Arc, Mutex};
25use wasi_common::{sync::WasiCtxBuilder, WasiCtx};
26use wasmtime::*;
27
28use crate::into_field::IntoField;
29
30#[cfg(feature = "wasm-build")]
31pub mod build;
32mod ram_file;
33
34pub struct Runtime {
38 module: Module,
39 config: Config,
41 functions: Arc<HashSet<String>>,
43 types: Arc<HashMap<String, String>>,
45 instances: Mutex<Vec<Instance>>,
47 abi_version: (u8, u8),
49}
50
51#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
53#[non_exhaustive]
54pub struct Config {
55 pub memory_size_limit: Option<usize>,
57 pub file_size_limit: Option<usize>,
59}
60
61impl Clone for Runtime {
64 fn clone(&self) -> Self {
65 Self {
66 module: self.module.clone(), config: self.config,
68 functions: self.functions.clone(),
69 types: self.types.clone(),
70 instances: Default::default(), abi_version: self.abi_version,
72 }
73 }
74}
75
76impl Debug for Runtime {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.debug_struct("Runtime")
79 .field("config", &self.config)
80 .field("functions", &self.functions)
81 .field("types", &self.types)
82 .field("instances", &self.instances.lock().unwrap().len())
83 .finish()
84 }
85}
86
87impl Runtime {
88 pub fn new(binary: &[u8]) -> Result<Self> {
90 Self::with_config(binary, Config::default())
91 }
92
93 pub fn with_config(binary: &[u8], config: Config) -> Result<Self> {
95 static ENGINE: std::sync::LazyLock<Engine> = std::sync::LazyLock::new(Engine::default);
97 Self::with_config_engine(binary, config, &ENGINE)
98 }
99
100 fn with_config_engine(binary: &[u8], config: Config, engine: &Engine) -> Result<Self> {
102 let module = Module::from_binary(engine, binary).context("failed to load wasm binary")?;
103
104 let version = module
106 .exports()
107 .find_map(|e| e.name().strip_prefix("ARROWUDF_VERSION_"))
108 .context("version not found")?;
109 let (major, minor) = version.split_once('_').context("invalid version")?;
110 let (major, minor) = (major.parse::<u8>()?, minor.parse::<u8>()?);
111 ensure!(major <= 3, "unsupported abi version: {major}.{minor}");
112
113 let mut functions = HashSet::new();
114 let mut types = HashMap::new();
115 for export in module.exports() {
116 if let Some(encoded) = export.name().strip_prefix("arrowudf_") {
117 let name = base64_decode(encoded).context("invalid symbol")?;
118 functions.insert(name);
119 } else if let Some(encoded) = export.name().strip_prefix("arrowudt_") {
120 let meta = base64_decode(encoded).context("invalid symbol")?;
121 let (name, fields) = meta.split_once('=').context("invalid type string")?;
122 types.insert(name.to_string(), fields.to_string());
123 }
124 }
125
126 Ok(Self {
127 module,
128 config,
129 functions: functions.into(),
130 types: types.into(),
131 instances: Mutex::new(vec![]),
132 abi_version: (major, minor),
133 })
134 }
135
136 pub fn functions(&self) -> impl Iterator<Item = &str> {
138 self.functions.iter().map(|s| s.as_str())
139 }
140
141 pub fn types(&self) -> impl Iterator<Item = (&str, &str)> {
143 self.types.iter().map(|(k, v)| (k.as_str(), v.as_str()))
144 }
145
146 pub fn abi_version(&self) -> (u8, u8) {
148 self.abi_version
149 }
150
151 pub fn find_function(
154 &self,
155 name: &str,
156 arg_types: Vec<impl IntoField>,
157 return_type: impl IntoField,
158 ) -> Result<FunctionHandle> {
159 let args = arg_types
161 .into_iter()
162 .map(|x| x.into_field(""))
163 .collect::<Vec<_>>();
164 let ret = return_type.into_field("");
165 let sig = function_signature_of(name, &args, &ret, false)?;
166 if let Some(export_name) = self.get_function_export_name_by_inlined_signature(&sig) {
168 return Ok(FunctionHandle {
169 export_name: export_name.to_owned(),
170 });
171 }
172 bail!(
173 "function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}\navailable types:\n {}",
174 sig,
175 self.functions.iter().join("\n "),
176 self.types.iter().map(|(k, v)| format!("{k}: {v}")).join("\n "),
177 )
178 }
179
180 pub fn find_table_function(
183 &self,
184 name: &str,
185 arg_types: Vec<impl IntoField>,
186 return_type: impl IntoField,
187 ) -> Result<TableFunctionHandle> {
188 let args = arg_types
190 .into_iter()
191 .map(|x| x.into_field(""))
192 .collect::<Vec<_>>();
193 let ret = return_type.into_field("");
194 let sig = function_signature_of(name, &args, &ret, true)?;
195 if let Some(export_name) = self.get_function_export_name_by_inlined_signature(&sig) {
197 return Ok(TableFunctionHandle {
198 export_name: export_name.to_owned(),
199 });
200 }
201 bail!(
202 "table function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}\navailable types:\n {}",
203 sig,
204 self.functions.iter().join("\n "),
205 self.types.iter().map(|(k, v)| format!("{k}: {v}")).join("\n "),
206 )
207 }
208
209 fn get_function_export_name_by_inlined_signature(&self, s: &str) -> Option<&str> {
219 if let Some(f) = self.functions.get(s) {
220 return Some(f);
221 }
222 self.functions
223 .iter()
224 .find(|f| self.inline_types(f) == s)
225 .map(|f| f.as_str())
226 }
227
228 fn inline_types(&self, s: &str) -> String {
238 let mut inlined = s.to_string();
239 loop {
240 let replaced = inlined.clone();
241 for (k, v) in self.types.iter() {
242 inlined = inlined.replace(&format!("struct {k}"), &format!("struct<{v}>"));
243 }
244 if replaced == inlined {
245 return inlined;
246 }
247 }
248 }
249
250 pub fn call(&self, func: &FunctionHandle, input: &RecordBatch) -> Result<RecordBatch> {
252 let export_name = &func.export_name;
253 if !self.functions.contains(export_name) {
254 bail!("function not found: {export_name}");
255 }
256
257 let mut instance = if let Some(instance) = self.instances.lock().unwrap().pop() {
259 instance
260 } else {
261 Instance::new(self)?
262 };
263
264 let output = instance.call_scalar_function(export_name, input);
266
267 if output.is_ok() {
269 self.instances.lock().unwrap().push(instance);
270 }
271
272 output
273 }
274
275 pub fn call_table_function<'a>(
277 &'a self,
278 func: &'a TableFunctionHandle,
279 input: &'a RecordBatch,
280 ) -> Result<impl Iterator<Item = Result<RecordBatch>> + 'a> {
281 use genawaiter2::{sync::gen, yield_};
282
283 let export_name = &func.export_name;
284 if !self.functions.contains(export_name) {
285 bail!("function not found: {export_name}");
286 }
287
288 let mut instance = if let Some(instance) = self.instances.lock().unwrap().pop() {
290 instance
291 } else {
292 Instance::new(self)?
293 };
294
295 Ok(gen!({
296 let iter = match instance.call_table_function(export_name, input) {
298 Ok(iter) => iter,
299 Err(e) => {
300 yield_!(Err(e));
301 return;
302 }
303 };
304 for output in iter {
305 yield_!(output);
306 }
307 self.instances.lock().unwrap().push(instance);
310 })
311 .into_iter())
312 }
313}
314
315pub struct FunctionHandle {
316 export_name: String,
317}
318
319pub struct TableFunctionHandle {
320 export_name: String,
321}
322
323struct Instance {
324 alloc: TypedFunc<(u32, u32), u32>,
326 dealloc: TypedFunc<(u32, u32, u32), ()>,
328 record_batch_iterator_next: TypedFunc<(u32, u32), ()>,
330 record_batch_iterator_drop: TypedFunc<u32, ()>,
332 functions: HashMap<String, TypedFunc<(u32, u32, u32), i32>>,
334 memory: Memory,
335 store: Store<(WasiCtx, StoreLimits)>,
336 stdout: RamFileRef,
337 stderr: RamFileRef,
338}
339
340impl Instance {
341 fn new(rt: &Runtime) -> Result<Self> {
343 let module = &rt.module;
344 let engine = module.engine();
345 let mut linker = Linker::new(engine);
346 wasi_common::sync::add_to_linker(&mut linker, |(wasi, _)| wasi)?;
347
348 let file_size_limit = rt.config.file_size_limit.unwrap_or(1024);
352 let stdout = RamFileRef::new(RamFile::with_size_limit(file_size_limit));
353 let stderr = RamFileRef::new(RamFile::with_size_limit(file_size_limit));
354 let wasi = WasiCtxBuilder::new()
355 .stdout(Box::new(stdout.clone()))
356 .stderr(Box::new(stderr.clone()))
357 .build();
358 let limits = {
359 let mut builder = StoreLimitsBuilder::new();
360 if let Some(limit) = rt.config.memory_size_limit {
361 builder = builder.memory_size(limit);
362 }
363 builder.build()
364 };
365 let mut store = Store::new(engine, (wasi, limits));
366 store.limiter(|(_, limiter)| limiter);
367
368 let instance = linker.instantiate(&mut store, module)?;
369 let mut functions = HashMap::new();
370 for export in module.exports() {
371 let Some(encoded) = export.name().strip_prefix("arrowudf_") else {
372 continue;
373 };
374 let name = base64_decode(encoded).context("invalid symbol")?;
375 let func = instance.get_typed_func(&mut store, export.name())?;
376 functions.insert(name, func);
377 }
378 let alloc = instance.get_typed_func(&mut store, "alloc")?;
379 let dealloc = instance.get_typed_func(&mut store, "dealloc")?;
380 let record_batch_iterator_next =
381 instance.get_typed_func(&mut store, "record_batch_iterator_next")?;
382 let record_batch_iterator_drop =
383 instance.get_typed_func(&mut store, "record_batch_iterator_drop")?;
384 let memory = instance
385 .get_memory(&mut store, "memory")
386 .context("no memory")?;
387
388 Ok(Instance {
389 alloc,
390 dealloc,
391 record_batch_iterator_next,
392 record_batch_iterator_drop,
393 memory,
394 store,
395 functions,
396 stdout,
397 stderr,
398 })
399 }
400
401 fn call_scalar_function(&mut self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
403 let func = self
411 .functions
412 .get(name)
413 .with_context(|| format!("function not found: {name}"))?;
414
415 let input = encode_record_batch(input)?;
417
418 let alloc_len = u32::try_from(input.len() + 4 * 2).context("input too large")?;
420 let alloc_ptr = self.alloc.call(&mut self.store, (alloc_len, 4))?;
421 ensure!(alloc_ptr != 0, "failed to allocate for input");
422 let in_ptr = alloc_ptr + 4 * 2;
423
424 self.memory
426 .write(&mut self.store, in_ptr as usize, &input)?;
427
428 let result = func.call(&mut self.store, (in_ptr, input.len() as u32, alloc_ptr));
430 let errno = self.append_stdio(result)?;
431
432 let out_ptr = self.read_u32(alloc_ptr)?;
434 let out_len = self.read_u32(alloc_ptr + 4)?;
435
436 let out_bytes = self
438 .memory
439 .data(&self.store)
440 .get(out_ptr as usize..(out_ptr + out_len) as usize)
441 .context("output slice out of bounds")?;
442 let result = match errno {
443 0 => Ok(decode_record_batch(out_bytes)?),
444 _ => Err(anyhow!("{}", std::str::from_utf8(out_bytes)?)),
445 };
446
447 self.dealloc
449 .call(&mut self.store, (alloc_ptr, alloc_len, 4))?;
450 self.dealloc.call(&mut self.store, (out_ptr, out_len, 1))?;
451
452 result
453 }
454
455 fn call_table_function<'a>(
457 &'a mut self,
458 name: &str,
459 input: &RecordBatch,
460 ) -> Result<impl Iterator<Item = Result<RecordBatch>> + 'a> {
461 let func = self
469 .functions
470 .get(name)
471 .with_context(|| format!("function not found: {name}"))?;
472
473 let input = encode_record_batch(input)?;
475
476 let alloc_len = u32::try_from(input.len() + 4 * 2).context("input too large")?;
478 let alloc_ptr = self.alloc.call(&mut self.store, (alloc_len, 4))?;
479 ensure!(alloc_ptr != 0, "failed to allocate for input");
480 let in_ptr = alloc_ptr + 4 * 2;
481
482 self.memory
484 .write(&mut self.store, in_ptr as usize, &input)?;
485
486 let result = func.call(&mut self.store, (in_ptr, input.len() as u32, alloc_ptr));
488 let errno = self.append_stdio(result)?;
489
490 let out_ptr = self.read_u32(alloc_ptr)?;
492 let out_len = self.read_u32(alloc_ptr + 4)?;
493
494 let out_bytes = self
496 .memory
497 .data(&self.store)
498 .get(out_ptr as usize..(out_ptr + out_len) as usize)
499 .context("output slice out of bounds")?;
500
501 let ptr = match errno {
502 0 => out_ptr,
503 _ => {
504 let err = anyhow!("{}", std::str::from_utf8(out_bytes)?);
505 self.dealloc
507 .call(&mut self.store, (alloc_ptr, alloc_len, 4))?;
508 self.dealloc.call(&mut self.store, (out_ptr, out_len, 1))?;
509
510 return Err(err);
511 }
512 };
513
514 struct RecordBatchIter<'a> {
515 instance: &'a mut Instance,
516 ptr: u32,
517 alloc_ptr: u32,
518 alloc_len: u32,
519 }
520
521 impl RecordBatchIter<'_> {
522 fn next(&mut self) -> Result<Option<RecordBatch>> {
524 self.instance
525 .record_batch_iterator_next
526 .call(&mut self.instance.store, (self.ptr, self.alloc_ptr))?;
527 let out_ptr = self.instance.read_u32(self.alloc_ptr)?;
529 let out_len = self.instance.read_u32(self.alloc_ptr + 4)?;
530
531 if out_ptr == 0 {
532 return Ok(None);
534 }
535
536 let out_bytes = self
538 .instance
539 .memory
540 .data(&self.instance.store)
541 .get(out_ptr as usize..(out_ptr + out_len) as usize)
542 .context("output slice out of bounds")?;
543 let batch = decode_record_batch(out_bytes)?;
544
545 self.instance
547 .dealloc
548 .call(&mut self.instance.store, (out_ptr, out_len, 1))?;
549
550 Ok(Some(batch))
551 }
552 }
553
554 impl Iterator for RecordBatchIter<'_> {
555 type Item = Result<RecordBatch>;
556
557 fn next(&mut self) -> Option<Self::Item> {
558 let result = self.next();
559 self.instance.append_stdio(result).transpose()
560 }
561 }
562
563 impl Drop for RecordBatchIter<'_> {
564 fn drop(&mut self) {
565 _ = self.instance.dealloc.call(
566 &mut self.instance.store,
567 (self.alloc_ptr, self.alloc_len, 4),
568 );
569 _ = self
570 .instance
571 .record_batch_iterator_drop
572 .call(&mut self.instance.store, self.ptr);
573 }
574 }
575
576 Ok(RecordBatchIter {
577 instance: self,
578 ptr,
579 alloc_ptr,
580 alloc_len,
581 })
582 }
583
584 fn read_u32(&mut self, ptr: u32) -> Result<u32> {
586 Ok(u32::from_le_bytes(
587 self.memory.data(&self.store)[ptr as usize..(ptr + 4) as usize]
588 .try_into()
589 .unwrap(),
590 ))
591 }
592
593 fn append_stdio<T>(&self, result: Result<T>) -> Result<T> {
595 let stdout = self.stdout.take();
596 let stderr = self.stderr.take();
597 match result {
598 Ok(v) => Ok(v),
599 Err(e) => Err(e.context(format!(
600 "--- stdout\n{}\n--- stderr\n{}",
601 String::from_utf8_lossy(&stdout),
602 String::from_utf8_lossy(&stderr),
603 ))),
604 }
605 }
606}
607
608fn base64_decode(input: &str) -> Result<String> {
610 use base64::{
611 alphabet::Alphabet,
612 engine::{general_purpose::NO_PAD, GeneralPurpose},
613 Engine,
614 };
615 let alphabet =
618 Alphabet::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789$_").unwrap();
619 let engine = GeneralPurpose::new(&alphabet, NO_PAD);
620 let bytes = engine.decode(input)?;
621 String::from_utf8(bytes).context("invalid utf8")
622}
623
624fn encode_record_batch(batch: &RecordBatch) -> Result<Vec<u8>> {
625 let mut buf = vec![];
626 let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batch.schema())?;
627 writer.write(batch)?;
628 writer.finish()?;
629 drop(writer);
630 Ok(buf)
631}
632
633fn decode_record_batch(bytes: &[u8]) -> Result<RecordBatch> {
634 let mut reader = arrow_ipc::reader::FileReader::try_new(std::io::Cursor::new(bytes), None)?;
635 let batch = reader.next().unwrap()?;
636 Ok(batch)
637}
638
639fn function_signature_of(
641 name: &str,
642 args: &[Field],
643 ret: &Field,
644 is_table_function: bool,
645) -> Result<String> {
646 let args: Vec<_> = args.iter().map(field_to_typename).try_collect()?;
647 let ret = field_to_typename(ret)?;
648 Ok(format!(
649 "{}({}){}{}",
650 name,
651 args.iter().join(","),
652 if is_table_function { "->>" } else { "->" },
653 ret
654 ))
655}
656
657fn field_to_typename(field: &Field) -> Result<String> {
658 if let Some(ext_typename) = field.metadata().get("ARROW:extension:name") {
659 if let Some(typename) = ext_typename.strip_prefix("arrowudf.") {
660 return Ok(typename.to_owned());
662 }
663 }
664 let ty = field.data_type();
665 Ok(match ty {
668 DataType::Null => "null".to_owned(),
669 DataType::Boolean => "boolean".to_owned(),
670 DataType::Int8 => "int8".to_owned(),
671 DataType::Int16 => "int16".to_owned(),
672 DataType::Int32 => "int32".to_owned(),
673 DataType::Int64 => "int64".to_owned(),
674 DataType::UInt8 => "uint8".to_owned(),
675 DataType::UInt16 => "uint16".to_owned(),
676 DataType::UInt32 => "uint32".to_owned(),
677 DataType::UInt64 => "uint64".to_owned(),
678 DataType::Float32 => "float32".to_owned(),
679 DataType::Float64 => "float64".to_owned(),
680 DataType::Date32 => "date32".to_owned(),
681 DataType::Time64(TimeUnit::Microsecond) => "time64".to_owned(),
682 DataType::Timestamp(TimeUnit::Microsecond, None) => "timestamp".to_owned(),
683 DataType::Interval(IntervalUnit::MonthDayNano) => "interval".to_owned(),
684 DataType::Utf8 => "string".to_owned(),
685 DataType::Binary => "binary".to_owned(),
686 DataType::LargeUtf8 => "largestring".to_owned(),
687 DataType::LargeBinary => "largebinary".to_owned(),
688 DataType::List(elem) => format!("{}[]", field_to_typename(elem)?),
689 DataType::Struct(fields) => {
690 let fields: Vec<_> = fields
691 .iter()
692 .map(|x| Ok::<_, anyhow::Error>((x.name(), field_to_typename(x)?)))
693 .try_collect()?;
694 format!(
695 "struct<{}>",
696 fields
697 .iter()
698 .map(|(name, typename)| format!("{}:{}", name, typename))
699 .join(",")
700 )
701 }
702 _ => {
703 bail!("unsupported data type: {ty}");
704 }
705 })
706}