1#![doc = include_str!("../README.md")]
16
17use anyhow::{anyhow, bail, ensure, Context};
18use arrow_array::RecordBatch;
19use ram_file::{RamFile, RamFileRef};
20use std::collections::{HashMap, HashSet};
21use std::fmt::Debug;
22use std::sync::Mutex;
23use wasi_common::{sync::WasiCtxBuilder, WasiCtx};
24use wasmtime::*;
25
26#[cfg(feature = "build")]
27pub mod build;
28mod ram_file;
29
30pub struct Runtime {
34 module: Module,
35 config: Config,
37 functions: HashSet<String>,
39 types: HashMap<String, String>,
41 instances: Mutex<Vec<Instance>>,
43 abi_version: (u8, u8),
45}
46
47#[derive(Debug, Default, PartialEq, Eq)]
49#[non_exhaustive]
50pub struct Config {
51 pub memory_size_limit: Option<usize>,
53 pub file_size_limit: Option<usize>,
55}
56
57struct Instance {
58 alloc: TypedFunc<(u32, u32), u32>,
60 dealloc: TypedFunc<(u32, u32, u32), ()>,
62 record_batch_iterator_next: TypedFunc<(u32, u32), ()>,
64 record_batch_iterator_drop: TypedFunc<u32, ()>,
66 functions: HashMap<String, TypedFunc<(u32, u32, u32), i32>>,
68 memory: Memory,
69 store: Store<(WasiCtx, StoreLimits)>,
70 stdout: RamFileRef,
71 stderr: RamFileRef,
72}
73
74impl Debug for Runtime {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("Runtime")
77 .field("config", &self.config)
78 .field("functions", &self.functions)
79 .field("types", &self.types)
80 .field("instances", &self.instances.lock().unwrap().len())
81 .finish()
82 }
83}
84
85impl Runtime {
86 pub fn new(binary: &[u8]) -> Result<Self> {
88 Self::with_config(binary, Config::default())
89 }
90
91 pub fn with_config(binary: &[u8], config: Config) -> Result<Self> {
93 static ENGINE: once_cell::sync::Lazy<Engine> = once_cell::sync::Lazy::new(Engine::default);
95 Self::with_config_engine(binary, config, &ENGINE)
96 }
97
98 fn with_config_engine(binary: &[u8], config: Config, engine: &Engine) -> Result<Self> {
100 let module = Module::from_binary(engine, binary).context("failed to load wasm binary")?;
101
102 let version = module
104 .exports()
105 .find_map(|e| e.name().strip_prefix("ARROWUDF_VERSION_"))
106 .context("version not found")?;
107 let (major, minor) = version.split_once('_').context("invalid version")?;
108 let (major, minor) = (major.parse::<u8>()?, minor.parse::<u8>()?);
109 ensure!(major <= 3, "unsupported abi version: {major}.{minor}");
110
111 let mut functions = HashSet::new();
112 let mut types = HashMap::new();
113 for export in module.exports() {
114 if let Some(encoded) = export.name().strip_prefix("arrowudf_") {
115 let name = base64_decode(encoded).context("invalid symbol")?;
116 functions.insert(name);
117 } else if let Some(encoded) = export.name().strip_prefix("arrowudt_") {
118 let meta = base64_decode(encoded).context("invalid symbol")?;
119 let (name, fields) = meta.split_once('=').context("invalid type string")?;
120 types.insert(name.to_string(), fields.to_string());
121 }
122 }
123
124 Ok(Self {
125 module,
126 config,
127 functions,
128 types,
129 instances: Mutex::new(vec![]),
130 abi_version: (major, minor),
131 })
132 }
133
134 pub fn functions(&self) -> impl Iterator<Item = &str> {
136 self.functions.iter().map(|s| s.as_str())
137 }
138
139 pub fn types(&self) -> impl Iterator<Item = (&str, &str)> {
141 self.types.iter().map(|(k, v)| (k.as_str(), v.as_str()))
142 }
143
144 pub fn abi_version(&self) -> (u8, u8) {
146 self.abi_version
147 }
148
149 pub fn find_function_by_inlined_signature(&self, s: &str) -> Option<&str> {
159 self.functions
160 .iter()
161 .find(|f| self.inline_types(f) == s)
162 .map(|f| f.as_str())
163 }
164
165 fn inline_types(&self, s: &str) -> String {
175 let mut inlined = s.to_string();
176 loop {
177 let replaced = inlined.clone();
178 for (k, v) in self.types.iter() {
179 inlined = inlined.replace(&format!("struct {k}"), &format!("struct<{v}>"));
180 }
181 if replaced == inlined {
182 return inlined;
183 }
184 }
185 }
186
187 pub fn call(&self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
189 if !self.functions.contains(name) {
190 bail!("function not found: {name}");
191 }
192
193 let mut instance = if let Some(instance) = self.instances.lock().unwrap().pop() {
195 instance
196 } else {
197 Instance::new(self)?
198 };
199
200 let output = instance.call_scalar_function(name, input);
202
203 if output.is_ok() {
205 self.instances.lock().unwrap().push(instance);
206 }
207
208 output
209 }
210
211 pub fn call_table_function<'a>(
213 &'a self,
214 name: &'a str,
215 input: &'a RecordBatch,
216 ) -> Result<impl Iterator<Item = Result<RecordBatch>> + 'a> {
217 use genawaiter2::{sync::gen, yield_};
218 if !self.functions.contains(name) {
219 bail!("function not found: {name}");
220 }
221
222 let mut instance = if let Some(instance) = self.instances.lock().unwrap().pop() {
224 instance
225 } else {
226 Instance::new(self)?
227 };
228
229 Ok(gen!({
230 let iter = match instance.call_table_function(name, input) {
232 Ok(iter) => iter,
233 Err(e) => {
234 yield_!(Err(e));
235 return;
236 }
237 };
238 for output in iter {
239 yield_!(output);
240 }
241 self.instances.lock().unwrap().push(instance);
244 })
245 .into_iter())
246 }
247}
248
249impl Instance {
250 fn new(rt: &Runtime) -> Result<Self> {
252 let module = &rt.module;
253 let engine = module.engine();
254 let mut linker = Linker::new(engine);
255 wasi_common::sync::add_to_linker(&mut linker, |(wasi, _)| wasi)?;
256
257 let file_size_limit = rt.config.file_size_limit.unwrap_or(1024);
261 let stdout = RamFileRef::new(RamFile::with_size_limit(file_size_limit));
262 let stderr = RamFileRef::new(RamFile::with_size_limit(file_size_limit));
263 let wasi = WasiCtxBuilder::new()
264 .stdout(Box::new(stdout.clone()))
265 .stderr(Box::new(stderr.clone()))
266 .build();
267 let limits = {
268 let mut builder = StoreLimitsBuilder::new();
269 if let Some(limit) = rt.config.memory_size_limit {
270 builder = builder.memory_size(limit);
271 }
272 builder.build()
273 };
274 let mut store = Store::new(engine, (wasi, limits));
275 store.limiter(|(_, limiter)| limiter);
276
277 let instance = linker.instantiate(&mut store, module)?;
278 let mut functions = HashMap::new();
279 for export in module.exports() {
280 let Some(encoded) = export.name().strip_prefix("arrowudf_") else {
281 continue;
282 };
283 let name = base64_decode(encoded).context("invalid symbol")?;
284 let func = instance.get_typed_func(&mut store, export.name())?;
285 functions.insert(name, func);
286 }
287 let alloc = instance.get_typed_func(&mut store, "alloc")?;
288 let dealloc = instance.get_typed_func(&mut store, "dealloc")?;
289 let record_batch_iterator_next =
290 instance.get_typed_func(&mut store, "record_batch_iterator_next")?;
291 let record_batch_iterator_drop =
292 instance.get_typed_func(&mut store, "record_batch_iterator_drop")?;
293 let memory = instance
294 .get_memory(&mut store, "memory")
295 .context("no memory")?;
296
297 Ok(Instance {
298 alloc,
299 dealloc,
300 record_batch_iterator_next,
301 record_batch_iterator_drop,
302 memory,
303 store,
304 functions,
305 stdout,
306 stderr,
307 })
308 }
309
310 fn call_scalar_function(&mut self, name: &str, input: &RecordBatch) -> Result<RecordBatch> {
312 let func = self
320 .functions
321 .get(name)
322 .with_context(|| format!("function not found: {name}"))?;
323
324 let input = encode_record_batch(input)?;
326
327 let alloc_len = u32::try_from(input.len() + 4 * 2).context("input too large")?;
329 let alloc_ptr = self.alloc.call(&mut self.store, (alloc_len, 4))?;
330 ensure!(alloc_ptr != 0, "failed to allocate for input");
331 let in_ptr = alloc_ptr + 4 * 2;
332
333 self.memory
335 .write(&mut self.store, in_ptr as usize, &input)?;
336
337 let result = func.call(&mut self.store, (in_ptr, input.len() as u32, alloc_ptr));
339 let errno = self.append_stdio(result)?;
340
341 let out_ptr = self.read_u32(alloc_ptr)?;
343 let out_len = self.read_u32(alloc_ptr + 4)?;
344
345 let out_bytes = self
347 .memory
348 .data(&self.store)
349 .get(out_ptr as usize..(out_ptr + out_len) as usize)
350 .context("output slice out of bounds")?;
351 let result = match errno {
352 0 => Ok(decode_record_batch(out_bytes)?),
353 _ => Err(anyhow!("{}", std::str::from_utf8(out_bytes)?)),
354 };
355
356 self.dealloc
358 .call(&mut self.store, (alloc_ptr, alloc_len, 4))?;
359 self.dealloc.call(&mut self.store, (out_ptr, out_len, 1))?;
360
361 result
362 }
363
364 fn call_table_function<'a>(
366 &'a mut self,
367 name: &str,
368 input: &RecordBatch,
369 ) -> Result<impl Iterator<Item = Result<RecordBatch>> + 'a> {
370 let func = self
378 .functions
379 .get(name)
380 .with_context(|| format!("function not found: {name}"))?;
381
382 let input = encode_record_batch(input)?;
384
385 let alloc_len = u32::try_from(input.len() + 4 * 2).context("input too large")?;
387 let alloc_ptr = self.alloc.call(&mut self.store, (alloc_len, 4))?;
388 ensure!(alloc_ptr != 0, "failed to allocate for input");
389 let in_ptr = alloc_ptr + 4 * 2;
390
391 self.memory
393 .write(&mut self.store, in_ptr as usize, &input)?;
394
395 let result = func.call(&mut self.store, (in_ptr, input.len() as u32, alloc_ptr));
397 let errno = self.append_stdio(result)?;
398
399 let out_ptr = self.read_u32(alloc_ptr)?;
401 let out_len = self.read_u32(alloc_ptr + 4)?;
402
403 let out_bytes = self
405 .memory
406 .data(&self.store)
407 .get(out_ptr as usize..(out_ptr + out_len) as usize)
408 .context("output slice out of bounds")?;
409
410 let ptr = match errno {
411 0 => out_ptr,
412 _ => {
413 let err = anyhow!("{}", std::str::from_utf8(out_bytes)?);
414 self.dealloc
416 .call(&mut self.store, (alloc_ptr, alloc_len, 4))?;
417 self.dealloc.call(&mut self.store, (out_ptr, out_len, 1))?;
418
419 return Err(err);
420 }
421 };
422
423 struct RecordBatchIter<'a> {
424 instance: &'a mut Instance,
425 ptr: u32,
426 alloc_ptr: u32,
427 alloc_len: u32,
428 }
429
430 impl RecordBatchIter<'_> {
431 fn next(&mut self) -> Result<Option<RecordBatch>> {
433 self.instance
434 .record_batch_iterator_next
435 .call(&mut self.instance.store, (self.ptr, self.alloc_ptr))?;
436 let out_ptr = self.instance.read_u32(self.alloc_ptr)?;
438 let out_len = self.instance.read_u32(self.alloc_ptr + 4)?;
439
440 if out_ptr == 0 {
441 return Ok(None);
443 }
444
445 let out_bytes = self
447 .instance
448 .memory
449 .data(&self.instance.store)
450 .get(out_ptr as usize..(out_ptr + out_len) as usize)
451 .context("output slice out of bounds")?;
452 let batch = decode_record_batch(out_bytes)?;
453
454 self.instance
456 .dealloc
457 .call(&mut self.instance.store, (out_ptr, out_len, 1))?;
458
459 Ok(Some(batch))
460 }
461 }
462
463 impl Iterator for RecordBatchIter<'_> {
464 type Item = Result<RecordBatch>;
465
466 fn next(&mut self) -> Option<Self::Item> {
467 let result = self.next();
468 self.instance.append_stdio(result).transpose()
469 }
470 }
471
472 impl Drop for RecordBatchIter<'_> {
473 fn drop(&mut self) {
474 _ = self.instance.dealloc.call(
475 &mut self.instance.store,
476 (self.alloc_ptr, self.alloc_len, 4),
477 );
478 _ = self
479 .instance
480 .record_batch_iterator_drop
481 .call(&mut self.instance.store, self.ptr);
482 }
483 }
484
485 Ok(RecordBatchIter {
486 instance: self,
487 ptr,
488 alloc_ptr,
489 alloc_len,
490 })
491 }
492
493 fn read_u32(&mut self, ptr: u32) -> Result<u32> {
495 Ok(u32::from_le_bytes(
496 self.memory.data(&self.store)[ptr as usize..(ptr + 4) as usize]
497 .try_into()
498 .unwrap(),
499 ))
500 }
501
502 fn append_stdio<T>(&self, result: Result<T>) -> Result<T> {
504 let stdout = self.stdout.take();
505 let stderr = self.stderr.take();
506 match result {
507 Ok(v) => Ok(v),
508 Err(e) => Err(e.context(format!(
509 "--- stdout\n{}\n--- stderr\n{}",
510 String::from_utf8_lossy(&stdout),
511 String::from_utf8_lossy(&stderr),
512 ))),
513 }
514 }
515}
516
517fn base64_decode(input: &str) -> Result<String> {
519 use base64::{
520 alphabet::Alphabet,
521 engine::{general_purpose::NO_PAD, GeneralPurpose},
522 Engine,
523 };
524 let alphabet =
527 Alphabet::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789$_").unwrap();
528 let engine = GeneralPurpose::new(&alphabet, NO_PAD);
529 let bytes = engine.decode(input)?;
530 String::from_utf8(bytes).context("invalid utf8")
531}
532
533fn encode_record_batch(batch: &RecordBatch) -> Result<Vec<u8>> {
534 let mut buf = vec![];
535 let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batch.schema())?;
536 writer.write(batch)?;
537 writer.finish()?;
538 drop(writer);
539 Ok(buf)
540}
541
542fn decode_record_batch(bytes: &[u8]) -> Result<RecordBatch> {
543 let mut reader = arrow_ipc::reader::FileReader::try_new(std::io::Cursor::new(bytes), None)?;
544 let batch = reader.next().unwrap()?;
545 Ok(batch)
546}