libsql_wasm/
lib.rs

1use std::ffi::{c_char, c_void, CStr};
2use wasmtime::{Config, Engine, Linker, Module, Store, Val};
3use wasmtime_wasi::WasiCtxBuilder;
4
5const LIBSQL_INTEGER: i8 = 1;
6const LIBSQL_FLOAT: i8 = 2;
7const LIBSQL_TEXT: i8 = 3;
8const LIBSQL_BLOB: i8 = 4;
9const LIBSQL_NULL: i8 = 5;
10
11fn maybe_set_err_buf(
12    err_buf: *mut *const u8,
13    err_str: String,
14    alloc_err: unsafe extern "C" fn(u64) -> *mut u8,
15) {
16    if !err_buf.is_null() {
17        let err_ptr = unsafe { alloc_err(err_str.len() as u64 + 1) };
18        unsafe { std::slice::from_raw_parts_mut(err_ptr, err_str.len()) }
19            .copy_from_slice(err_str.as_bytes());
20        unsafe { *err_buf = err_ptr as *const u8 };
21    }
22}
23
24#[no_mangle]
25pub fn libsql_compile_wasm_module(
26    engine: *const wasmtime::Engine,
27    p_src_body: *const u8,
28    n_body: i32,
29    alloc_err: unsafe extern "C" fn(u64) -> *mut u8,
30    err_msg_buf: *mut *const u8,
31) -> *const c_void {
32    let src_body: &[u8] = unsafe { std::slice::from_raw_parts(p_src_body, n_body as usize) };
33
34    let module = match Module::new(unsafe { &*engine }, src_body) {
35        Ok(m) => m,
36        Err(orig_e) => {
37            // If compilation failed, let's assume it's unquoted .wat and retry
38            let src_body_str: &str = match std::str::from_utf8(src_body) {
39                Ok(src) => src,
40                Err(e) => {
41                    maybe_set_err_buf(
42                        err_msg_buf,
43                        format!(
44                            "Failed to compile module: {}, and it's not valid .wat either: {}",
45                            orig_e, e
46                        ),
47                        alloc_err,
48                    );
49                    return std::ptr::null() as *const c_void;
50                }
51            };
52            if src_body_str.len() < 2 {
53                maybe_set_err_buf(
54                    err_msg_buf,
55                    format!("Failed to compile module: {}", orig_e),
56                    alloc_err,
57                );
58                return std::ptr::null() as *const c_void;
59            }
60            let src_body_dequoted =
61                String::from(&src_body_str[1..src_body_str.len() - 2]).replace("''", "'");
62            match Module::new(unsafe { &*engine }, src_body_dequoted.as_bytes()) {
63                Ok(m) => m,
64                Err(e) => {
65                    maybe_set_err_buf(
66                        err_msg_buf,
67                        format!("Failed to compile .wat module: {}", e),
68                        alloc_err,
69                    );
70                    return std::ptr::null();
71                }
72            }
73        }
74    };
75    let module = Box::new(module);
76    let module_ptr = &*module as *const Module as *const c_void;
77    std::mem::forget(module);
78    module_ptr
79}
80
81#[no_mangle]
82pub fn libsql_wasm_engine_new() -> *const c_void {
83    let engine = match Engine::new(&Config::new()) {
84        Ok(eng) => eng,
85        Err(_) => return std::ptr::null() as *const c_void,
86    };
87
88    Box::into_raw(Box::new(engine)) as *const c_void
89}
90
91#[no_mangle]
92pub fn libsql_wasm_engine_free(engine: *mut c_void) {
93    unsafe { Box::from_raw(engine as *mut Engine) };
94}
95
96#[repr(C)]
97pub struct libsql_wasm_udf_api {
98    libsql_value_type: unsafe extern "C" fn(*const c_void) -> i32,
99    libsql_value_int: unsafe extern "C" fn(*const c_void) -> i32,
100    libsql_value_double: unsafe extern "C" fn(*const c_void) -> f64,
101    libsql_value_text: unsafe extern "C" fn(*const c_void) -> *const u8,
102    libsql_value_blob: unsafe extern "C" fn(*const c_void) -> *const c_void,
103    libsql_value_bytes: unsafe extern "C" fn(*const c_void) -> i32,
104    libsql_result_error: unsafe extern "C" fn(*const c_void, *const u8, i32),
105    libsql_result_error_nomem: unsafe extern "C" fn(*const c_void),
106    libsql_result_int: unsafe extern "C" fn(*const c_void, i32),
107    libsql_result_double: unsafe extern "C" fn(*const c_void, f64),
108    libsql_result_text: unsafe extern "C" fn(*const c_void, *const u8, i32, *const c_void),
109    libsql_result_blob: unsafe extern "C" fn(*const c_void, *const c_void, i32, *const c_void),
110    libsql_result_null: unsafe extern "C" fn(*const c_void),
111    libsql_malloc: unsafe extern "C" fn(i32) -> *mut c_void,
112    libsql_free: unsafe extern "C" fn(*mut c_void),
113}
114
115fn alloc_slice(api: *const libsql_wasm_udf_api, s: &[u8]) -> *const c_void {
116    let len = s.len();
117    let ptr = unsafe { ((*api).libsql_malloc)(len as i32) };
118    unsafe { std::slice::from_raw_parts_mut(ptr as *mut u8, len) }.copy_from_slice(s);
119    ptr as *const c_void
120}
121
122#[no_mangle]
123pub fn libsql_run_wasm(
124    api: *const libsql_wasm_udf_api,
125    libsql_ctx: *const c_void,
126    engine: *mut Engine,
127    module: *mut Module,
128    func_name: *const u8,
129    argc: i32,
130    argv: *mut *mut c_void,
131) {
132    let engine = unsafe { &*engine };
133    let module = unsafe { &*module };
134
135    let run_wasm = |engine: &Engine, module: &Module| -> Result<(), String> {
136        let mut linker = Linker::new(&engine);
137        wasmtime_wasi::add_to_linker(&mut linker, |s| s)
138            .map_err(|e| format!("Add WASI failed: {}", e))?;
139        let wasi = WasiCtxBuilder::new()
140            .inherit_stdio()
141            .args(&[])
142            .map_err(|e| format!("Creating WasiCtx failed: {}", e))?
143            .build();
144        let mut store = Store::new(&engine, wasi);
145
146        let instance = linker
147            .instantiate(&mut store, module)
148            .map_err(|e| format!("Creating instance failed: {}", e))?;
149
150        let func_name: &str = unsafe { CStr::from_ptr(func_name as *const c_char) }
151            .to_str()
152            .map_err(|e| format!("Function name is not valid utf-8: {}", e))?;
153
154        let func = instance
155            .get_func(&mut store, func_name)
156            .ok_or_else(|| format!("Function {} not found in Wasm module", func_name))?;
157
158        let memory = instance
159            .get_memory(&mut store, "memory")
160            .ok_or_else(|| format!("Memory \"memory\" not found in wasm module"))?;
161
162        let mem_size = memory.size(&mut store) as usize;
163
164        let mut vals: Vec<Val> = Vec::new();
165        for i in 0..argc {
166            let arg = unsafe { *argv.offset(i as isize) };
167            match unsafe { ((*api).libsql_value_type)(arg) } as i8 {
168                LIBSQL_INTEGER => {
169                    vals.push(Val::I64(unsafe { ((*api).libsql_value_int)(arg) } as i64))
170                }
171                LIBSQL_FLOAT => vals.push(Val::F64(
172                    unsafe { ((*api).libsql_value_double)(arg) }.to_bits(),
173                )),
174                LIBSQL_TEXT => {
175                    let text_len = unsafe { ((*api).libsql_value_bytes)(arg) } as usize;
176                    let text: &[u8] = unsafe {
177                        std::slice::from_raw_parts(((*api).libsql_value_text)(arg), text_len)
178                    };
179
180                    let func_name = "libsql_malloc";
181                    let func_malloc =
182                        instance.get_func(&mut store, func_name).ok_or_else(|| {
183                            format!("Function {} not found in Wasm module", func_name)
184                        })?;
185                    let params = [Val::I32((text_len + 2) as i32)];
186                    let mut result = Val::null();
187                    func_malloc
188                        .call(&mut store, &params, std::slice::from_mut(&mut result))
189                        .map_err(|e| format!("Calling function {} failed: {}", func_name, e))?;
190
191                    let mem_offset = result.i32().unwrap_or(mem_size as i32) as usize;
192
193                    let data = memory.data_mut(&mut store);
194                    data[mem_offset] = LIBSQL_TEXT as u8;
195                    data[mem_offset + 1..mem_offset + 1 + text_len].copy_from_slice(text);
196                    data[mem_offset + 1 + text_len] = 0;
197
198                    vals.push(Val::I32(mem_offset as i32));
199                }
200                LIBSQL_BLOB => {
201                    let blob_len = unsafe { ((*api).libsql_value_bytes)(arg) } as usize;
202                    let blob: &[u8] = unsafe {
203                        std::slice::from_raw_parts(
204                            ((*api).libsql_value_blob)(arg) as *const u8,
205                            blob_len,
206                        )
207                    };
208                    let blob_len_i32 = blob_len as i32;
209
210                    let func_name = "libsql_malloc";
211                    let func_malloc =
212                        instance.get_func(&mut store, func_name).ok_or_else(|| {
213                            format!("Function {} not found in Wasm module", func_name)
214                        })?;
215                    let params = [Val::I32((blob_len_i32 + 5) as i32)];
216                    let mut result = Val::null();
217                    func_malloc
218                        .call(&mut store, &params, std::slice::from_mut(&mut result))
219                        .map_err(|e| format!("Calling function {} failed: {}", func_name, e))?;
220
221                    let mem_offset = result.i32().unwrap_or(mem_size as i32) as usize;
222
223                    let data = memory.data_mut(&mut store);
224                    data[mem_offset] = LIBSQL_BLOB as u8;
225                    data[mem_offset + 1..mem_offset + 1 + 4]
226                        .copy_from_slice(&blob_len_i32.to_be_bytes());
227                    data[mem_offset + 1 + 4..mem_offset + 1 + 4 + blob_len].copy_from_slice(blob);
228
229                    vals.push(Val::I32(mem_offset as i32));
230                }
231                LIBSQL_NULL => {
232                    let func_name = "libsql_malloc";
233                    let func_malloc =
234                        instance.get_func(&mut store, func_name).ok_or_else(|| {
235                            format!("Function {} not found in Wasm module", func_name)
236                        })?;
237                    let params = [Val::I32(1)];
238                    let mut result = Val::null();
239                    func_malloc
240                        .call(&mut store, &params, std::slice::from_mut(&mut result))
241                        .map_err(|e| format!("Calling function {} failed: {}", func_name, e))?;
242
243                    let mem_offset = result.i32().unwrap_or(mem_size as i32) as usize;
244
245                    memory.data_mut(&mut store)[mem_offset] = LIBSQL_NULL as u8;
246
247                    vals.push(Val::I32(mem_offset as i32));
248                }
249                _ => {
250                    return Err(format!("Unknown libSQL type"));
251                }
252            }
253        }
254
255        let mut result = Val::null();
256        func.call(&mut store, &vals, std::slice::from_mut(&mut result))
257            .map_err(|e| format!("Calling function {} failed: {}", func_name, e))?;
258
259        match result {
260            Val::I64(v) => unsafe { ((*api).libsql_result_int)(libsql_ctx, v as i32) },
261            Val::F64(v) => unsafe { ((*api).libsql_result_double)(libsql_ctx, f64::from_bits(v)) },
262            Val::I32(v) => {
263                let v = v as usize;
264                match memory.data(&store)[v] as i8 {
265                    LIBSQL_TEXT => {
266                        let result_str = unsafe {
267                            CStr::from_ptr(
268                                (memory.data(&store).as_ptr() as *const c_char)
269                                    .offset(v as isize + 1),
270                            )
271                        };
272                        let result_ptr = alloc_slice(api, result_str.to_bytes_with_nul());
273                        unsafe {
274                            ((*api).libsql_result_text)(
275                                libsql_ctx,
276                                result_ptr as *const u8,
277                                result_str.to_str().unwrap().len() as i32, // safe to unwrap, created in alloc_slice
278                                (*api).libsql_free as *const c_void,
279                            )
280                        }
281                    }
282                    LIBSQL_BLOB => {
283                        let blob_len = i32::from_be_bytes(
284                            memory.data(&store)[v + 1..v + 1 + 4].try_into().unwrap(), // safe to unwrap, slice size == 4
285                        );
286                        let result_ptr = alloc_slice(
287                            api,
288                            &memory.data(&store)[v + 1 + 4..v + 1 + 4 + blob_len as usize],
289                        );
290                        unsafe {
291                            ((*api).libsql_result_blob)(
292                                libsql_ctx,
293                                result_ptr as *const c_void,
294                                blob_len,
295                                (*api).libsql_free as *const c_void,
296                            )
297                        }
298                    }
299                    LIBSQL_NULL => unsafe { ((*api).libsql_result_null)(libsql_ctx) },
300                    _ => return Err(format!("Malformed result type byte")),
301                }
302            }
303            _ => return Err(format!("Malformed result type")),
304        }
305        Ok(())
306    };
307
308    match run_wasm(engine, module) {
309        Ok(_) => {}
310        Err(err) => unsafe {
311            ((*api).libsql_result_error)(libsql_ctx, err.as_ptr() as *const u8, err.len() as i32);
312        },
313    }
314}
315
316#[no_mangle]
317pub fn libsql_free_wasm_module(module: *mut *mut Module) {
318    unsafe { Box::from_raw(*module) };
319}