Skip to main content

ic_sqlite_vfs/db/
statement.rs

1//! Prepared statement lifecycle and row iteration.
2//!
3//! Each execution resets and clears bindings first. `Drop` finalizes the SQLite
4//! statement, so prepared statements can be reused without leaking C resources.
5
6use crate::db::connection::sqlite_error;
7use crate::db::row::{FromColumn, Row};
8use crate::db::value::{bind_all, bind_named_all, ToSql};
9use crate::db::DbError;
10use crate::sqlite_vfs::ffi;
11use std::ptr::NonNull;
12
13#[cfg(any(test, feature = "canister-api-test-failpoints"))]
14use std::cell::RefCell;
15#[cfg(any(test, feature = "canister-api-test-failpoints"))]
16use std::collections::BTreeMap;
17
18#[cfg(any(test, feature = "canister-api-test-failpoints"))]
19thread_local! {
20    static STEP_FAILPOINTS: RefCell<BTreeMap<crate::stable::memory::ContextId, StepFailpointState>> = const { RefCell::new(BTreeMap::new()) };
21}
22
23#[cfg(any(test, feature = "canister-api-test-failpoints"))]
24#[derive(Clone, Copy, Debug, Eq, PartialEq)]
25pub struct StepFailpoint {
26    pub ordinal: u64,
27    pub code: std::ffi::c_int,
28}
29
30#[cfg(any(test, feature = "canister-api-test-failpoints"))]
31#[derive(Clone, Copy, Debug)]
32struct StepFailpointState {
33    failpoint: StepFailpoint,
34    count: u64,
35}
36
37pub struct Statement<'connection> {
38    db: *mut ffi::sqlite3,
39    raw: NonNull<ffi::sqlite3_stmt>,
40    parameter_count: usize,
41    _connection: std::marker::PhantomData<&'connection ()>,
42}
43
44pub struct Rows<'statement, 'connection> {
45    statement: &'statement mut Statement<'connection>,
46    done: bool,
47}
48
49#[cfg(feature = "bench-profile")]
50#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
51pub struct QueryOptionalStringTextProfile {
52    pub reset_bind: u64,
53    pub step: u64,
54    pub column_read: u64,
55}
56
57impl<'connection> Statement<'connection> {
58    pub(crate) fn new(db: *mut ffi::sqlite3, raw: NonNull<ffi::sqlite3_stmt>) -> Self {
59        let parameter_count =
60            usize::try_from(unsafe { ffi::sqlite3_bind_parameter_count(raw.as_ptr()) })
61                .unwrap_or(0);
62        Self {
63            db,
64            raw,
65            parameter_count,
66            _connection: std::marker::PhantomData,
67        }
68    }
69
70    pub(crate) fn into_raw(self) -> NonNull<ffi::sqlite3_stmt> {
71        let raw = self.raw;
72        std::mem::forget(self);
73        raw
74    }
75
76    pub fn execute(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
77        self.reset_and_bind(values)?;
78        let rc = step(self.raw.as_ptr())?;
79        if rc == ffi::SQLITE_DONE {
80            Ok(())
81        } else {
82            Err(sqlite_error(self.db, rc))
83        }
84    }
85
86    pub fn execute_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
87        self.reset_and_bind_named(values)?;
88        let rc = step(self.raw.as_ptr())?;
89        if rc == ffi::SQLITE_DONE {
90            Ok(())
91        } else {
92            Err(sqlite_error(self.db, rc))
93        }
94    }
95
96    pub fn query<'statement>(
97        &'statement mut self,
98        values: &[&dyn ToSql],
99    ) -> Result<Rows<'statement, 'connection>, DbError> {
100        self.reset_and_bind(values)?;
101        Ok(Rows {
102            statement: self,
103            done: false,
104        })
105    }
106
107    pub fn query_named<'statement>(
108        &'statement mut self,
109        values: &[(&str, &dyn ToSql)],
110    ) -> Result<Rows<'statement, 'connection>, DbError> {
111        self.reset_and_bind_named(values)?;
112        Ok(Rows {
113            statement: self,
114            done: false,
115        })
116    }
117
118    pub fn query_one<T, F>(&mut self, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
119    where
120        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
121    {
122        let mut rows = self.query(values)?;
123        match rows.next_row()? {
124            Some(row) => f(&row),
125            None => Err(DbError::NotFound),
126        }
127    }
128
129    pub fn query_one_named<T, F>(
130        &mut self,
131        values: &[(&str, &dyn ToSql)],
132        f: F,
133    ) -> Result<T, DbError>
134    where
135        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
136    {
137        let mut rows = self.query_named(values)?;
138        match rows.next_row()? {
139            Some(row) => f(&row),
140            None => Err(DbError::NotFound),
141        }
142    }
143
144    pub fn query_optional<T, F>(
145        &mut self,
146        values: &[&dyn ToSql],
147        f: F,
148    ) -> Result<Option<T>, DbError>
149    where
150        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
151    {
152        let mut rows = self.query(values)?;
153        match rows.next_row()? {
154            Some(row) => f(&row).map(Some),
155            None => Ok(None),
156        }
157    }
158
159    pub fn query_optional_named<T, F>(
160        &mut self,
161        values: &[(&str, &dyn ToSql)],
162        f: F,
163    ) -> Result<Option<T>, DbError>
164    where
165        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
166    {
167        let mut rows = self.query_named(values)?;
168        match rows.next_row()? {
169            Some(row) => f(&row).map(Some),
170            None => Ok(None),
171        }
172    }
173
174    pub fn query_all<T, F>(&mut self, values: &[&dyn ToSql], mut f: F) -> Result<Vec<T>, DbError>
175    where
176        F: FnMut(&Row<'_>) -> Result<T, DbError>,
177    {
178        let mut rows = self.query(values)?;
179        let mut output = Vec::new();
180        while let Some(row) = rows.next_row()? {
181            output.push(f(&row)?);
182        }
183        Ok(output)
184    }
185
186    pub fn query_all_named<T, F>(
187        &mut self,
188        values: &[(&str, &dyn ToSql)],
189        mut f: F,
190    ) -> Result<Vec<T>, DbError>
191    where
192        F: FnMut(&Row<'_>) -> Result<T, DbError>,
193    {
194        let mut rows = self.query_named(values)?;
195        let mut output = Vec::new();
196        while let Some(row) = rows.next_row()? {
197            output.push(f(&row)?);
198        }
199        Ok(output)
200    }
201
202    pub fn query_scalar<T: FromColumn>(&mut self, values: &[&dyn ToSql]) -> Result<T, DbError> {
203        self.query_one(values, |row| row.get(0))
204    }
205    pub fn query_scalar_named<T: FromColumn>(
206        &mut self,
207        values: &[(&str, &dyn ToSql)],
208    ) -> Result<T, DbError> {
209        self.query_one_named(values, |row| row.get(0))
210    }
211
212    pub fn query_optional_scalar<T: FromColumn>(
213        &mut self,
214        values: &[&dyn ToSql],
215    ) -> Result<Option<T>, DbError> {
216        self.query_optional(values, |row| row.get(0))
217    }
218
219    pub fn query_optional_string_text(&mut self, value: &str) -> Result<Option<String>, DbError> {
220        self.reset_and_bind_single_text(value)?;
221        match step(self.raw.as_ptr())? {
222            ffi::SQLITE_ROW => read_string_column_zero(self.raw.as_ptr()).map(Some),
223            ffi::SQLITE_DONE => Ok(None),
224            rc => Err(sqlite_error(self.db, rc)),
225        }
226    }
227
228    pub fn query_optional_string_text_len(
229        &mut self,
230        value: &str,
231    ) -> Result<Option<usize>, DbError> {
232        self.reset_and_bind_single_text(value)?;
233        match step(self.raw.as_ptr())? {
234            ffi::SQLITE_ROW => read_string_column_zero_len(self.raw.as_ptr()).map(Some),
235            ffi::SQLITE_DONE => Ok(None),
236            rc => Err(sqlite_error(self.db, rc)),
237        }
238    }
239
240    #[cfg(feature = "bench-profile")]
241    #[doc(hidden)]
242    pub fn query_optional_string_text_profiled(
243        &mut self,
244        value: &str,
245    ) -> Result<(Option<String>, QueryOptionalStringTextProfile), DbError> {
246        let mut profile = QueryOptionalStringTextProfile::default();
247
248        let start = instruction_counter();
249        self.reset_and_bind_single_text(value)?;
250        profile.reset_bind = instruction_counter().saturating_sub(start);
251
252        let start = instruction_counter();
253        let rc = step(self.raw.as_ptr())?;
254        profile.step = instruction_counter().saturating_sub(start);
255
256        match rc {
257            ffi::SQLITE_ROW => {
258                let start = instruction_counter();
259                let value = read_string_column_zero(self.raw.as_ptr()).map(Some);
260                profile.column_read = instruction_counter().saturating_sub(start);
261                value.map(|value| (value, profile))
262            }
263            ffi::SQLITE_DONE => Ok((None, profile)),
264            rc => Err(sqlite_error(self.db, rc)),
265        }
266    }
267
268    #[cfg(feature = "bench-profile")]
269    #[doc(hidden)]
270    pub fn query_optional_string_text_len_profiled(
271        &mut self,
272        value: &str,
273    ) -> Result<(Option<usize>, QueryOptionalStringTextProfile), DbError> {
274        let mut profile = QueryOptionalStringTextProfile::default();
275
276        let start = instruction_counter();
277        self.reset_and_bind_single_text(value)?;
278        profile.reset_bind = instruction_counter().saturating_sub(start);
279
280        let start = instruction_counter();
281        let rc = step(self.raw.as_ptr())?;
282        profile.step = instruction_counter().saturating_sub(start);
283
284        match rc {
285            ffi::SQLITE_ROW => {
286                let start = instruction_counter();
287                let value = read_string_column_zero_len(self.raw.as_ptr()).map(Some);
288                profile.column_read = instruction_counter().saturating_sub(start);
289                value.map(|value| (value, profile))
290            }
291            ffi::SQLITE_DONE => Ok((None, profile)),
292            rc => Err(sqlite_error(self.db, rc)),
293        }
294    }
295
296    pub fn query_optional_scalar_named<T: FromColumn>(
297        &mut self,
298        values: &[(&str, &dyn ToSql)],
299    ) -> Result<Option<T>, DbError> {
300        self.query_optional_named(values, |row| row.get(0))
301    }
302
303    pub fn query_column<T: FromColumn>(
304        &mut self,
305        values: &[&dyn ToSql],
306    ) -> Result<Vec<T>, DbError> {
307        self.query_all(values, |row| row.get(0))
308    }
309
310    pub fn query_column_named<T: FromColumn>(
311        &mut self,
312        values: &[(&str, &dyn ToSql)],
313    ) -> Result<Vec<T>, DbError> {
314        self.query_all_named(values, |row| row.get(0))
315    }
316
317    fn reset_and_bind(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
318        let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
319        if reset_rc != ffi::SQLITE_OK {
320            return Err(sqlite_error(self.db, reset_rc));
321        }
322        let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
323        if clear_rc != ffi::SQLITE_OK {
324            return Err(sqlite_error(self.db, clear_rc));
325        }
326        bind_all(self.raw.as_ptr(), values)
327    }
328
329    fn reset_and_bind_single_text(&mut self, value: &str) -> Result<(), DbError> {
330        if self.parameter_count != 1 {
331            return Err(DbError::ParameterCountMismatch {
332                expected: self.parameter_count,
333                actual: 1,
334            });
335        }
336        let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
337        if reset_rc != ffi::SQLITE_OK {
338            return Err(sqlite_error(self.db, reset_rc));
339        }
340        let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
341        if clear_rc != ffi::SQLITE_OK {
342            return Err(sqlite_error(self.db, clear_rc));
343        }
344        let len = std::ffi::c_int::try_from(value.len()).map_err(|_| DbError::TextTooLarge)?;
345        let bind_rc = unsafe {
346            ffi::sqlite3_bind_text(
347                self.raw.as_ptr(),
348                1,
349                value.as_ptr().cast(),
350                len,
351                ffi::SQLITE_TRANSIENT(),
352            )
353        };
354        if bind_rc == ffi::SQLITE_OK {
355            Ok(())
356        } else {
357            Err(DbError::Sqlite(bind_rc, "sqlite bind failed".to_string()))
358        }
359    }
360
361    fn reset_and_bind_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
362        let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
363        if reset_rc != ffi::SQLITE_OK {
364            return Err(sqlite_error(self.db, reset_rc));
365        }
366        let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
367        if clear_rc != ffi::SQLITE_OK {
368            return Err(sqlite_error(self.db, clear_rc));
369        }
370        bind_named_all(self.raw.as_ptr(), values)
371    }
372}
373
374fn read_string_column_zero(statement: *mut ffi::sqlite3_stmt) -> Result<String, DbError> {
375    let actual = unsafe { ffi::sqlite3_column_type(statement, 0) };
376    if actual != ffi::SQLITE_TEXT {
377        return Err(DbError::TypeMismatch {
378            index: 0,
379            expected: "TEXT",
380            actual: sqlite_type_name(actual),
381        });
382    }
383    let text = unsafe { ffi::sqlite3_column_text(statement, 0) };
384    let len = unsafe { ffi::sqlite3_column_bytes(statement, 0) };
385    let len = usize::try_from(len).map_err(|_| DbError::TextTooLarge)?;
386    if len == 0 || text.is_null() {
387        return Ok(String::new());
388    }
389    let bytes = unsafe { std::slice::from_raw_parts(text.cast::<u8>(), len) };
390    Ok(String::from_utf8_lossy(bytes).into_owned())
391}
392
393fn read_string_column_zero_len(statement: *mut ffi::sqlite3_stmt) -> Result<usize, DbError> {
394    let actual = unsafe { ffi::sqlite3_column_type(statement, 0) };
395    if actual != ffi::SQLITE_TEXT {
396        return Err(DbError::TypeMismatch {
397            index: 0,
398            expected: "TEXT",
399            actual: sqlite_type_name(actual),
400        });
401    }
402    let len = unsafe { ffi::sqlite3_column_bytes(statement, 0) };
403    usize::try_from(len).map_err(|_| DbError::TextTooLarge)
404}
405
406fn sqlite_type_name(code: std::ffi::c_int) -> &'static str {
407    match code {
408        ffi::SQLITE_INTEGER => "INTEGER",
409        ffi::SQLITE_FLOAT => "REAL",
410        ffi::SQLITE_TEXT => "TEXT",
411        ffi::SQLITE_BLOB => "BLOB",
412        ffi::SQLITE_NULL => "NULL",
413        _ => "UNKNOWN",
414    }
415}
416
417impl Rows<'_, '_> {
418    pub fn next_row(&mut self) -> Result<Option<Row<'_>>, DbError> {
419        if self.done {
420            return Ok(None);
421        }
422        let rc = step(self.statement.raw.as_ptr())?;
423        match rc {
424            ffi::SQLITE_ROW => Ok(Some(Row::new(self.statement.raw.as_ptr()))),
425            ffi::SQLITE_DONE => {
426                self.done = true;
427                Ok(None)
428            }
429            _ => Err(sqlite_error(self.statement.db, rc)),
430        }
431    }
432}
433
434fn step(statement: *mut ffi::sqlite3_stmt) -> Result<std::ffi::c_int, DbError> {
435    #[cfg(any(test, feature = "canister-api-test-failpoints"))]
436    if let Some(code) = hit_step_failpoint() {
437        return Err(DbError::Sqlite(code, "sqlite step failpoint".to_string()));
438    }
439    Ok(unsafe { ffi::sqlite3_step(statement) })
440}
441
442#[cfg(feature = "bench-profile")]
443fn instruction_counter() -> u64 {
444    #[cfg(target_arch = "wasm32")]
445    {
446        ic_cdk::api::performance_counter(0)
447    }
448    #[cfg(not(target_arch = "wasm32"))]
449    {
450        0
451    }
452}
453
454#[cfg(any(test, feature = "canister-api-test-failpoints"))]
455pub fn set_step_failpoint(failpoint: StepFailpoint) {
456    if let Ok(context) = crate::stable::memory::active_context_id() {
457        STEP_FAILPOINTS.with(|slot| {
458            slot.borrow_mut().insert(
459                context,
460                StepFailpointState {
461                    failpoint,
462                    count: 0,
463                },
464            );
465        });
466    }
467}
468
469#[cfg(any(test, feature = "canister-api-test-failpoints"))]
470pub fn clear_step_failpoint() {
471    STEP_FAILPOINTS.with(|slot| slot.borrow_mut().clear());
472}
473
474#[cfg(any(test, feature = "canister-api-test-failpoints"))]
475fn hit_step_failpoint() -> Option<std::ffi::c_int> {
476    let Ok(context) = crate::stable::memory::active_context_id() else {
477        return None;
478    };
479    STEP_FAILPOINTS.with(|slot| {
480        let mut slot = slot.borrow_mut();
481        let state = slot.get_mut(&context)?;
482        state.count += 1;
483        if state.failpoint.ordinal == state.count {
484            let code = state.failpoint.code;
485            slot.remove(&context);
486            Some(code)
487        } else {
488            None
489        }
490    })
491}
492
493impl Drop for Statement<'_> {
494    fn drop(&mut self) {
495        unsafe {
496            ffi::sqlite3_finalize(self.raw.as_ptr());
497        }
498    }
499}