lunatic_sqlite_api/
sqlite_bindings.rs

1use anyhow::Result;
2use hash_map_id::HashMapId;
3use lunatic_common_api::{get_memory, write_to_guest_vec, IntoTrap};
4use lunatic_error_api::ErrorCtx;
5use lunatic_process::state::ProcessState;
6use lunatic_process_api::ProcessConfigCtx;
7use sqlite::{Connection, State, Statement};
8use std::{
9    collections::HashMap,
10    future::Future,
11    path::Path,
12    sync::{Arc, Mutex},
13};
14use wasmtime::{Caller, Linker, ResourceLimiter};
15
16use crate::wire_format::{BindList, SqliteError, SqliteRow, SqliteValue};
17
18pub const SQLITE_ROW: u32 = 100;
19pub const SQLITE_DONE: u32 = 101;
20
21pub type SQLiteConnections = HashMapId<Arc<Mutex<Connection>>>;
22pub type SQLiteResults = HashMapId<Vec<u8>>;
23// sometimes we need to lookup the connection_id for the statement
24pub type SQLiteStatements = HashMapId<(u64, Statement)>;
25// maps connection_id to name of allocation function
26pub type SQLiteGuestAllocators = HashMap<u64, String>;
27pub trait SQLiteCtx {
28    fn sqlite_connections(&self) -> &SQLiteConnections;
29    fn sqlite_connections_mut(&mut self) -> &mut SQLiteConnections;
30
31    fn sqlite_guest_allocator(&self) -> &SQLiteGuestAllocators;
32    fn sqlite_guest_allocator_mut(&mut self) -> &mut SQLiteGuestAllocators;
33
34    fn sqlite_statements(&self) -> &SQLiteStatements;
35    fn sqlite_statements_mut(&mut self) -> &mut SQLiteStatements;
36}
37
38// Register the SqlLite apis
39pub fn register<T: SQLiteCtx + ProcessState + Send + ErrorCtx + ResourceLimiter + Sync + 'static>(
40    linker: &mut Linker<T>,
41) -> Result<()>
42where
43    T::Config: lunatic_process_api::ProcessConfigCtx,
44{
45    linker.func_wrap("lunatic::sqlite", "open", open)?;
46    linker.func_wrap("lunatic::sqlite", "query_prepare", query_prepare)?;
47    linker.func_wrap("lunatic::sqlite", "execute", execute)?;
48    linker.func_wrap("lunatic::sqlite", "bind_value", bind_value)?;
49    linker.func_wrap("lunatic::sqlite", "sqlite3_changes", sqlite3_changes)?;
50    linker.func_wrap("lunatic::sqlite", "statement_reset", statement_reset)?;
51    linker.func_wrap2_async("lunatic::sqlite", "last_error", last_error)?;
52    linker.func_wrap("lunatic::sqlite", "sqlite3_finalize", sqlite3_finalize)?;
53    linker.func_wrap("lunatic::sqlite", "sqlite3_step", sqlite3_step)?;
54    linker.func_wrap3_async("lunatic::sqlite", "read_column", read_column)?;
55    linker.func_wrap2_async("lunatic::sqlite", "column_names", column_names)?;
56    linker.func_wrap2_async("lunatic::sqlite", "read_row", read_row)?;
57    linker.func_wrap("lunatic::sqlite", "column_count", column_count)?;
58    linker.func_wrap3_async("lunatic::sqlite", "column_name", column_name)?;
59    Ok(())
60}
61
62fn open<T>(
63    mut caller: Caller<T>,
64    path_str_ptr: u32,
65    path_str_len: u32,
66    connection_id_ptr: u32,
67) -> Result<u64>
68where
69    T: ProcessState + ErrorCtx + SQLiteCtx,
70    T::Config: lunatic_process_api::ProcessConfigCtx,
71{
72    // obtain the memory and the state
73    let memory = get_memory(&mut caller)?;
74    let (memory_slice, state) = memory.data_and_store_mut(&mut caller);
75
76    // obtain the path as a byte slice reference
77    let path = memory_slice
78        .get(path_str_ptr as usize..(path_str_ptr + path_str_len) as usize)
79        .or_trap("lunatic::sqlite::open")?;
80    let path = std::str::from_utf8(path).or_trap("lunatic::sqlite::open")?;
81    if let Err(error_message) = state.config().can_access_fs_location(Path::new(path)) {
82        let error_id = state
83            .error_resources_mut()
84            .add(anyhow::Error::msg(error_message).context(format!("Failed to access '{path}'")));
85        memory
86            .write(
87                &mut caller,
88                connection_id_ptr as usize,
89                &error_id.to_le_bytes(),
90            )
91            .or_trap("lunatic::sqlite::open")?;
92        return Ok(1);
93    }
94
95    // call the open function, and define the return code
96    let (conn_or_err_id, return_code) = match sqlite::open(path) {
97        Ok(conn) => (
98            caller
99                .data_mut()
100                .sqlite_connections_mut()
101                .add(Arc::new(Mutex::new(conn))),
102            0,
103        ),
104        Err(error) => (caller.data_mut().error_resources_mut().add(error.into()), 1),
105    };
106
107    // write the result into memory and return the return code
108    memory
109        .write(
110            &mut caller,
111            connection_id_ptr as usize,
112            &conn_or_err_id.to_le_bytes(),
113        )
114        .or_trap("lunatic::sqlite::open")?;
115    Ok(return_code)
116}
117
118fn execute<T: ProcessState + ErrorCtx + SQLiteCtx>(
119    mut caller: Caller<T>,
120    conn_id: u64,
121    exec_str_ptr: u32,
122    exec_str_len: u32,
123) -> Result<u32> {
124    let memory = get_memory(&mut caller)?;
125    let (memory_slice, state) = memory.data_and_store_mut(&mut caller);
126    let exec = memory_slice
127        .get(exec_str_ptr as usize..(exec_str_ptr + exec_str_len) as usize)
128        .or_trap("lunatic::sqlite::execute")?;
129    let exec = std::str::from_utf8(exec).or_trap("lunatic::sqlite::execute")?;
130
131    // execute a single sqlite query
132    match state
133        .sqlite_connections()
134        .get(conn_id)
135        .or_trap("lunatic::sqlite::execute")?
136        .lock()
137        .or_trap("lunatic::sqlite::execute")?
138        .execute(exec)
139    {
140        // 1 is equal to SQLITE_ERROR, which is a generic error code
141        Err(e) => Ok(e.code.unwrap_or(1) as u32),
142        Ok(_) => Ok(0),
143    }
144}
145
146fn query_prepare<T: ProcessState + ErrorCtx + SQLiteCtx>(
147    mut caller: Caller<T>,
148    conn_id: u64,
149    query_str_ptr: u32,
150    query_str_len: u32,
151) -> Result<u64> {
152    // get the memory
153    let memory = get_memory(&mut caller)?;
154    let (memory_slice, state) = memory.data_and_store_mut(&mut caller);
155
156    // get the query
157    let query = memory_slice
158        .get(query_str_ptr as usize..(query_str_ptr + query_str_len) as usize)
159        .or_trap("lunatic::sqlite::query_prepare::get_query")?;
160    let query = std::str::from_utf8(query).or_trap("lunatic::sqlite::query_prepare::from_utf8")?;
161
162    let statement = {
163        // obtain the sqlite connection
164        let conn = state
165            .sqlite_connections()
166            .get(conn_id)
167            .take()
168            .or_trap("lunatic::sqlite::query_prepare::obtain_conn")?
169            .lock()
170            .or_trap("lunatic::sqlite::query_prepare::obtain_conn")?;
171
172        // prepare the statement
173        conn.prepare(query)
174            .or_trap("lunatic::sqlite::query_prepare::prepare_statement")?
175    };
176
177    let statement_id = state.sqlite_statements_mut().add((conn_id, statement));
178
179    Ok(statement_id)
180}
181
182macro_rules! get_statement {
183    ($state:ident, $statement_id:ident) => {
184        $state
185            .sqlite_statements_mut()
186            .get_mut($statement_id)
187            .map(|(connection_id, statement)| (*connection_id, statement))
188            .or_trap("lunatic::sqlite::get_statement_by_id")?
189    };
190}
191
192macro_rules! get_conn {
193    ($state:ident, $conn_id:ident, $fn_name:literal) => {{
194        let trap_str = concat!("lunatic::sqlite::", $fn_name, "::obtain_conn");
195        $state
196            .sqlite_connections_mut()
197            .get($conn_id)
198            .take()
199            .or_trap(trap_str)?
200            .lock()
201            .or_trap(trap_str)?
202    }};
203}
204
205fn bind_value<T: ProcessState + ErrorCtx + SQLiteCtx>(
206    mut caller: Caller<T>,
207    statement_id: u64,
208    bind_data_ptr: u32,
209    bind_data_len: u32,
210) -> Result<()> {
211    // get the memory
212    let memory = get_memory(&mut caller)?;
213    let (memory_slice, state) = memory.data_and_store_mut(&mut caller);
214
215    let (_, statement) = get_statement!(state, statement_id);
216
217    // get the query
218    let bind_data = memory_slice
219        .get(bind_data_ptr as usize..(bind_data_ptr + bind_data_len) as usize)
220        .or_trap("lunatic::sqlite::bind_value::load_bind_data")?;
221
222    let values: BindList = bincode::deserialize(bind_data).unwrap();
223
224    for pair in values.iter() {
225        pair.bind(statement)
226            .or_trap("lunatic::sqlite::bind_value")?;
227    }
228
229    Ok(())
230}
231
232fn sqlite3_changes<T: ProcessState + ErrorCtx + SQLiteCtx>(
233    mut caller: Caller<T>,
234    conn_id: u64,
235) -> Result<u32> {
236    // get state
237    let memory = get_memory(&mut caller)?;
238    let (_, state) = memory.data_and_store_mut(&mut caller);
239    let conn = get_conn!(state, conn_id, "sqlite3_changes");
240
241    Ok(conn.change_count() as u32)
242}
243
244fn statement_reset<T: ProcessState + ErrorCtx + SQLiteCtx>(
245    mut caller: Caller<T>,
246    statement_id: u64,
247) -> Result<()> {
248    // get state
249    let memory = get_memory(&mut caller)?;
250    let (_, state) = memory.data_and_store_mut(&mut caller);
251    let (_, stmt) = get_statement!(state, statement_id);
252
253    stmt.reset().or_trap("lunatic::sqlite::statement_reset")?;
254
255    Ok(())
256}
257
258fn read_column<T: ProcessState + ErrorCtx + SQLiteCtx + Send + Sync>(
259    mut caller: Caller<T>,
260    statement_id: u64,
261    col_idx: u32,
262    opaque_ptr: u32,
263) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
264    Box::new(async move {
265        // get state
266        let memory = get_memory(&mut caller)?;
267        let (_, state) = memory.data_and_store_mut(&mut caller);
268        let (_, stmt) = get_statement!(state, statement_id);
269
270        let column = bincode::serialize(&SqliteValue::read_column(stmt, col_idx as usize)?)
271            .or_trap("lunatic::sqlite::read_column")?;
272
273        write_to_guest_vec(&mut caller, &memory, &column, opaque_ptr).await
274    })
275}
276
277fn column_names<T: ProcessState + ErrorCtx + SQLiteCtx + Send + Sync>(
278    mut caller: Caller<T>,
279    statement_id: u64,
280    opaque_ptr: u32,
281) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
282    Box::new(async move {
283        // get state
284        let memory = get_memory(&mut caller)?;
285        let (_, state) = memory.data_and_store_mut(&mut caller);
286        let (_, stmt) = get_statement!(state, statement_id);
287
288        let column_names = stmt.column_names().to_vec();
289
290        let column_names =
291            bincode::serialize(&column_names).or_trap("lunatic::sqlite::column_names")?;
292
293        write_to_guest_vec(&mut caller, &memory, &column_names, opaque_ptr).await
294    })
295}
296
297// this function assumes that the row has not been read yet and therefore
298// starts at column_idx 0
299fn read_row<T: ProcessState + ErrorCtx + SQLiteCtx + Send + Sync>(
300    mut caller: Caller<T>,
301    statement_id: u64,
302    opaque_ptr: u32,
303) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
304    Box::new(async move {
305        // get state
306        let memory = get_memory(&mut caller)?;
307        let (_, state) = memory.data_and_store_mut(&mut caller);
308        let (_, stmt) = get_statement!(state, statement_id);
309
310        let read_row = SqliteRow::read_row(stmt)?;
311
312        let row = bincode::serialize(&read_row).or_trap("lunatic::sqlite::read_row")?;
313
314        write_to_guest_vec(&mut caller, &memory, &row, opaque_ptr).await
315    })
316}
317
318fn last_error<T: ProcessState + ErrorCtx + SQLiteCtx + ResourceLimiter + Send + Sync>(
319    mut caller: Caller<T>,
320    conn_id: u64,
321    opaque_ptr: u32,
322) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
323    Box::new(async move {
324        // get state
325        let memory = get_memory(&mut caller)?;
326        let err = {
327            let (_, state) = memory.data_and_store_mut(&mut caller);
328            let mut conn = get_conn!(state, conn_id, "last_error");
329
330            let err: SqliteError = conn.last().or_trap("lunatic::sqlite::last_error")?.into();
331            bincode::serialize(&err)
332                .or_trap("lunatic::sqlite::last_error::encode_error_wire_format")?
333        };
334
335        write_to_guest_vec(&mut caller, &memory, &err, opaque_ptr).await
336    })
337}
338
339fn sqlite3_finalize<T: ProcessState + ErrorCtx + SQLiteCtx>(
340    mut caller: Caller<T>,
341    statement_id: u64,
342) -> Result<()> {
343    // get state
344    let memory = get_memory(&mut caller)?;
345    let (_, state) = memory.data_and_store_mut(&mut caller);
346    // dropping the statement should invoke the C function `sqlite3_finalize`
347    state
348        .sqlite_statements_mut()
349        .remove(statement_id)
350        .or_trap("lunatic::sqlite::sqlite3_finalize")?;
351
352    Ok(())
353}
354
355// sends back SQLITE_DONE or SQLITE_ROW depending on whether there's more data available or not
356fn sqlite3_step<T: ProcessState + ErrorCtx + SQLiteCtx>(
357    mut caller: Caller<T>,
358    statement_id: u64,
359) -> Result<u32> {
360    // get state
361    let memory = get_memory(&mut caller)?;
362    let (_, state) = memory.data_and_store_mut(&mut caller);
363    let (_, statement) = get_statement!(state, statement_id);
364
365    match statement.next().or_trap("lunatic::sqlite::sqlite3_step")? {
366        State::Done => Ok(SQLITE_DONE),
367        State::Row => Ok(SQLITE_ROW),
368    }
369}
370
371fn column_count<T: ProcessState + ErrorCtx + SQLiteCtx>(
372    mut caller: Caller<T>,
373    statement_id: u64,
374) -> Result<u32> {
375    // get state
376    let memory = get_memory(&mut caller)?;
377    let (_, state) = memory.data_and_store_mut(&mut caller);
378    let (_, statement) = get_statement!(state, statement_id);
379
380    Ok(statement.column_count() as u32)
381}
382
383fn column_name<T: ProcessState + ErrorCtx + SQLiteCtx + Send + Sync>(
384    mut caller: Caller<T>,
385    statement_id: u64,
386    column_idx: u32,
387    opaque_ptr: u32,
388) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
389    Box::new(async move {
390        // get state
391        let memory = get_memory(&mut caller)?;
392        let (_, column_name) = {
393            let (_, state) = memory.data_and_store_mut(&mut caller);
394            let (connection_id, statement) = get_statement!(state, statement_id);
395
396            (
397                connection_id,
398                statement
399                    .column_name(column_idx as usize)
400                    .or_trap("lunatic::sqlite::column_name")?
401                    .to_owned(),
402            )
403        };
404
405        write_to_guest_vec(&mut caller, &memory, column_name.as_bytes(), opaque_ptr).await
406    })
407}