1#![doc = include_str!("README.md")]
16
17use anyhow::{Context, anyhow, bail, ensure};
18use arrow_array::RecordBatch;
19use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit};
20use itertools::Itertools;
21use ram_file::{RamFile, RamFileRef};
22use std::collections::{HashMap, HashSet};
23use std::fmt::Debug;
24use std::sync::{Arc, Mutex};
25use wasi_common::{WasiCtx, sync::WasiCtxBuilder};
26use wasmtime::*;
27
28use crate::into_field::IntoField;
29
30#[cfg(feature = "wasm-build")]
31pub mod build;
32mod ram_file;
33
34pub struct Runtime {
38 module: Module,
39 config: Config,
41 functions: Arc<HashSet<String>>,
43 types: Arc<HashMap<String, String>>,
45 instances: Mutex<Vec<Instance>>,
47 abi_version: (u8, u8),
49}
50
51#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
53#[non_exhaustive]
54pub struct Config {
55 pub memory_size_limit: Option<usize>,
57 pub file_size_limit: Option<usize>,
59}
60
61impl Clone for Runtime {
64 fn clone(&self) -> Self {
65 Self {
66 module: self.module.clone(), config: self.config,
68 functions: self.functions.clone(),
69 types: self.types.clone(),
70 instances: Default::default(), abi_version: self.abi_version,
72 }
73 }
74}
75
76impl Debug for Runtime {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.debug_struct("Runtime")
79 .field("config", &self.config)
80 .field("functions", &self.functions)
81 .field("types", &self.types)
82 .field("instances", &self.instances.lock().unwrap().len())
83 .finish()
84 }
85}
86
87impl Runtime {
88 pub fn new(binary: &[u8]) -> Result<Self> {
90 Self::with_config(binary, Config::default())
91 }
92
93 pub fn with_config(binary: &[u8], config: Config) -> Result<Self> {
95 static ENGINE: std::sync::LazyLock<Engine> = std::sync::LazyLock::new(Engine::default);
97 Self::with_config_engine(binary, config, &ENGINE)
98 }
99
100 fn with_config_engine(binary: &[u8], config: Config, engine: &Engine) -> Result<Self> {
102 let module = Module::from_binary(engine, binary).context("failed to load wasm binary")?;
103
104 let version = module
106 .exports()
107 .find_map(|e| e.name().strip_prefix("ARROWUDF_VERSION_"))
108 .context("version not found")?;
109 let (major, minor) = version.split_once('_').context("invalid version")?;
110 let (major, minor) = (major.parse::<u8>()?, minor.parse::<u8>()?);
111 ensure!(major <= 3, "unsupported abi version: {major}.{minor}");
112
113 let mut functions = HashSet::new();
114 let mut types = HashMap::new();
115 for export in module.exports() {
116 if let Some(encoded) = export.name().strip_prefix("arrowudf_") {
117 let name = base64_decode(encoded).context("invalid symbol")?;
118 functions.insert(name);
119 } else if let Some(encoded) = export.name().strip_prefix("arrowudt_") {
120 let meta = base64_decode(encoded).context("invalid symbol")?;
121 let (name, fields) = meta.split_once('=').context("invalid type string")?;
122 types.insert(name.to_string(), fields.to_string());
123 }
124 }
125
126 Ok(Self {
127 module,
128 config,
129 functions: functions.into(),
130 types: types.into(),
131 instances: Mutex::new(vec![]),
132 abi_version: (major, minor),
133 })
134 }
135
136 pub fn functions(&self) -> impl Iterator<Item = &str> {
138 self.functions.iter().map(|s| s.as_str())
139 }
140
141 pub fn types(&self) -> impl Iterator<Item = (&str, &str)> {
143 self.types.iter().map(|(k, v)| (k.as_str(), v.as_str()))
144 }
145
146 pub fn abi_version(&self) -> (u8, u8) {
148 self.abi_version
149 }
150
151 pub fn find_function(
154 &self,
155 name: &str,
156 arg_types: Vec<impl IntoField>,
157 return_type: impl IntoField,
158 ) -> Result<FunctionHandle> {
159 let args = arg_types
161 .into_iter()
162 .map(|x| x.into_field(""))
163 .collect::<Vec<_>>();
164 let ret = return_type.into_field("");
165 let sig = function_signature_of(name, &args, &ret, false)?;
166 if let Some(export_name) = self.get_function_export_name_by_inlined_signature(&sig) {
168 return Ok(FunctionHandle {
169 export_name: export_name.to_owned(),
170 });
171 }
172 bail!(
173 "function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}\navailable types:\n {}",
174 sig,
175 self.functions.iter().join("\n "),
176 self.types
177 .iter()
178 .map(|(k, v)| format!("{k}: {v}"))
179 .join("\n "),
180 )
181 }
182
183 pub fn find_table_function(
186 &self,
187 name: &str,
188 arg_types: Vec<impl IntoField>,
189 return_type: impl IntoField,
190 ) -> Result<TableFunctionHandle> {
191 let args = arg_types
193 .into_iter()
194 .map(|x| x.into_field(""))
195 .collect::<Vec<_>>();
196 let ret = return_type.into_field("");
197 let sig = function_signature_of(name, &args, &ret, true)?;
198 if let Some(export_name) = self.get_function_export_name_by_inlined_signature(&sig) {
200 return Ok(TableFunctionHandle {
201 export_name: export_name.to_owned(),
202 });
203 }
204 bail!(
205 "table function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}\navailable types:\n {}",
206 sig,
207 self.functions.iter().join("\n "),
208 self.types
209 .iter()
210 .map(|(k, v)| format!("{k}: {v}"))
211 .join("\n "),
212 )
213 }
214
215 fn get_function_export_name_by_inlined_signature(&self, s: &str) -> Option<&str> {
225 if let Some(f) = self.functions.get(s) {
226 return Some(f);
227 }
228 self.functions
229 .iter()
230 .find(|f| self.inline_types(f) == s)
231 .map(|f| f.as_str())
232 }
233
234 fn inline_types(&self, s: &str) -> String {
244 let mut inlined = s.to_string();
245 loop {
246 let replaced = inlined.clone();
247 for (k, v) in self.types.iter() {
248 inlined = inlined.replace(&format!("struct {k}"), &format!("struct<{v}>"));
249 }
250 if replaced == inlined {
251 return inlined;
252 }
253 }
254 }
255
256 pub fn call(&self, func: &FunctionHandle, input: &RecordBatch) -> Result<RecordBatch> {
258 let export_name = &func.export_name;
259 if !self.functions.contains(export_name) {
260 bail!("function not found: {export_name}");
261 }
262
263 let mut instance = if let Some(instance) = self.instances.lock().unwrap().pop() {
265 instance
266 } else {
267 Instance::new(self)?
268 };
269
270 let output = instance.call_scalar_function(export_name, input);
272
273 if output.is_ok() {
275 self.instances.lock().unwrap().push(instance);
276 }
277
278 output
279 }
280
281 pub fn call_table_function<'a>(
283 &'a self,
284 func: &'a TableFunctionHandle,
285 input: &'a RecordBatch,
286 ) -> Result<impl Iterator<Item = Result<RecordBatch>> + 'a> {
287 use genawaiter2::sync::r#gen as generator;
288 use genawaiter2::yield_;
289
290 let export_name = &func.export_name;
291 if !self.functions.contains(export_name) {
292 bail!("function not found: {export_name}");
293 }
294
295 let mut instance = if let Some(instance) = self.instances.lock().unwrap().pop() {
297 instance
298 } else {
299 Instance::new(self)?
300 };
301
302 Ok(generator!({
303 let iter = match instance.call_table_function(export_name, input) {
305 Ok(iter) => iter,
306 Err(e) => {
307 yield_!(Err(e));
308 return;
309 }
310 };
311 for output in iter {
312 yield_!(output);
313 }
314 self.instances.lock().unwrap().push(instance);
317 })
318 .into_iter())
319 }
320}
321
322pub struct FunctionHandle {
323 export_name: String,
324}
325
326pub struct TableFunctionHandle {
327 export_name: String,
328}
329
330struct Instance {
331 alloc: TypedFunc<(u32, u32), u32>,
333 dealloc: TypedFunc<(u32, u32, u32), ()>,
335 record_batch_iterator_next: TypedFunc<(u32, u32), ()>,
337 record_batch_iterator_drop: TypedFunc<u32, ()>,
339 functions: HashMap<String, TypedFunc<(u32, u32, u32), i32>>,
341 memory: Memory,
342 store: Store<(WasiCtx, StoreLimits)>,
343 stdout: RamFileRef,
344 stderr: RamFileRef,
345}
346
347impl Instance {
348 fn new(rt: &Runtime) -> Result<Self> {
350 let module = &rt.module;
351 let engine = module.engine();
352 let mut linker = Linker::new(engine);
353 wasi_common::sync::add_to_linker(&mut linker, |(wasi, _)| wasi)?;
354
355 let file_size_limit = rt.config.file_size_limit.unwrap_or(1024);
359 let stdout = RamFileRef::new(RamFile::with_size_limit(file_size_limit));
360 let stderr = RamFileRef::new(RamFile::with_size_limit(file_size_limit));
361 let wasi = WasiCtxBuilder::new()
362 .stdout(Box::new(stdout.clone()))
363 .stderr(Box::new(stderr.clone()))
364 .build();
365 let limits = {
366 let mut builder = StoreLimitsBuilder::new();
367 if let Some(limit) = rt.config.memory_size_limit {
368 builder = builder.memory_size(limit);
369 }
370 builder.build()
371 };
372 let mut store = Store::new(engine, (wasi, limits));
373 store.limiter(|(_, limiter)| limiter);
374
375 let instance = linker.instantiate(&mut store, module)?;
376 let mut functions = HashMap::new();
377 for export in module.exports() {
378 let Some(encoded) = export.name().strip_prefix("arrowudf_") else {
379 continue;
380 };
381 let name = base64_decode(encoded).context("invalid symbol")?;
382 let func = instance.get_typed_func(&mut store, export.name())?;
383 functions.insert(name, func);
384 }
385 let alloc = instance.get_typed_func(&mut store, "alloc")?;
386 let dealloc = instance.get_typed_func(&mut store, "dealloc")?;
387 let record_batch_iterator_next =
388 instance.get_typed_func(&mut store, "record_batch_iterator_next")?;
389 let record_batch_iterator_drop =
390 instance.get_typed_func(&mut store, "record_batch_iterator_drop")?;
391 let memory = instance
392 .get_memory(&mut store, "memory")
393 .context("no memory")?;
394
395 Ok(Instance {
396 alloc,
397 dealloc,
398 record_batch_iterator_next,
399 record_batch_iterator_drop,
400 memory,
401 store,
402 functions,
403 stdout,
404 stderr,
405 })
406 }
407
408 fn call_scalar_function(&mut self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
410 let func = self
418 .functions
419 .get(name)
420 .with_context(|| format!("function not found: {name}"))?;
421
422 let input = encode_record_batch(input)?;
424
425 let alloc_len = u32::try_from(input.len() + 4 * 2).context("input too large")?;
427 let alloc_ptr = self.alloc.call(&mut self.store, (alloc_len, 4))?;
428 ensure!(alloc_ptr != 0, "failed to allocate for input");
429 let in_ptr = alloc_ptr + 4 * 2;
430
431 self.memory
433 .write(&mut self.store, in_ptr as usize, &input)?;
434
435 let result = func.call(&mut self.store, (in_ptr, input.len() as u32, alloc_ptr));
437 let errno = self.append_stdio(result)?;
438
439 let out_ptr = self.read_u32(alloc_ptr)?;
441 let out_len = self.read_u32(alloc_ptr + 4)?;
442
443 let out_bytes = self
445 .memory
446 .data(&self.store)
447 .get(out_ptr as usize..(out_ptr + out_len) as usize)
448 .context("output slice out of bounds")?;
449 let result = match errno {
450 0 => Ok(decode_record_batch(out_bytes)?),
451 _ => Err(anyhow!("{}", std::str::from_utf8(out_bytes)?)),
452 };
453
454 self.dealloc
456 .call(&mut self.store, (alloc_ptr, alloc_len, 4))?;
457 self.dealloc.call(&mut self.store, (out_ptr, out_len, 1))?;
458
459 result
460 }
461
462 fn call_table_function<'a>(
464 &'a mut self,
465 name: &str,
466 input: &RecordBatch,
467 ) -> Result<impl Iterator<Item = Result<RecordBatch>> + 'a> {
468 let func = self
476 .functions
477 .get(name)
478 .with_context(|| format!("function not found: {name}"))?;
479
480 let input = encode_record_batch(input)?;
482
483 let alloc_len = u32::try_from(input.len() + 4 * 2).context("input too large")?;
485 let alloc_ptr = self.alloc.call(&mut self.store, (alloc_len, 4))?;
486 ensure!(alloc_ptr != 0, "failed to allocate for input");
487 let in_ptr = alloc_ptr + 4 * 2;
488
489 self.memory
491 .write(&mut self.store, in_ptr as usize, &input)?;
492
493 let result = func.call(&mut self.store, (in_ptr, input.len() as u32, alloc_ptr));
495 let errno = self.append_stdio(result)?;
496
497 let out_ptr = self.read_u32(alloc_ptr)?;
499 let out_len = self.read_u32(alloc_ptr + 4)?;
500
501 let out_bytes = self
503 .memory
504 .data(&self.store)
505 .get(out_ptr as usize..(out_ptr + out_len) as usize)
506 .context("output slice out of bounds")?;
507
508 let ptr = match errno {
509 0 => out_ptr,
510 _ => {
511 let err = anyhow!("{}", std::str::from_utf8(out_bytes)?);
512 self.dealloc
514 .call(&mut self.store, (alloc_ptr, alloc_len, 4))?;
515 self.dealloc.call(&mut self.store, (out_ptr, out_len, 1))?;
516
517 return Err(err);
518 }
519 };
520
521 struct RecordBatchIter<'a> {
522 instance: &'a mut Instance,
523 ptr: u32,
524 alloc_ptr: u32,
525 alloc_len: u32,
526 }
527
528 impl RecordBatchIter<'_> {
529 fn next(&mut self) -> Result<Option<RecordBatch>> {
531 self.instance
532 .record_batch_iterator_next
533 .call(&mut self.instance.store, (self.ptr, self.alloc_ptr))?;
534 let out_ptr = self.instance.read_u32(self.alloc_ptr)?;
536 let out_len = self.instance.read_u32(self.alloc_ptr + 4)?;
537
538 if out_ptr == 0 {
539 return Ok(None);
541 }
542
543 let out_bytes = self
545 .instance
546 .memory
547 .data(&self.instance.store)
548 .get(out_ptr as usize..(out_ptr + out_len) as usize)
549 .context("output slice out of bounds")?;
550 let batch = decode_record_batch(out_bytes)?;
551
552 self.instance
554 .dealloc
555 .call(&mut self.instance.store, (out_ptr, out_len, 1))?;
556
557 Ok(Some(batch))
558 }
559 }
560
561 impl Iterator for RecordBatchIter<'_> {
562 type Item = Result<RecordBatch>;
563
564 fn next(&mut self) -> Option<Self::Item> {
565 let result = self.next();
566 self.instance.append_stdio(result).transpose()
567 }
568 }
569
570 impl Drop for RecordBatchIter<'_> {
571 fn drop(&mut self) {
572 _ = self.instance.dealloc.call(
573 &mut self.instance.store,
574 (self.alloc_ptr, self.alloc_len, 4),
575 );
576 _ = self
577 .instance
578 .record_batch_iterator_drop
579 .call(&mut self.instance.store, self.ptr);
580 }
581 }
582
583 Ok(RecordBatchIter {
584 instance: self,
585 ptr,
586 alloc_ptr,
587 alloc_len,
588 })
589 }
590
591 fn read_u32(&mut self, ptr: u32) -> Result<u32> {
593 Ok(u32::from_le_bytes(
594 self.memory.data(&self.store)[ptr as usize..(ptr + 4) as usize]
595 .try_into()
596 .unwrap(),
597 ))
598 }
599
600 fn append_stdio<T>(&self, result: Result<T>) -> Result<T> {
602 let stdout = self.stdout.take();
603 let stderr = self.stderr.take();
604 match result {
605 Ok(v) => Ok(v),
606 Err(e) => Err(e.context(format!(
607 "--- stdout\n{}\n--- stderr\n{}",
608 String::from_utf8_lossy(&stdout),
609 String::from_utf8_lossy(&stderr),
610 ))),
611 }
612 }
613}
614
615fn base64_decode(input: &str) -> Result<String> {
617 use base64::{
618 Engine,
619 alphabet::Alphabet,
620 engine::{GeneralPurpose, general_purpose::NO_PAD},
621 };
622 let alphabet =
625 Alphabet::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789$_").unwrap();
626 let engine = GeneralPurpose::new(&alphabet, NO_PAD);
627 let bytes = engine.decode(input)?;
628 String::from_utf8(bytes).context("invalid utf8")
629}
630
631fn encode_record_batch(batch: &RecordBatch) -> Result<Vec<u8>> {
632 let mut buf = vec![];
633 let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batch.schema())?;
634 writer.write(batch)?;
635 writer.finish()?;
636 drop(writer);
637 Ok(buf)
638}
639
640fn decode_record_batch(bytes: &[u8]) -> Result<RecordBatch> {
641 let mut reader = arrow_ipc::reader::FileReader::try_new(std::io::Cursor::new(bytes), None)?;
642 let batch = reader.next().unwrap()?;
643 Ok(batch)
644}
645
646fn function_signature_of(
648 name: &str,
649 args: &[Field],
650 ret: &Field,
651 is_table_function: bool,
652) -> Result<String> {
653 let args: Vec<_> = args.iter().map(field_to_typename).try_collect()?;
654 let ret = field_to_typename(ret)?;
655 Ok(format!(
656 "{}({}){}{}",
657 name,
658 args.iter().join(","),
659 if is_table_function { "->>" } else { "->" },
660 ret
661 ))
662}
663
664fn field_to_typename(field: &Field) -> Result<String> {
665 if let Some(ext_typename) = field.metadata().get("ARROW:extension:name")
666 && let Some(typename) = ext_typename.strip_prefix("arrowudf.")
667 {
668 return Ok(typename.to_owned());
670 }
671 let ty = field.data_type();
672 Ok(match ty {
675 DataType::Null => "null".to_owned(),
676 DataType::Boolean => "boolean".to_owned(),
677 DataType::Int8 => "int8".to_owned(),
678 DataType::Int16 => "int16".to_owned(),
679 DataType::Int32 => "int32".to_owned(),
680 DataType::Int64 => "int64".to_owned(),
681 DataType::UInt8 => "uint8".to_owned(),
682 DataType::UInt16 => "uint16".to_owned(),
683 DataType::UInt32 => "uint32".to_owned(),
684 DataType::UInt64 => "uint64".to_owned(),
685 DataType::Float32 => "float32".to_owned(),
686 DataType::Float64 => "float64".to_owned(),
687 DataType::Date32 => "date32".to_owned(),
688 DataType::Time64(TimeUnit::Microsecond) => "time64".to_owned(),
689 DataType::Timestamp(TimeUnit::Microsecond, None) => "timestamp".to_owned(),
690 DataType::Interval(IntervalUnit::MonthDayNano) => "interval".to_owned(),
691 DataType::Utf8 => "string".to_owned(),
692 DataType::Binary => "binary".to_owned(),
693 DataType::LargeUtf8 => "largestring".to_owned(),
694 DataType::LargeBinary => "largebinary".to_owned(),
695 DataType::List(elem) => format!("{}[]", field_to_typename(elem)?),
696 DataType::Struct(fields) => {
697 let fields: Vec<_> = fields
698 .iter()
699 .map(|x| Ok::<_, anyhow::Error>((x.name(), field_to_typename(x)?)))
700 .try_collect()?;
701 format!(
702 "struct<{}>",
703 fields
704 .iter()
705 .map(|(name, typename)| format!("{name}:{typename}"))
706 .join(",")
707 )
708 }
709 _ => {
710 bail!("unsupported data type: {ty}");
711 }
712 })
713}