1use anyhow::Result;
2use hash_map_id::HashMapId;
3use lunatic_common_api::{get_memory, write_to_guest_vec, IntoTrap};
4use lunatic_error_api::ErrorCtx;
5use lunatic_process::state::ProcessState;
6use lunatic_process_api::ProcessConfigCtx;
7use sqlite::{Connection, State, Statement};
8use std::{
9 collections::HashMap,
10 future::Future,
11 path::Path,
12 sync::{Arc, Mutex},
13};
14use wasmtime::{Caller, Linker, ResourceLimiter};
15
16use crate::wire_format::{BindList, SqliteError, SqliteRow, SqliteValue};
17
18pub const SQLITE_ROW: u32 = 100;
19pub const SQLITE_DONE: u32 = 101;
20
21pub type SQLiteConnections = HashMapId<Arc<Mutex<Connection>>>;
22pub type SQLiteResults = HashMapId<Vec<u8>>;
23pub type SQLiteStatements = HashMapId<(u64, Statement)>;
25pub type SQLiteGuestAllocators = HashMap<u64, String>;
27pub trait SQLiteCtx {
28 fn sqlite_connections(&self) -> &SQLiteConnections;
29 fn sqlite_connections_mut(&mut self) -> &mut SQLiteConnections;
30
31 fn sqlite_guest_allocator(&self) -> &SQLiteGuestAllocators;
32 fn sqlite_guest_allocator_mut(&mut self) -> &mut SQLiteGuestAllocators;
33
34 fn sqlite_statements(&self) -> &SQLiteStatements;
35 fn sqlite_statements_mut(&mut self) -> &mut SQLiteStatements;
36}
37
38pub fn register<T: SQLiteCtx + ProcessState + Send + ErrorCtx + ResourceLimiter + Sync + 'static>(
40 linker: &mut Linker<T>,
41) -> Result<()>
42where
43 T::Config: lunatic_process_api::ProcessConfigCtx,
44{
45 linker.func_wrap("lunatic::sqlite", "open", open)?;
46 linker.func_wrap("lunatic::sqlite", "query_prepare", query_prepare)?;
47 linker.func_wrap("lunatic::sqlite", "execute", execute)?;
48 linker.func_wrap("lunatic::sqlite", "bind_value", bind_value)?;
49 linker.func_wrap("lunatic::sqlite", "sqlite3_changes", sqlite3_changes)?;
50 linker.func_wrap("lunatic::sqlite", "statement_reset", statement_reset)?;
51 linker.func_wrap2_async("lunatic::sqlite", "last_error", last_error)?;
52 linker.func_wrap("lunatic::sqlite", "sqlite3_finalize", sqlite3_finalize)?;
53 linker.func_wrap("lunatic::sqlite", "sqlite3_step", sqlite3_step)?;
54 linker.func_wrap3_async("lunatic::sqlite", "read_column", read_column)?;
55 linker.func_wrap2_async("lunatic::sqlite", "column_names", column_names)?;
56 linker.func_wrap2_async("lunatic::sqlite", "read_row", read_row)?;
57 linker.func_wrap("lunatic::sqlite", "column_count", column_count)?;
58 linker.func_wrap3_async("lunatic::sqlite", "column_name", column_name)?;
59 Ok(())
60}
61
62fn open<T>(
63 mut caller: Caller<T>,
64 path_str_ptr: u32,
65 path_str_len: u32,
66 connection_id_ptr: u32,
67) -> Result<u64>
68where
69 T: ProcessState + ErrorCtx + SQLiteCtx,
70 T::Config: lunatic_process_api::ProcessConfigCtx,
71{
72 let memory = get_memory(&mut caller)?;
74 let (memory_slice, state) = memory.data_and_store_mut(&mut caller);
75
76 let path = memory_slice
78 .get(path_str_ptr as usize..(path_str_ptr + path_str_len) as usize)
79 .or_trap("lunatic::sqlite::open")?;
80 let path = std::str::from_utf8(path).or_trap("lunatic::sqlite::open")?;
81 if let Err(error_message) = state.config().can_access_fs_location(Path::new(path)) {
82 let error_id = state
83 .error_resources_mut()
84 .add(anyhow::Error::msg(error_message).context(format!("Failed to access '{path}'")));
85 memory
86 .write(
87 &mut caller,
88 connection_id_ptr as usize,
89 &error_id.to_le_bytes(),
90 )
91 .or_trap("lunatic::sqlite::open")?;
92 return Ok(1);
93 }
94
95 let (conn_or_err_id, return_code) = match sqlite::open(path) {
97 Ok(conn) => (
98 caller
99 .data_mut()
100 .sqlite_connections_mut()
101 .add(Arc::new(Mutex::new(conn))),
102 0,
103 ),
104 Err(error) => (caller.data_mut().error_resources_mut().add(error.into()), 1),
105 };
106
107 memory
109 .write(
110 &mut caller,
111 connection_id_ptr as usize,
112 &conn_or_err_id.to_le_bytes(),
113 )
114 .or_trap("lunatic::sqlite::open")?;
115 Ok(return_code)
116}
117
118fn execute<T: ProcessState + ErrorCtx + SQLiteCtx>(
119 mut caller: Caller<T>,
120 conn_id: u64,
121 exec_str_ptr: u32,
122 exec_str_len: u32,
123) -> Result<u32> {
124 let memory = get_memory(&mut caller)?;
125 let (memory_slice, state) = memory.data_and_store_mut(&mut caller);
126 let exec = memory_slice
127 .get(exec_str_ptr as usize..(exec_str_ptr + exec_str_len) as usize)
128 .or_trap("lunatic::sqlite::execute")?;
129 let exec = std::str::from_utf8(exec).or_trap("lunatic::sqlite::execute")?;
130
131 match state
133 .sqlite_connections()
134 .get(conn_id)
135 .or_trap("lunatic::sqlite::execute")?
136 .lock()
137 .or_trap("lunatic::sqlite::execute")?
138 .execute(exec)
139 {
140 Err(e) => Ok(e.code.unwrap_or(1) as u32),
142 Ok(_) => Ok(0),
143 }
144}
145
146fn query_prepare<T: ProcessState + ErrorCtx + SQLiteCtx>(
147 mut caller: Caller<T>,
148 conn_id: u64,
149 query_str_ptr: u32,
150 query_str_len: u32,
151) -> Result<u64> {
152 let memory = get_memory(&mut caller)?;
154 let (memory_slice, state) = memory.data_and_store_mut(&mut caller);
155
156 let query = memory_slice
158 .get(query_str_ptr as usize..(query_str_ptr + query_str_len) as usize)
159 .or_trap("lunatic::sqlite::query_prepare::get_query")?;
160 let query = std::str::from_utf8(query).or_trap("lunatic::sqlite::query_prepare::from_utf8")?;
161
162 let statement = {
163 let conn = state
165 .sqlite_connections()
166 .get(conn_id)
167 .take()
168 .or_trap("lunatic::sqlite::query_prepare::obtain_conn")?
169 .lock()
170 .or_trap("lunatic::sqlite::query_prepare::obtain_conn")?;
171
172 conn.prepare(query)
174 .or_trap("lunatic::sqlite::query_prepare::prepare_statement")?
175 };
176
177 let statement_id = state.sqlite_statements_mut().add((conn_id, statement));
178
179 Ok(statement_id)
180}
181
182macro_rules! get_statement {
183 ($state:ident, $statement_id:ident) => {
184 $state
185 .sqlite_statements_mut()
186 .get_mut($statement_id)
187 .map(|(connection_id, statement)| (*connection_id, statement))
188 .or_trap("lunatic::sqlite::get_statement_by_id")?
189 };
190}
191
192macro_rules! get_conn {
193 ($state:ident, $conn_id:ident, $fn_name:literal) => {{
194 let trap_str = concat!("lunatic::sqlite::", $fn_name, "::obtain_conn");
195 $state
196 .sqlite_connections_mut()
197 .get($conn_id)
198 .take()
199 .or_trap(trap_str)?
200 .lock()
201 .or_trap(trap_str)?
202 }};
203}
204
205fn bind_value<T: ProcessState + ErrorCtx + SQLiteCtx>(
206 mut caller: Caller<T>,
207 statement_id: u64,
208 bind_data_ptr: u32,
209 bind_data_len: u32,
210) -> Result<()> {
211 let memory = get_memory(&mut caller)?;
213 let (memory_slice, state) = memory.data_and_store_mut(&mut caller);
214
215 let (_, statement) = get_statement!(state, statement_id);
216
217 let bind_data = memory_slice
219 .get(bind_data_ptr as usize..(bind_data_ptr + bind_data_len) as usize)
220 .or_trap("lunatic::sqlite::bind_value::load_bind_data")?;
221
222 let values: BindList = bincode::deserialize(bind_data).unwrap();
223
224 for pair in values.iter() {
225 pair.bind(statement)
226 .or_trap("lunatic::sqlite::bind_value")?;
227 }
228
229 Ok(())
230}
231
232fn sqlite3_changes<T: ProcessState + ErrorCtx + SQLiteCtx>(
233 mut caller: Caller<T>,
234 conn_id: u64,
235) -> Result<u32> {
236 let memory = get_memory(&mut caller)?;
238 let (_, state) = memory.data_and_store_mut(&mut caller);
239 let conn = get_conn!(state, conn_id, "sqlite3_changes");
240
241 Ok(conn.change_count() as u32)
242}
243
244fn statement_reset<T: ProcessState + ErrorCtx + SQLiteCtx>(
245 mut caller: Caller<T>,
246 statement_id: u64,
247) -> Result<()> {
248 let memory = get_memory(&mut caller)?;
250 let (_, state) = memory.data_and_store_mut(&mut caller);
251 let (_, stmt) = get_statement!(state, statement_id);
252
253 stmt.reset().or_trap("lunatic::sqlite::statement_reset")?;
254
255 Ok(())
256}
257
258fn read_column<T: ProcessState + ErrorCtx + SQLiteCtx + Send + Sync>(
259 mut caller: Caller<T>,
260 statement_id: u64,
261 col_idx: u32,
262 opaque_ptr: u32,
263) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
264 Box::new(async move {
265 let memory = get_memory(&mut caller)?;
267 let (_, state) = memory.data_and_store_mut(&mut caller);
268 let (_, stmt) = get_statement!(state, statement_id);
269
270 let column = bincode::serialize(&SqliteValue::read_column(stmt, col_idx as usize)?)
271 .or_trap("lunatic::sqlite::read_column")?;
272
273 write_to_guest_vec(&mut caller, &memory, &column, opaque_ptr).await
274 })
275}
276
277fn column_names<T: ProcessState + ErrorCtx + SQLiteCtx + Send + Sync>(
278 mut caller: Caller<T>,
279 statement_id: u64,
280 opaque_ptr: u32,
281) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
282 Box::new(async move {
283 let memory = get_memory(&mut caller)?;
285 let (_, state) = memory.data_and_store_mut(&mut caller);
286 let (_, stmt) = get_statement!(state, statement_id);
287
288 let column_names = stmt.column_names().to_vec();
289
290 let column_names =
291 bincode::serialize(&column_names).or_trap("lunatic::sqlite::column_names")?;
292
293 write_to_guest_vec(&mut caller, &memory, &column_names, opaque_ptr).await
294 })
295}
296
297fn read_row<T: ProcessState + ErrorCtx + SQLiteCtx + Send + Sync>(
300 mut caller: Caller<T>,
301 statement_id: u64,
302 opaque_ptr: u32,
303) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
304 Box::new(async move {
305 let memory = get_memory(&mut caller)?;
307 let (_, state) = memory.data_and_store_mut(&mut caller);
308 let (_, stmt) = get_statement!(state, statement_id);
309
310 let read_row = SqliteRow::read_row(stmt)?;
311
312 let row = bincode::serialize(&read_row).or_trap("lunatic::sqlite::read_row")?;
313
314 write_to_guest_vec(&mut caller, &memory, &row, opaque_ptr).await
315 })
316}
317
318fn last_error<T: ProcessState + ErrorCtx + SQLiteCtx + ResourceLimiter + Send + Sync>(
319 mut caller: Caller<T>,
320 conn_id: u64,
321 opaque_ptr: u32,
322) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
323 Box::new(async move {
324 let memory = get_memory(&mut caller)?;
326 let err = {
327 let (_, state) = memory.data_and_store_mut(&mut caller);
328 let mut conn = get_conn!(state, conn_id, "last_error");
329
330 let err: SqliteError = conn.last().or_trap("lunatic::sqlite::last_error")?.into();
331 bincode::serialize(&err)
332 .or_trap("lunatic::sqlite::last_error::encode_error_wire_format")?
333 };
334
335 write_to_guest_vec(&mut caller, &memory, &err, opaque_ptr).await
336 })
337}
338
339fn sqlite3_finalize<T: ProcessState + ErrorCtx + SQLiteCtx>(
340 mut caller: Caller<T>,
341 statement_id: u64,
342) -> Result<()> {
343 let memory = get_memory(&mut caller)?;
345 let (_, state) = memory.data_and_store_mut(&mut caller);
346 state
348 .sqlite_statements_mut()
349 .remove(statement_id)
350 .or_trap("lunatic::sqlite::sqlite3_finalize")?;
351
352 Ok(())
353}
354
355fn sqlite3_step<T: ProcessState + ErrorCtx + SQLiteCtx>(
357 mut caller: Caller<T>,
358 statement_id: u64,
359) -> Result<u32> {
360 let memory = get_memory(&mut caller)?;
362 let (_, state) = memory.data_and_store_mut(&mut caller);
363 let (_, statement) = get_statement!(state, statement_id);
364
365 match statement.next().or_trap("lunatic::sqlite::sqlite3_step")? {
366 State::Done => Ok(SQLITE_DONE),
367 State::Row => Ok(SQLITE_ROW),
368 }
369}
370
371fn column_count<T: ProcessState + ErrorCtx + SQLiteCtx>(
372 mut caller: Caller<T>,
373 statement_id: u64,
374) -> Result<u32> {
375 let memory = get_memory(&mut caller)?;
377 let (_, state) = memory.data_and_store_mut(&mut caller);
378 let (_, statement) = get_statement!(state, statement_id);
379
380 Ok(statement.column_count() as u32)
381}
382
383fn column_name<T: ProcessState + ErrorCtx + SQLiteCtx + Send + Sync>(
384 mut caller: Caller<T>,
385 statement_id: u64,
386 column_idx: u32,
387 opaque_ptr: u32,
388) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
389 Box::new(async move {
390 let memory = get_memory(&mut caller)?;
392 let (_, column_name) = {
393 let (_, state) = memory.data_and_store_mut(&mut caller);
394 let (connection_id, statement) = get_statement!(state, statement_id);
395
396 (
397 connection_id,
398 statement
399 .column_name(column_idx as usize)
400 .or_trap("lunatic::sqlite::column_name")?
401 .to_owned(),
402 )
403 };
404
405 write_to_guest_vec(&mut caller, &memory, column_name.as_bytes(), opaque_ptr).await
406 })
407}