1use crate::db::connection::sqlite_error;
7use crate::db::row::Row;
8use crate::db::value::{bind_all, bind_named_all, ToSql};
9use crate::db::DbError;
10use crate::sqlite_vfs::ffi;
11use std::ptr::NonNull;
12
13#[cfg(any(test, feature = "canister-api-test-failpoints"))]
14use std::cell::RefCell;
15
16#[cfg(any(test, feature = "canister-api-test-failpoints"))]
17thread_local! {
18 static STEP_FAILPOINT: RefCell<Option<StepFailpoint>> = const { RefCell::new(None) };
19 static STEP_COUNT: RefCell<u64> = const { RefCell::new(0) };
20}
21
22#[cfg(any(test, feature = "canister-api-test-failpoints"))]
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24pub struct StepFailpoint {
25 pub ordinal: u64,
26 pub code: std::ffi::c_int,
27}
28
29pub struct Statement<'connection> {
30 db: *mut ffi::sqlite3,
31 raw: NonNull<ffi::sqlite3_stmt>,
32 _connection: std::marker::PhantomData<&'connection ()>,
33}
34
35pub struct Rows<'statement, 'connection> {
36 statement: &'statement mut Statement<'connection>,
37 done: bool,
38}
39
40impl<'connection> Statement<'connection> {
41 pub(crate) fn new(db: *mut ffi::sqlite3, raw: NonNull<ffi::sqlite3_stmt>) -> Self {
42 Self {
43 db,
44 raw,
45 _connection: std::marker::PhantomData,
46 }
47 }
48
49 pub fn execute(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
50 self.reset_and_bind(values)?;
51 let rc = step(self.raw.as_ptr())?;
52 if rc == ffi::SQLITE_DONE {
53 Ok(())
54 } else {
55 Err(sqlite_error(self.db, rc))
56 }
57 }
58
59 pub fn execute_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
60 self.reset_and_bind_named(values)?;
61 let rc = step(self.raw.as_ptr())?;
62 if rc == ffi::SQLITE_DONE {
63 Ok(())
64 } else {
65 Err(sqlite_error(self.db, rc))
66 }
67 }
68
69 pub fn execute_with_texts(&mut self, values: &[&str]) -> Result<(), DbError> {
70 let values = values
71 .iter()
72 .map(|value| value as &dyn ToSql)
73 .collect::<Vec<_>>();
74 self.execute(&values)
75 }
76
77 pub fn query<'statement>(
78 &'statement mut self,
79 values: &[&dyn ToSql],
80 ) -> Result<Rows<'statement, 'connection>, DbError> {
81 self.reset_and_bind(values)?;
82 Ok(Rows {
83 statement: self,
84 done: false,
85 })
86 }
87
88 pub fn query_named<'statement>(
89 &'statement mut self,
90 values: &[(&str, &dyn ToSql)],
91 ) -> Result<Rows<'statement, 'connection>, DbError> {
92 self.reset_and_bind_named(values)?;
93 Ok(Rows {
94 statement: self,
95 done: false,
96 })
97 }
98
99 pub fn query_one<T, F>(&mut self, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
100 where
101 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
102 {
103 let mut rows = self.query(values)?;
104 match rows.next_row()? {
105 Some(row) => f(&row),
106 None => Err(DbError::NotFound),
107 }
108 }
109
110 pub fn query_one_named<T, F>(
111 &mut self,
112 values: &[(&str, &dyn ToSql)],
113 f: F,
114 ) -> Result<T, DbError>
115 where
116 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
117 {
118 let mut rows = self.query_named(values)?;
119 match rows.next_row()? {
120 Some(row) => f(&row),
121 None => Err(DbError::NotFound),
122 }
123 }
124
125 pub fn query_optional<T, F>(
126 &mut self,
127 values: &[&dyn ToSql],
128 f: F,
129 ) -> Result<Option<T>, DbError>
130 where
131 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
132 {
133 let mut rows = self.query(values)?;
134 match rows.next_row()? {
135 Some(row) => f(&row).map(Some),
136 None => Ok(None),
137 }
138 }
139
140 pub fn query_optional_named<T, F>(
141 &mut self,
142 values: &[(&str, &dyn ToSql)],
143 f: F,
144 ) -> Result<Option<T>, DbError>
145 where
146 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
147 {
148 let mut rows = self.query_named(values)?;
149 match rows.next_row()? {
150 Some(row) => f(&row).map(Some),
151 None => Ok(None),
152 }
153 }
154
155 pub fn query_all<T, F>(&mut self, values: &[&dyn ToSql], mut f: F) -> Result<Vec<T>, DbError>
156 where
157 F: FnMut(&Row<'_>) -> Result<T, DbError>,
158 {
159 let mut rows = self.query(values)?;
160 let mut output = Vec::new();
161 while let Some(row) = rows.next_row()? {
162 output.push(f(&row)?);
163 }
164 Ok(output)
165 }
166
167 pub fn query_all_named<T, F>(
168 &mut self,
169 values: &[(&str, &dyn ToSql)],
170 mut f: F,
171 ) -> Result<Vec<T>, DbError>
172 where
173 F: FnMut(&Row<'_>) -> Result<T, DbError>,
174 {
175 let mut rows = self.query_named(values)?;
176 let mut output = Vec::new();
177 while let Some(row) = rows.next_row()? {
178 output.push(f(&row)?);
179 }
180 Ok(output)
181 }
182
183 pub fn query_optional_string_with_text(
184 &mut self,
185 value: &str,
186 ) -> Result<Option<String>, DbError> {
187 self.query_optional(&[&value], |row| row.get(0))
188 }
189
190 fn reset_and_bind(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
191 let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
192 if reset_rc != ffi::SQLITE_OK {
193 return Err(sqlite_error(self.db, reset_rc));
194 }
195 let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
196 if clear_rc != ffi::SQLITE_OK {
197 return Err(sqlite_error(self.db, clear_rc));
198 }
199 bind_all(self.raw.as_ptr(), values)
200 }
201
202 fn reset_and_bind_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
203 let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
204 if reset_rc != ffi::SQLITE_OK {
205 return Err(sqlite_error(self.db, reset_rc));
206 }
207 let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
208 if clear_rc != ffi::SQLITE_OK {
209 return Err(sqlite_error(self.db, clear_rc));
210 }
211 bind_named_all(self.raw.as_ptr(), values)
212 }
213}
214
215impl Rows<'_, '_> {
216 pub fn next_row(&mut self) -> Result<Option<Row<'_>>, DbError> {
217 if self.done {
218 return Ok(None);
219 }
220 let rc = step(self.statement.raw.as_ptr())?;
221 match rc {
222 ffi::SQLITE_ROW => Ok(Some(Row::new(self.statement.raw.as_ptr()))),
223 ffi::SQLITE_DONE => {
224 self.done = true;
225 Ok(None)
226 }
227 _ => Err(sqlite_error(self.statement.db, rc)),
228 }
229 }
230}
231
232fn step(statement: *mut ffi::sqlite3_stmt) -> Result<std::ffi::c_int, DbError> {
233 #[cfg(any(test, feature = "canister-api-test-failpoints"))]
234 if let Some(code) = hit_step_failpoint() {
235 return Err(DbError::Sqlite(code, "sqlite step failpoint".to_string()));
236 }
237 Ok(unsafe { ffi::sqlite3_step(statement) })
238}
239
240#[cfg(any(test, feature = "canister-api-test-failpoints"))]
241pub fn set_step_failpoint(failpoint: StepFailpoint) {
242 STEP_FAILPOINT.with(|slot| *slot.borrow_mut() = Some(failpoint));
243 STEP_COUNT.with(|count| *count.borrow_mut() = 0);
244}
245
246#[cfg(any(test, feature = "canister-api-test-failpoints"))]
247pub fn clear_step_failpoint() {
248 STEP_FAILPOINT.with(|slot| *slot.borrow_mut() = None);
249 STEP_COUNT.with(|count| *count.borrow_mut() = 0);
250}
251
252#[cfg(any(test, feature = "canister-api-test-failpoints"))]
253fn hit_step_failpoint() -> Option<std::ffi::c_int> {
254 STEP_COUNT.with(|count| {
255 let mut count = count.borrow_mut();
256 *count += 1;
257 let current = *count;
258 STEP_FAILPOINT.with(|slot| {
259 let mut slot = slot.borrow_mut();
260 let failpoint = *slot;
261 if failpoint.is_some_and(|value| value.ordinal == current) {
262 *slot = None;
263 failpoint.map(|value| value.code)
264 } else {
265 None
266 }
267 })
268 })
269}
270
271impl Drop for Statement<'_> {
272 fn drop(&mut self) {
273 unsafe {
274 ffi::sqlite3_finalize(self.raw.as_ptr());
275 }
276 }
277}