1use reifydb_core::value::column::data::ColumnData;
7use reifydb_sdk::marshal::wasm::{marshal_columns_to_bytes, unmarshal_columns_from_bytes};
8use reifydb_type::{fragment::Fragment, value::r#type::Type};
9use reifydb_wasm::{Engine, SpawnBinary, module::value::Value, source};
10
11use super::{ScalarFunction, ScalarFunctionContext};
12use crate::error::{ScalarFunctionError, ScalarFunctionResult};
13
14pub struct WasmScalarFunction {
26 name: String,
27 wasm_bytes: Vec<u8>,
28}
29
30impl WasmScalarFunction {
31 pub fn new(name: impl Into<String>, wasm_bytes: Vec<u8>) -> Self {
32 Self {
33 name: name.into(),
34 wasm_bytes,
35 }
36 }
37
38 pub fn name(&self) -> &str {
39 &self.name
40 }
41
42 fn err(&self, reason: impl Into<String>) -> ScalarFunctionError {
43 ScalarFunctionError::ExecutionFailed {
44 function: Fragment::internal(&self.name),
45 reason: reason.into(),
46 }
47 }
48}
49
50unsafe impl Send for WasmScalarFunction {}
53unsafe impl Sync for WasmScalarFunction {}
54
55impl ScalarFunction for WasmScalarFunction {
56 fn return_type(&self, _input_types: &[Type]) -> Type {
57 Type::Any
58 }
59
60 fn scalar<'a>(&'a self, ctx: ScalarFunctionContext<'a>) -> ScalarFunctionResult<ColumnData> {
61 let input_bytes = marshal_columns_to_bytes(ctx.columns);
62
63 let mut engine = Engine::default();
64 engine.spawn(source::binary::bytes(&self.wasm_bytes))
65 .map_err(|e| self.err(format!("failed to load: {:?}", e)))?;
66
67 let alloc_result = engine
69 .invoke("alloc", &[Value::I32(input_bytes.len() as i32)])
70 .map_err(|e| self.err(format!("alloc failed: {:?}", e)))?;
71
72 let input_ptr = match alloc_result.first() {
73 Some(Value::I32(v)) => *v,
74 _ => return Err(self.err("alloc returned unexpected result")),
75 };
76
77 engine.write_memory(input_ptr as usize, &input_bytes)
79 .map_err(|e| self.err(format!("write_memory failed: {:?}", e)))?;
80
81 let result = engine
83 .invoke("scalar", &[Value::I32(input_ptr), Value::I32(input_bytes.len() as i32)])
84 .map_err(|e| self.err(format!("scalar call failed: {:?}", e)))?;
85
86 let output_ptr = match result.first() {
87 Some(Value::I32(v)) => *v as usize,
88 _ => return Err(self.err("scalar returned unexpected result")),
89 };
90
91 let len_bytes = engine
93 .read_memory(output_ptr, 4)
94 .map_err(|e| self.err(format!("read output length failed: {:?}", e)))?;
95
96 let output_len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], len_bytes[3]]) as usize;
97
98 let output_bytes = engine
100 .read_memory(output_ptr + 4, output_len)
101 .map_err(|e| self.err(format!("read output data failed: {:?}", e)))?;
102
103 let output_columns = unmarshal_columns_from_bytes(&output_bytes);
105
106 match output_columns.first() {
107 Some(col) => Ok(col.data().clone()),
108 None => Ok(ColumnData::none_typed(Type::Any, ctx.row_count)),
109 }
110 }
111}