1use std::ffi::{c_char, c_void, CStr};
2use wasmtime::{Config, Engine, Linker, Module, Store, Val};
3use wasmtime_wasi::WasiCtxBuilder;
4
5const LIBSQL_INTEGER: i8 = 1;
6const LIBSQL_FLOAT: i8 = 2;
7const LIBSQL_TEXT: i8 = 3;
8const LIBSQL_BLOB: i8 = 4;
9const LIBSQL_NULL: i8 = 5;
10
11fn maybe_set_err_buf(
12 err_buf: *mut *const u8,
13 err_str: String,
14 alloc_err: unsafe extern "C" fn(u64) -> *mut u8,
15) {
16 if !err_buf.is_null() {
17 let err_ptr = unsafe { alloc_err(err_str.len() as u64 + 1) };
18 unsafe { std::slice::from_raw_parts_mut(err_ptr, err_str.len()) }
19 .copy_from_slice(err_str.as_bytes());
20 unsafe { *err_buf = err_ptr as *const u8 };
21 }
22}
23
24#[no_mangle]
25pub fn libsql_compile_wasm_module(
26 engine: *const wasmtime::Engine,
27 p_src_body: *const u8,
28 n_body: i32,
29 alloc_err: unsafe extern "C" fn(u64) -> *mut u8,
30 err_msg_buf: *mut *const u8,
31) -> *const c_void {
32 let src_body: &[u8] = unsafe { std::slice::from_raw_parts(p_src_body, n_body as usize) };
33
34 let module = match Module::new(unsafe { &*engine }, src_body) {
35 Ok(m) => m,
36 Err(orig_e) => {
37 let src_body_str: &str = match std::str::from_utf8(src_body) {
39 Ok(src) => src,
40 Err(e) => {
41 maybe_set_err_buf(
42 err_msg_buf,
43 format!(
44 "Failed to compile module: {}, and it's not valid .wat either: {}",
45 orig_e, e
46 ),
47 alloc_err,
48 );
49 return std::ptr::null() as *const c_void;
50 }
51 };
52 if src_body_str.len() < 2 {
53 maybe_set_err_buf(
54 err_msg_buf,
55 format!("Failed to compile module: {}", orig_e),
56 alloc_err,
57 );
58 return std::ptr::null() as *const c_void;
59 }
60 let src_body_dequoted =
61 String::from(&src_body_str[1..src_body_str.len() - 2]).replace("''", "'");
62 match Module::new(unsafe { &*engine }, src_body_dequoted.as_bytes()) {
63 Ok(m) => m,
64 Err(e) => {
65 maybe_set_err_buf(
66 err_msg_buf,
67 format!("Failed to compile .wat module: {}", e),
68 alloc_err,
69 );
70 return std::ptr::null();
71 }
72 }
73 }
74 };
75 let module = Box::new(module);
76 let module_ptr = &*module as *const Module as *const c_void;
77 std::mem::forget(module);
78 module_ptr
79}
80
81#[no_mangle]
82pub fn libsql_wasm_engine_new() -> *const c_void {
83 let engine = match Engine::new(&Config::new()) {
84 Ok(eng) => eng,
85 Err(_) => return std::ptr::null() as *const c_void,
86 };
87
88 Box::into_raw(Box::new(engine)) as *const c_void
89}
90
91#[no_mangle]
92pub fn libsql_wasm_engine_free(engine: *mut c_void) {
93 unsafe { Box::from_raw(engine as *mut Engine) };
94}
95
96#[repr(C)]
97pub struct libsql_wasm_udf_api {
98 libsql_value_type: unsafe extern "C" fn(*const c_void) -> i32,
99 libsql_value_int: unsafe extern "C" fn(*const c_void) -> i32,
100 libsql_value_double: unsafe extern "C" fn(*const c_void) -> f64,
101 libsql_value_text: unsafe extern "C" fn(*const c_void) -> *const u8,
102 libsql_value_blob: unsafe extern "C" fn(*const c_void) -> *const c_void,
103 libsql_value_bytes: unsafe extern "C" fn(*const c_void) -> i32,
104 libsql_result_error: unsafe extern "C" fn(*const c_void, *const u8, i32),
105 libsql_result_error_nomem: unsafe extern "C" fn(*const c_void),
106 libsql_result_int: unsafe extern "C" fn(*const c_void, i32),
107 libsql_result_double: unsafe extern "C" fn(*const c_void, f64),
108 libsql_result_text: unsafe extern "C" fn(*const c_void, *const u8, i32, *const c_void),
109 libsql_result_blob: unsafe extern "C" fn(*const c_void, *const c_void, i32, *const c_void),
110 libsql_result_null: unsafe extern "C" fn(*const c_void),
111 libsql_malloc: unsafe extern "C" fn(i32) -> *mut c_void,
112 libsql_free: unsafe extern "C" fn(*mut c_void),
113}
114
115fn alloc_slice(api: *const libsql_wasm_udf_api, s: &[u8]) -> *const c_void {
116 let len = s.len();
117 let ptr = unsafe { ((*api).libsql_malloc)(len as i32) };
118 unsafe { std::slice::from_raw_parts_mut(ptr as *mut u8, len) }.copy_from_slice(s);
119 ptr as *const c_void
120}
121
122#[no_mangle]
123pub fn libsql_run_wasm(
124 api: *const libsql_wasm_udf_api,
125 libsql_ctx: *const c_void,
126 engine: *mut Engine,
127 module: *mut Module,
128 func_name: *const u8,
129 argc: i32,
130 argv: *mut *mut c_void,
131) {
132 let engine = unsafe { &*engine };
133 let module = unsafe { &*module };
134
135 let run_wasm = |engine: &Engine, module: &Module| -> Result<(), String> {
136 let mut linker = Linker::new(&engine);
137 wasmtime_wasi::add_to_linker(&mut linker, |s| s)
138 .map_err(|e| format!("Add WASI failed: {}", e))?;
139 let wasi = WasiCtxBuilder::new()
140 .inherit_stdio()
141 .args(&[])
142 .map_err(|e| format!("Creating WasiCtx failed: {}", e))?
143 .build();
144 let mut store = Store::new(&engine, wasi);
145
146 let instance = linker
147 .instantiate(&mut store, module)
148 .map_err(|e| format!("Creating instance failed: {}", e))?;
149
150 let func_name: &str = unsafe { CStr::from_ptr(func_name as *const c_char) }
151 .to_str()
152 .map_err(|e| format!("Function name is not valid utf-8: {}", e))?;
153
154 let func = instance
155 .get_func(&mut store, func_name)
156 .ok_or_else(|| format!("Function {} not found in Wasm module", func_name))?;
157
158 let memory = instance
159 .get_memory(&mut store, "memory")
160 .ok_or_else(|| format!("Memory \"memory\" not found in wasm module"))?;
161
162 let mem_size = memory.size(&mut store) as usize;
163
164 let mut vals: Vec<Val> = Vec::new();
165 for i in 0..argc {
166 let arg = unsafe { *argv.offset(i as isize) };
167 match unsafe { ((*api).libsql_value_type)(arg) } as i8 {
168 LIBSQL_INTEGER => {
169 vals.push(Val::I64(unsafe { ((*api).libsql_value_int)(arg) } as i64))
170 }
171 LIBSQL_FLOAT => vals.push(Val::F64(
172 unsafe { ((*api).libsql_value_double)(arg) }.to_bits(),
173 )),
174 LIBSQL_TEXT => {
175 let text_len = unsafe { ((*api).libsql_value_bytes)(arg) } as usize;
176 let text: &[u8] = unsafe {
177 std::slice::from_raw_parts(((*api).libsql_value_text)(arg), text_len)
178 };
179
180 let func_name = "libsql_malloc";
181 let func_malloc =
182 instance.get_func(&mut store, func_name).ok_or_else(|| {
183 format!("Function {} not found in Wasm module", func_name)
184 })?;
185 let params = [Val::I32((text_len + 2) as i32)];
186 let mut result = Val::null();
187 func_malloc
188 .call(&mut store, ¶ms, std::slice::from_mut(&mut result))
189 .map_err(|e| format!("Calling function {} failed: {}", func_name, e))?;
190
191 let mem_offset = result.i32().unwrap_or(mem_size as i32) as usize;
192
193 let data = memory.data_mut(&mut store);
194 data[mem_offset] = LIBSQL_TEXT as u8;
195 data[mem_offset + 1..mem_offset + 1 + text_len].copy_from_slice(text);
196 data[mem_offset + 1 + text_len] = 0;
197
198 vals.push(Val::I32(mem_offset as i32));
199 }
200 LIBSQL_BLOB => {
201 let blob_len = unsafe { ((*api).libsql_value_bytes)(arg) } as usize;
202 let blob: &[u8] = unsafe {
203 std::slice::from_raw_parts(
204 ((*api).libsql_value_blob)(arg) as *const u8,
205 blob_len,
206 )
207 };
208 let blob_len_i32 = blob_len as i32;
209
210 let func_name = "libsql_malloc";
211 let func_malloc =
212 instance.get_func(&mut store, func_name).ok_or_else(|| {
213 format!("Function {} not found in Wasm module", func_name)
214 })?;
215 let params = [Val::I32((blob_len_i32 + 5) as i32)];
216 let mut result = Val::null();
217 func_malloc
218 .call(&mut store, ¶ms, std::slice::from_mut(&mut result))
219 .map_err(|e| format!("Calling function {} failed: {}", func_name, e))?;
220
221 let mem_offset = result.i32().unwrap_or(mem_size as i32) as usize;
222
223 let data = memory.data_mut(&mut store);
224 data[mem_offset] = LIBSQL_BLOB as u8;
225 data[mem_offset + 1..mem_offset + 1 + 4]
226 .copy_from_slice(&blob_len_i32.to_be_bytes());
227 data[mem_offset + 1 + 4..mem_offset + 1 + 4 + blob_len].copy_from_slice(blob);
228
229 vals.push(Val::I32(mem_offset as i32));
230 }
231 LIBSQL_NULL => {
232 let func_name = "libsql_malloc";
233 let func_malloc =
234 instance.get_func(&mut store, func_name).ok_or_else(|| {
235 format!("Function {} not found in Wasm module", func_name)
236 })?;
237 let params = [Val::I32(1)];
238 let mut result = Val::null();
239 func_malloc
240 .call(&mut store, ¶ms, std::slice::from_mut(&mut result))
241 .map_err(|e| format!("Calling function {} failed: {}", func_name, e))?;
242
243 let mem_offset = result.i32().unwrap_or(mem_size as i32) as usize;
244
245 memory.data_mut(&mut store)[mem_offset] = LIBSQL_NULL as u8;
246
247 vals.push(Val::I32(mem_offset as i32));
248 }
249 _ => {
250 return Err(format!("Unknown libSQL type"));
251 }
252 }
253 }
254
255 let mut result = Val::null();
256 func.call(&mut store, &vals, std::slice::from_mut(&mut result))
257 .map_err(|e| format!("Calling function {} failed: {}", func_name, e))?;
258
259 match result {
260 Val::I64(v) => unsafe { ((*api).libsql_result_int)(libsql_ctx, v as i32) },
261 Val::F64(v) => unsafe { ((*api).libsql_result_double)(libsql_ctx, f64::from_bits(v)) },
262 Val::I32(v) => {
263 let v = v as usize;
264 match memory.data(&store)[v] as i8 {
265 LIBSQL_TEXT => {
266 let result_str = unsafe {
267 CStr::from_ptr(
268 (memory.data(&store).as_ptr() as *const c_char)
269 .offset(v as isize + 1),
270 )
271 };
272 let result_ptr = alloc_slice(api, result_str.to_bytes_with_nul());
273 unsafe {
274 ((*api).libsql_result_text)(
275 libsql_ctx,
276 result_ptr as *const u8,
277 result_str.to_str().unwrap().len() as i32, (*api).libsql_free as *const c_void,
279 )
280 }
281 }
282 LIBSQL_BLOB => {
283 let blob_len = i32::from_be_bytes(
284 memory.data(&store)[v + 1..v + 1 + 4].try_into().unwrap(), );
286 let result_ptr = alloc_slice(
287 api,
288 &memory.data(&store)[v + 1 + 4..v + 1 + 4 + blob_len as usize],
289 );
290 unsafe {
291 ((*api).libsql_result_blob)(
292 libsql_ctx,
293 result_ptr as *const c_void,
294 blob_len,
295 (*api).libsql_free as *const c_void,
296 )
297 }
298 }
299 LIBSQL_NULL => unsafe { ((*api).libsql_result_null)(libsql_ctx) },
300 _ => return Err(format!("Malformed result type byte")),
301 }
302 }
303 _ => return Err(format!("Malformed result type")),
304 }
305 Ok(())
306 };
307
308 match run_wasm(engine, module) {
309 Ok(_) => {}
310 Err(err) => unsafe {
311 ((*api).libsql_result_error)(libsql_ctx, err.as_ptr() as *const u8, err.len() as i32);
312 },
313 }
314}
315
316#[no_mangle]
317pub fn libsql_free_wasm_module(module: *mut *mut Module) {
318 unsafe { Box::from_raw(*module) };
319}