Skip to main content

ic_sqlite_vfs/db/
statement.rs

1//! Prepared statement lifecycle and row iteration.
2//!
3//! Each execution resets and rebinds the statement. `Drop` finalizes regular
4//! statements, while cached statements clear bindings before returning to cache.
5
6use crate::db::connection::sqlite_error;
7use crate::db::row::{FromColumn, Row};
8use crate::db::value::{bind_all, bind_named_all, ToSql};
9use crate::db::DbError;
10use crate::sqlite_vfs::ffi;
11use std::ptr::NonNull;
12
13#[cfg(any(test, feature = "canister-api-test-failpoints"))]
14use std::cell::RefCell;
15#[cfg(any(test, feature = "canister-api-test-failpoints"))]
16use std::collections::BTreeMap;
17
18#[cfg(any(test, feature = "canister-api-test-failpoints"))]
19thread_local! {
20    static STEP_FAILPOINTS: RefCell<BTreeMap<crate::stable::memory::ContextId, StepFailpointState>> = const { RefCell::new(BTreeMap::new()) };
21}
22
23#[cfg(any(test, feature = "canister-api-test-failpoints"))]
24#[derive(Clone, Copy, Debug, Eq, PartialEq)]
25pub struct StepFailpoint {
26    pub ordinal: u64,
27    pub code: std::ffi::c_int,
28}
29
30#[cfg(any(test, feature = "canister-api-test-failpoints"))]
31#[derive(Clone, Copy, Debug)]
32struct StepFailpointState {
33    failpoint: StepFailpoint,
34    count: u64,
35}
36
37pub struct Statement<'connection> {
38    db: *mut ffi::sqlite3,
39    raw: NonNull<ffi::sqlite3_stmt>,
40    parameter_count: usize,
41    _connection: std::marker::PhantomData<&'connection ()>,
42}
43
44pub struct Rows<'statement, 'connection> {
45    statement: &'statement mut Statement<'connection>,
46    done: bool,
47}
48
49#[cfg(feature = "bench-profile")]
50#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
51pub struct QueryOptionalStringTextProfile {
52    pub reset_bind: u64,
53    pub step: u64,
54    pub column_read: u64,
55}
56
57impl<'connection> Statement<'connection> {
58    pub(crate) fn new(db: *mut ffi::sqlite3, raw: NonNull<ffi::sqlite3_stmt>) -> Self {
59        let parameter_count =
60            usize::try_from(unsafe { ffi::sqlite3_bind_parameter_count(raw.as_ptr()) })
61                .unwrap_or(0);
62        Self {
63            db,
64            raw,
65            parameter_count,
66            _connection: std::marker::PhantomData,
67        }
68    }
69
70    pub(crate) fn into_raw(self) -> NonNull<ffi::sqlite3_stmt> {
71        let raw = self.raw;
72        std::mem::forget(self);
73        raw
74    }
75
76    pub fn execute(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
77        self.reset_and_bind(values)?;
78        let rc = step(self.raw.as_ptr())?;
79        if rc == ffi::SQLITE_DONE {
80            Ok(())
81        } else {
82            Err(sqlite_error(self.db, rc))
83        }
84    }
85
86    pub fn execute_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
87        self.reset_and_bind_named(values)?;
88        let rc = step(self.raw.as_ptr())?;
89        if rc == ffi::SQLITE_DONE {
90            Ok(())
91        } else {
92            Err(sqlite_error(self.db, rc))
93        }
94    }
95
96    pub fn query<'statement>(
97        &'statement mut self,
98        values: &[&dyn ToSql],
99    ) -> Result<Rows<'statement, 'connection>, DbError> {
100        self.reset_and_bind(values)?;
101        Ok(Rows {
102            statement: self,
103            done: false,
104        })
105    }
106
107    pub fn query_named<'statement>(
108        &'statement mut self,
109        values: &[(&str, &dyn ToSql)],
110    ) -> Result<Rows<'statement, 'connection>, DbError> {
111        self.reset_and_bind_named(values)?;
112        Ok(Rows {
113            statement: self,
114            done: false,
115        })
116    }
117
118    pub fn query_one<T, F>(&mut self, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
119    where
120        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
121    {
122        let mut rows = self.query(values)?;
123        match rows.next_row()? {
124            Some(row) => f(&row),
125            None => Err(DbError::NotFound),
126        }
127    }
128
129    pub fn query_one_named<T, F>(
130        &mut self,
131        values: &[(&str, &dyn ToSql)],
132        f: F,
133    ) -> Result<T, DbError>
134    where
135        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
136    {
137        let mut rows = self.query_named(values)?;
138        match rows.next_row()? {
139            Some(row) => f(&row),
140            None => Err(DbError::NotFound),
141        }
142    }
143
144    pub fn query_optional<T, F>(
145        &mut self,
146        values: &[&dyn ToSql],
147        f: F,
148    ) -> Result<Option<T>, DbError>
149    where
150        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
151    {
152        let mut rows = self.query(values)?;
153        match rows.next_row()? {
154            Some(row) => f(&row).map(Some),
155            None => Ok(None),
156        }
157    }
158
159    pub fn query_optional_named<T, F>(
160        &mut self,
161        values: &[(&str, &dyn ToSql)],
162        f: F,
163    ) -> Result<Option<T>, DbError>
164    where
165        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
166    {
167        let mut rows = self.query_named(values)?;
168        match rows.next_row()? {
169            Some(row) => f(&row).map(Some),
170            None => Ok(None),
171        }
172    }
173
174    pub fn query_all<T, F>(&mut self, values: &[&dyn ToSql], mut f: F) -> Result<Vec<T>, DbError>
175    where
176        F: FnMut(&Row<'_>) -> Result<T, DbError>,
177    {
178        let mut rows = self.query(values)?;
179        let mut output = Vec::new();
180        while let Some(row) = rows.next_row()? {
181            output.push(f(&row)?);
182        }
183        Ok(output)
184    }
185
186    pub fn query_all_named<T, F>(
187        &mut self,
188        values: &[(&str, &dyn ToSql)],
189        mut f: F,
190    ) -> Result<Vec<T>, DbError>
191    where
192        F: FnMut(&Row<'_>) -> Result<T, DbError>,
193    {
194        let mut rows = self.query_named(values)?;
195        let mut output = Vec::new();
196        while let Some(row) = rows.next_row()? {
197            output.push(f(&row)?);
198        }
199        Ok(output)
200    }
201
202    pub fn query_scalar<T: FromColumn>(&mut self, values: &[&dyn ToSql]) -> Result<T, DbError> {
203        self.query_one(values, |row| row.get(0))
204    }
205    pub fn query_scalar_named<T: FromColumn>(
206        &mut self,
207        values: &[(&str, &dyn ToSql)],
208    ) -> Result<T, DbError> {
209        self.query_one_named(values, |row| row.get(0))
210    }
211
212    pub fn query_optional_scalar<T: FromColumn>(
213        &mut self,
214        values: &[&dyn ToSql],
215    ) -> Result<Option<T>, DbError> {
216        self.query_optional(values, |row| row.get(0))
217    }
218
219    pub fn query_optional_string_text(&mut self, value: &str) -> Result<Option<String>, DbError> {
220        self.reset_and_bind_single_text(value)?;
221        match step(self.raw.as_ptr())? {
222            ffi::SQLITE_ROW => read_string_column_zero(self.raw.as_ptr()).map(Some),
223            ffi::SQLITE_DONE => Ok(None),
224            rc => Err(sqlite_error(self.db, rc)),
225        }
226    }
227
228    pub fn query_optional_string_text_len(
229        &mut self,
230        value: &str,
231    ) -> Result<Option<usize>, DbError> {
232        self.reset_and_bind_single_text(value)?;
233        match step(self.raw.as_ptr())? {
234            ffi::SQLITE_ROW => read_string_column_zero_len(self.raw.as_ptr()).map(Some),
235            ffi::SQLITE_DONE => Ok(None),
236            rc => Err(sqlite_error(self.db, rc)),
237        }
238    }
239
240    #[cfg(feature = "bench-profile")]
241    #[doc(hidden)]
242    pub fn query_optional_string_text_profiled(
243        &mut self,
244        value: &str,
245    ) -> Result<(Option<String>, QueryOptionalStringTextProfile), DbError> {
246        let mut profile = QueryOptionalStringTextProfile::default();
247
248        let start = instruction_counter();
249        self.reset_and_bind_single_text(value)?;
250        profile.reset_bind = instruction_counter().saturating_sub(start);
251
252        let start = instruction_counter();
253        let rc = step(self.raw.as_ptr())?;
254        profile.step = instruction_counter().saturating_sub(start);
255
256        match rc {
257            ffi::SQLITE_ROW => {
258                let start = instruction_counter();
259                let value = read_string_column_zero(self.raw.as_ptr()).map(Some);
260                profile.column_read = instruction_counter().saturating_sub(start);
261                value.map(|value| (value, profile))
262            }
263            ffi::SQLITE_DONE => Ok((None, profile)),
264            rc => Err(sqlite_error(self.db, rc)),
265        }
266    }
267
268    #[cfg(feature = "bench-profile")]
269    #[doc(hidden)]
270    pub fn query_optional_string_text_len_profiled(
271        &mut self,
272        value: &str,
273    ) -> Result<(Option<usize>, QueryOptionalStringTextProfile), DbError> {
274        let mut profile = QueryOptionalStringTextProfile::default();
275
276        let start = instruction_counter();
277        self.reset_and_bind_single_text(value)?;
278        profile.reset_bind = instruction_counter().saturating_sub(start);
279
280        let start = instruction_counter();
281        let rc = step(self.raw.as_ptr())?;
282        profile.step = instruction_counter().saturating_sub(start);
283
284        match rc {
285            ffi::SQLITE_ROW => {
286                let start = instruction_counter();
287                let value = read_string_column_zero_len(self.raw.as_ptr()).map(Some);
288                profile.column_read = instruction_counter().saturating_sub(start);
289                value.map(|value| (value, profile))
290            }
291            ffi::SQLITE_DONE => Ok((None, profile)),
292            rc => Err(sqlite_error(self.db, rc)),
293        }
294    }
295
296    pub fn query_optional_scalar_named<T: FromColumn>(
297        &mut self,
298        values: &[(&str, &dyn ToSql)],
299    ) -> Result<Option<T>, DbError> {
300        self.query_optional_named(values, |row| row.get(0))
301    }
302
303    pub fn query_column<T: FromColumn>(
304        &mut self,
305        values: &[&dyn ToSql],
306    ) -> Result<Vec<T>, DbError> {
307        self.query_all(values, |row| row.get(0))
308    }
309
310    pub fn query_column_named<T: FromColumn>(
311        &mut self,
312        values: &[(&str, &dyn ToSql)],
313    ) -> Result<Vec<T>, DbError> {
314        self.query_all_named(values, |row| row.get(0))
315    }
316
317    fn reset_and_bind(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
318        let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
319        if reset_rc != ffi::SQLITE_OK {
320            return Err(sqlite_error(self.db, reset_rc));
321        }
322        bind_all(self.raw.as_ptr(), values)
323    }
324
325    fn reset_and_bind_single_text(&mut self, value: &str) -> Result<(), DbError> {
326        if self.parameter_count != 1 {
327            return Err(DbError::ParameterCountMismatch {
328                expected: self.parameter_count,
329                actual: 1,
330            });
331        }
332        let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
333        if reset_rc != ffi::SQLITE_OK {
334            return Err(sqlite_error(self.db, reset_rc));
335        }
336        let len = std::ffi::c_int::try_from(value.len()).map_err(|_| DbError::TextTooLarge)?;
337        let bind_rc = unsafe {
338            ffi::sqlite3_bind_text(
339                self.raw.as_ptr(),
340                1,
341                value.as_ptr().cast(),
342                len,
343                ffi::SQLITE_TRANSIENT(),
344            )
345        };
346        if bind_rc == ffi::SQLITE_OK {
347            Ok(())
348        } else {
349            Err(DbError::Sqlite(bind_rc, "sqlite bind failed".to_string()))
350        }
351    }
352
353    fn reset_and_bind_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
354        let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
355        if reset_rc != ffi::SQLITE_OK {
356            return Err(sqlite_error(self.db, reset_rc));
357        }
358        bind_named_all(self.raw.as_ptr(), values)
359    }
360}
361
362fn read_string_column_zero(statement: *mut ffi::sqlite3_stmt) -> Result<String, DbError> {
363    let actual = unsafe { ffi::sqlite3_column_type(statement, 0) };
364    if actual != ffi::SQLITE_TEXT {
365        return Err(DbError::TypeMismatch {
366            index: 0,
367            expected: "TEXT",
368            actual: sqlite_type_name(actual),
369        });
370    }
371    let text = unsafe { ffi::sqlite3_column_text(statement, 0) };
372    let len = unsafe { ffi::sqlite3_column_bytes(statement, 0) };
373    let len = usize::try_from(len).map_err(|_| DbError::TextTooLarge)?;
374    if len == 0 || text.is_null() {
375        return Ok(String::new());
376    }
377    let bytes = unsafe { std::slice::from_raw_parts(text.cast::<u8>(), len) };
378    Ok(String::from_utf8_lossy(bytes).into_owned())
379}
380
381fn read_string_column_zero_len(statement: *mut ffi::sqlite3_stmt) -> Result<usize, DbError> {
382    let actual = unsafe { ffi::sqlite3_column_type(statement, 0) };
383    if actual != ffi::SQLITE_TEXT {
384        return Err(DbError::TypeMismatch {
385            index: 0,
386            expected: "TEXT",
387            actual: sqlite_type_name(actual),
388        });
389    }
390    let len = unsafe { ffi::sqlite3_column_bytes(statement, 0) };
391    usize::try_from(len).map_err(|_| DbError::TextTooLarge)
392}
393
394fn sqlite_type_name(code: std::ffi::c_int) -> &'static str {
395    match code {
396        ffi::SQLITE_INTEGER => "INTEGER",
397        ffi::SQLITE_FLOAT => "REAL",
398        ffi::SQLITE_TEXT => "TEXT",
399        ffi::SQLITE_BLOB => "BLOB",
400        ffi::SQLITE_NULL => "NULL",
401        _ => "UNKNOWN",
402    }
403}
404
405impl Rows<'_, '_> {
406    pub fn next_row(&mut self) -> Result<Option<Row<'_>>, DbError> {
407        if self.done {
408            return Ok(None);
409        }
410        let rc = step(self.statement.raw.as_ptr())?;
411        match rc {
412            ffi::SQLITE_ROW => Ok(Some(Row::new(self.statement.raw.as_ptr()))),
413            ffi::SQLITE_DONE => {
414                self.done = true;
415                Ok(None)
416            }
417            _ => Err(sqlite_error(self.statement.db, rc)),
418        }
419    }
420}
421
422fn step(statement: *mut ffi::sqlite3_stmt) -> Result<std::ffi::c_int, DbError> {
423    #[cfg(any(test, feature = "canister-api-test-failpoints"))]
424    if let Some(code) = hit_step_failpoint() {
425        return Err(DbError::Sqlite(code, "sqlite step failpoint".to_string()));
426    }
427    Ok(unsafe { ffi::sqlite3_step(statement) })
428}
429
430#[cfg(feature = "bench-profile")]
431fn instruction_counter() -> u64 {
432    #[cfg(target_arch = "wasm32")]
433    {
434        ic_cdk::api::performance_counter(0)
435    }
436    #[cfg(not(target_arch = "wasm32"))]
437    {
438        0
439    }
440}
441
442#[cfg(any(test, feature = "canister-api-test-failpoints"))]
443pub fn set_step_failpoint(failpoint: StepFailpoint) {
444    if let Ok(context) = crate::stable::memory::active_context_id() {
445        STEP_FAILPOINTS.with(|slot| {
446            slot.borrow_mut().insert(
447                context,
448                StepFailpointState {
449                    failpoint,
450                    count: 0,
451                },
452            );
453        });
454    }
455}
456
457#[cfg(any(test, feature = "canister-api-test-failpoints"))]
458pub fn clear_step_failpoint() {
459    STEP_FAILPOINTS.with(|slot| slot.borrow_mut().clear());
460}
461
462#[cfg(any(test, feature = "canister-api-test-failpoints"))]
463fn hit_step_failpoint() -> Option<std::ffi::c_int> {
464    let Ok(context) = crate::stable::memory::active_context_id() else {
465        return None;
466    };
467    STEP_FAILPOINTS.with(|slot| {
468        let mut slot = slot.borrow_mut();
469        let state = slot.get_mut(&context)?;
470        state.count += 1;
471        if state.failpoint.ordinal == state.count {
472            let code = state.failpoint.code;
473            slot.remove(&context);
474            Some(code)
475        } else {
476            None
477        }
478    })
479}
480
481impl Drop for Statement<'_> {
482    fn drop(&mut self) {
483        unsafe {
484            ffi::sqlite3_finalize(self.raw.as_ptr());
485        }
486    }
487}