Skip to main content

ic_sqlite_vfs/db/
statement.rs

1//! Prepared statement lifecycle and row iteration.
2//!
3//! Each execution resets and clears bindings first. `Drop` finalizes the SQLite
4//! statement, so prepared statements can be reused without leaking C resources.
5
6use crate::db::connection::sqlite_error;
7use crate::db::row::Row;
8use crate::db::value::{bind_all, bind_named_all, ToSql};
9use crate::db::DbError;
10use crate::sqlite_vfs::ffi;
11use std::ptr::NonNull;
12
13#[cfg(any(test, feature = "canister-api-test-failpoints"))]
14use std::cell::RefCell;
15
16#[cfg(any(test, feature = "canister-api-test-failpoints"))]
17thread_local! {
18    static STEP_FAILPOINT: RefCell<Option<StepFailpoint>> = const { RefCell::new(None) };
19    static STEP_COUNT: RefCell<u64> = const { RefCell::new(0) };
20}
21
22#[cfg(any(test, feature = "canister-api-test-failpoints"))]
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24pub struct StepFailpoint {
25    pub ordinal: u64,
26    pub code: std::ffi::c_int,
27}
28
29pub struct Statement<'connection> {
30    db: *mut ffi::sqlite3,
31    raw: NonNull<ffi::sqlite3_stmt>,
32    _connection: std::marker::PhantomData<&'connection ()>,
33}
34
35pub struct Rows<'statement, 'connection> {
36    statement: &'statement mut Statement<'connection>,
37    done: bool,
38}
39
40impl<'connection> Statement<'connection> {
41    pub(crate) fn new(db: *mut ffi::sqlite3, raw: NonNull<ffi::sqlite3_stmt>) -> Self {
42        Self {
43            db,
44            raw,
45            _connection: std::marker::PhantomData,
46        }
47    }
48
49    pub fn execute(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
50        self.reset_and_bind(values)?;
51        let rc = step(self.raw.as_ptr())?;
52        if rc == ffi::SQLITE_DONE {
53            Ok(())
54        } else {
55            Err(sqlite_error(self.db, rc))
56        }
57    }
58
59    pub fn execute_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
60        self.reset_and_bind_named(values)?;
61        let rc = step(self.raw.as_ptr())?;
62        if rc == ffi::SQLITE_DONE {
63            Ok(())
64        } else {
65            Err(sqlite_error(self.db, rc))
66        }
67    }
68
69    pub fn execute_with_texts(&mut self, values: &[&str]) -> Result<(), DbError> {
70        let values = values
71            .iter()
72            .map(|value| value as &dyn ToSql)
73            .collect::<Vec<_>>();
74        self.execute(&values)
75    }
76
77    pub fn query<'statement>(
78        &'statement mut self,
79        values: &[&dyn ToSql],
80    ) -> Result<Rows<'statement, 'connection>, DbError> {
81        self.reset_and_bind(values)?;
82        Ok(Rows {
83            statement: self,
84            done: false,
85        })
86    }
87
88    pub fn query_named<'statement>(
89        &'statement mut self,
90        values: &[(&str, &dyn ToSql)],
91    ) -> Result<Rows<'statement, 'connection>, DbError> {
92        self.reset_and_bind_named(values)?;
93        Ok(Rows {
94            statement: self,
95            done: false,
96        })
97    }
98
99    pub fn query_one<T, F>(&mut self, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
100    where
101        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
102    {
103        let mut rows = self.query(values)?;
104        match rows.next_row()? {
105            Some(row) => f(&row),
106            None => Err(DbError::NotFound),
107        }
108    }
109
110    pub fn query_one_named<T, F>(
111        &mut self,
112        values: &[(&str, &dyn ToSql)],
113        f: F,
114    ) -> Result<T, DbError>
115    where
116        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
117    {
118        let mut rows = self.query_named(values)?;
119        match rows.next_row()? {
120            Some(row) => f(&row),
121            None => Err(DbError::NotFound),
122        }
123    }
124
125    pub fn query_optional<T, F>(
126        &mut self,
127        values: &[&dyn ToSql],
128        f: F,
129    ) -> Result<Option<T>, DbError>
130    where
131        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
132    {
133        let mut rows = self.query(values)?;
134        match rows.next_row()? {
135            Some(row) => f(&row).map(Some),
136            None => Ok(None),
137        }
138    }
139
140    pub fn query_optional_named<T, F>(
141        &mut self,
142        values: &[(&str, &dyn ToSql)],
143        f: F,
144    ) -> Result<Option<T>, DbError>
145    where
146        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
147    {
148        let mut rows = self.query_named(values)?;
149        match rows.next_row()? {
150            Some(row) => f(&row).map(Some),
151            None => Ok(None),
152        }
153    }
154
155    pub fn query_all<T, F>(&mut self, values: &[&dyn ToSql], mut f: F) -> Result<Vec<T>, DbError>
156    where
157        F: FnMut(&Row<'_>) -> Result<T, DbError>,
158    {
159        let mut rows = self.query(values)?;
160        let mut output = Vec::new();
161        while let Some(row) = rows.next_row()? {
162            output.push(f(&row)?);
163        }
164        Ok(output)
165    }
166
167    pub fn query_all_named<T, F>(
168        &mut self,
169        values: &[(&str, &dyn ToSql)],
170        mut f: F,
171    ) -> Result<Vec<T>, DbError>
172    where
173        F: FnMut(&Row<'_>) -> Result<T, DbError>,
174    {
175        let mut rows = self.query_named(values)?;
176        let mut output = Vec::new();
177        while let Some(row) = rows.next_row()? {
178            output.push(f(&row)?);
179        }
180        Ok(output)
181    }
182
183    pub fn query_optional_string_with_text(
184        &mut self,
185        value: &str,
186    ) -> Result<Option<String>, DbError> {
187        self.query_optional(&[&value], |row| row.get(0))
188    }
189
190    fn reset_and_bind(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
191        let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
192        if reset_rc != ffi::SQLITE_OK {
193            return Err(sqlite_error(self.db, reset_rc));
194        }
195        let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
196        if clear_rc != ffi::SQLITE_OK {
197            return Err(sqlite_error(self.db, clear_rc));
198        }
199        bind_all(self.raw.as_ptr(), values)
200    }
201
202    fn reset_and_bind_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
203        let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
204        if reset_rc != ffi::SQLITE_OK {
205            return Err(sqlite_error(self.db, reset_rc));
206        }
207        let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
208        if clear_rc != ffi::SQLITE_OK {
209            return Err(sqlite_error(self.db, clear_rc));
210        }
211        bind_named_all(self.raw.as_ptr(), values)
212    }
213}
214
215impl Rows<'_, '_> {
216    pub fn next_row(&mut self) -> Result<Option<Row<'_>>, DbError> {
217        if self.done {
218            return Ok(None);
219        }
220        let rc = step(self.statement.raw.as_ptr())?;
221        match rc {
222            ffi::SQLITE_ROW => Ok(Some(Row::new(self.statement.raw.as_ptr()))),
223            ffi::SQLITE_DONE => {
224                self.done = true;
225                Ok(None)
226            }
227            _ => Err(sqlite_error(self.statement.db, rc)),
228        }
229    }
230}
231
232fn step(statement: *mut ffi::sqlite3_stmt) -> Result<std::ffi::c_int, DbError> {
233    #[cfg(any(test, feature = "canister-api-test-failpoints"))]
234    if let Some(code) = hit_step_failpoint() {
235        return Err(DbError::Sqlite(code, "sqlite step failpoint".to_string()));
236    }
237    Ok(unsafe { ffi::sqlite3_step(statement) })
238}
239
240#[cfg(any(test, feature = "canister-api-test-failpoints"))]
241pub fn set_step_failpoint(failpoint: StepFailpoint) {
242    STEP_FAILPOINT.with(|slot| *slot.borrow_mut() = Some(failpoint));
243    STEP_COUNT.with(|count| *count.borrow_mut() = 0);
244}
245
246#[cfg(any(test, feature = "canister-api-test-failpoints"))]
247pub fn clear_step_failpoint() {
248    STEP_FAILPOINT.with(|slot| *slot.borrow_mut() = None);
249    STEP_COUNT.with(|count| *count.borrow_mut() = 0);
250}
251
252#[cfg(any(test, feature = "canister-api-test-failpoints"))]
253fn hit_step_failpoint() -> Option<std::ffi::c_int> {
254    STEP_COUNT.with(|count| {
255        let mut count = count.borrow_mut();
256        *count += 1;
257        let current = *count;
258        STEP_FAILPOINT.with(|slot| {
259            let mut slot = slot.borrow_mut();
260            let failpoint = *slot;
261            if failpoint.is_some_and(|value| value.ordinal == current) {
262                *slot = None;
263                failpoint.map(|value| value.code)
264            } else {
265                None
266            }
267        })
268    })
269}
270
271impl Drop for Statement<'_> {
272    fn drop(&mut self) {
273        unsafe {
274            ffi::sqlite3_finalize(self.raw.as_ptr());
275        }
276    }
277}