Skip to main content

ic_sqlite_vfs/db/
connection.rs

1//! Thin SQLite C connection wrapper bound to the `icstable` VFS.
2//!
3//! `rusqlite` refuses `SQLITE_THREADSAFE=0`, so this crate keeps a small FFI
4//! facade. Write connections are per-message; read-only connections may be
5//! reused inside one context cache.
6
7use crate::config::{SQLITE_URI, STATEMENT_CACHE_CAPACITY, VFS_NAME};
8use crate::db::row::{FromColumn, Row};
9use crate::db::statement::Statement;
10use crate::db::value::ToSql;
11use crate::db::{pragmas, DbError};
12use crate::sqlite_vfs::ffi;
13use std::cell::RefCell;
14use std::collections::{BTreeMap, VecDeque};
15use std::ffi::{c_char, c_int, c_void, CStr, CString};
16use std::ops::{Deref, DerefMut};
17use std::ptr::{self, NonNull};
18
19pub struct Connection {
20    raw: NonNull<ffi::sqlite3>,
21    cached: RefCell<StatementCache>,
22}
23
24pub struct CachedStatement<'connection> {
25    statement: Option<Statement<'connection>>,
26    sql: String,
27    cache: &'connection RefCell<StatementCache>,
28}
29
30struct StatementCache {
31    statements: BTreeMap<String, NonNull<ffi::sqlite3_stmt>>,
32    returned_lru: VecDeque<String>,
33}
34
35impl StatementCache {
36    fn new() -> Self {
37        Self {
38            statements: BTreeMap::new(),
39            returned_lru: VecDeque::new(),
40        }
41    }
42
43    fn take(&mut self, sql: &str) -> Option<NonNull<ffi::sqlite3_stmt>> {
44        let raw = self.statements.remove(sql)?;
45        self.returned_lru.retain(|cached_sql| cached_sql != sql);
46        Some(raw)
47    }
48
49    unsafe fn insert(&mut self, sql: String, raw: NonNull<ffi::sqlite3_stmt>) {
50        if let Some(previous) = self.statements.insert(sql.clone(), raw) {
51            ffi::sqlite3_finalize(previous.as_ptr());
52        }
53        self.returned_lru.retain(|cached_sql| cached_sql != &sql);
54        self.returned_lru.push_back(sql);
55        self.evict_over_capacity();
56    }
57
58    unsafe fn evict_over_capacity(&mut self) {
59        while self.statements.len() > STATEMENT_CACHE_CAPACITY {
60            let Some(sql) = self.returned_lru.pop_front() else {
61                return;
62            };
63            if let Some(statement) = self.statements.remove(&sql) {
64                ffi::sqlite3_finalize(statement.as_ptr());
65            }
66        }
67    }
68
69    unsafe fn finalize_all(&mut self) {
70        for (_, statement) in std::mem::take(&mut self.statements) {
71            ffi::sqlite3_finalize(statement.as_ptr());
72        }
73        self.returned_lru.clear();
74    }
75}
76
77pub fn open_read_write() -> Result<Connection, DbError> {
78    let flags = ffi::SQLITE_OPEN_READWRITE
79        | ffi::SQLITE_OPEN_CREATE
80        | ffi::SQLITE_OPEN_URI
81        | ffi::SQLITE_OPEN_NOMUTEX;
82    let connection = Connection::open(flags)?;
83    pragmas::apply_read_write(&connection)?;
84    Ok(connection)
85}
86
87pub fn open_read_only() -> Result<Connection, DbError> {
88    let flags = ffi::SQLITE_OPEN_READONLY | ffi::SQLITE_OPEN_URI | ffi::SQLITE_OPEN_NOMUTEX;
89    let connection = Connection::open(flags)?;
90    pragmas::apply_read_only(&connection)?;
91    Ok(connection)
92}
93
94impl Connection {
95    fn open(flags: c_int) -> Result<Self, DbError> {
96        let filename = CString::new(SQLITE_URI).map_err(|_| DbError::InteriorNul)?;
97        let vfs = CString::new(VFS_NAME).map_err(|_| DbError::InteriorNul)?;
98        let mut db = ptr::null_mut();
99        let rc = unsafe { ffi::sqlite3_open_v2(filename.as_ptr(), &mut db, flags, vfs.as_ptr()) };
100        let Some(raw) = NonNull::new(db) else {
101            return Err(DbError::Sqlite(
102                rc,
103                "sqlite3_open_v2 returned null".to_string(),
104            ));
105        };
106        if rc != ffi::SQLITE_OK {
107            let error = sqlite_error(raw.as_ptr(), rc);
108            unsafe {
109                ffi::sqlite3_close(raw.as_ptr());
110            }
111            return Err(error);
112        }
113        Ok(Self {
114            raw,
115            cached: RefCell::new(StatementCache::new()),
116        })
117    }
118
119    pub fn raw(&self) -> *mut ffi::sqlite3 {
120        self.raw.as_ptr()
121    }
122
123    pub fn execute_batch(&self, sql: &str) -> Result<(), DbError> {
124        let sql = CString::new(sql).map_err(|_| DbError::InteriorNul)?;
125        let mut error = ptr::null_mut();
126        let rc = unsafe {
127            ffi::sqlite3_exec(
128                self.raw.as_ptr(),
129                sql.as_ptr(),
130                None,
131                ptr::null_mut(),
132                &mut error,
133            )
134        };
135        if rc == ffi::SQLITE_OK {
136            return Ok(());
137        }
138        Err(classify_sqlite_error(rc, take_error_message(error)))
139    }
140
141    pub fn execute(&self, sql: &str, values: &[&dyn ToSql]) -> Result<(), DbError> {
142        let mut statement = self.prepare(sql)?;
143        statement.execute(values)
144    }
145
146    pub fn execute_named(&self, sql: &str, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
147        let mut statement = self.prepare(sql)?;
148        statement.execute_named(values)
149    }
150
151    pub fn prepare(&self, sql: &str) -> Result<Statement<'_>, DbError> {
152        let sql = CString::new(sql).map_err(|_| DbError::InteriorNul)?;
153        let mut statement = ptr::null_mut();
154        let mut tail = ptr::null();
155        let rc = unsafe {
156            ffi::sqlite3_prepare_v2(
157                self.raw.as_ptr(),
158                sql.as_ptr(),
159                -1,
160                &mut statement,
161                &mut tail,
162            )
163        };
164        if rc != ffi::SQLITE_OK {
165            return Err(sqlite_error(self.raw.as_ptr(), rc));
166        }
167        let Some(raw) = NonNull::new(statement) else {
168            return Err(DbError::EmptySql);
169        };
170        if !tail_is_empty(tail) {
171            unsafe {
172                ffi::sqlite3_finalize(raw.as_ptr());
173            }
174            return Err(DbError::TrailingSql);
175        }
176        Ok(Statement::new(self.raw.as_ptr(), raw))
177    }
178
179    pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>, DbError> {
180        if let Some(raw) = self.cached.borrow_mut().take(sql) {
181            return Ok(CachedStatement::new(
182                Statement::new(self.raw.as_ptr(), raw),
183                sql.to_string(),
184                &self.cached,
185            ));
186        }
187        let statement = self.prepare(sql)?;
188        Ok(CachedStatement::new(
189            statement,
190            sql.to_string(),
191            &self.cached,
192        ))
193    }
194
195    pub fn query_one<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
196    where
197        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
198    {
199        let mut statement = self.prepare(sql)?;
200        statement.query_one(values, f)
201    }
202
203    pub fn query_one_named<T, F>(
204        &self,
205        sql: &str,
206        values: &[(&str, &dyn ToSql)],
207        f: F,
208    ) -> Result<T, DbError>
209    where
210        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
211    {
212        let mut statement = self.prepare(sql)?;
213        statement.query_one_named(values, f)
214    }
215
216    /// Runs a single-row query.
217    ///
218    /// This is a `rusqlite`-style alias for [`Connection::query_one`].
219    pub fn query_row<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
220    where
221        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
222    {
223        self.query_one(sql, values, f)
224    }
225
226    /// Runs a single-row query with named parameters.
227    ///
228    /// This is a `rusqlite`-style alias for [`Connection::query_one_named`].
229    pub fn query_row_named<T, F>(
230        &self,
231        sql: &str,
232        values: &[(&str, &dyn ToSql)],
233        f: F,
234    ) -> Result<T, DbError>
235    where
236        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
237    {
238        self.query_one_named(sql, values, f)
239    }
240
241    pub fn query_optional<T, F>(
242        &self,
243        sql: &str,
244        values: &[&dyn ToSql],
245        f: F,
246    ) -> Result<Option<T>, DbError>
247    where
248        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
249    {
250        let mut statement = self.prepare(sql)?;
251        statement.query_optional(values, f)
252    }
253
254    pub fn query_optional_named<T, F>(
255        &self,
256        sql: &str,
257        values: &[(&str, &dyn ToSql)],
258        f: F,
259    ) -> Result<Option<T>, DbError>
260    where
261        F: FnOnce(&Row<'_>) -> Result<T, DbError>,
262    {
263        let mut statement = self.prepare(sql)?;
264        statement.query_optional_named(values, f)
265    }
266
267    pub fn query_all<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<Vec<T>, DbError>
268    where
269        F: FnMut(&Row<'_>) -> Result<T, DbError>,
270    {
271        let mut statement = self.prepare(sql)?;
272        statement.query_all(values, f)
273    }
274
275    pub fn query_all_named<T, F>(
276        &self,
277        sql: &str,
278        values: &[(&str, &dyn ToSql)],
279        f: F,
280    ) -> Result<Vec<T>, DbError>
281    where
282        F: FnMut(&Row<'_>) -> Result<T, DbError>,
283    {
284        let mut statement = self.prepare(sql)?;
285        statement.query_all_named(values, f)
286    }
287
288    /// Maps all rows into a `Vec<T>`.
289    ///
290    /// Unlike `rusqlite::Statement::query_map`, this returns a collected
291    /// `Vec<T>`, not an iterator. That keeps the prepared statement lifetime
292    /// inside one synchronous canister message.
293    pub fn query_map<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<Vec<T>, DbError>
294    where
295        F: FnMut(&Row<'_>) -> Result<T, DbError>,
296    {
297        self.query_all(sql, values, f)
298    }
299
300    /// Maps all rows into a `Vec<T>` using named parameters.
301    ///
302    /// Unlike `rusqlite::Statement::query_map`, this returns a collected
303    /// `Vec<T>`, not an iterator. That keeps the prepared statement lifetime
304    /// inside one synchronous canister message.
305    pub fn query_map_named<T, F>(
306        &self,
307        sql: &str,
308        values: &[(&str, &dyn ToSql)],
309        f: F,
310    ) -> Result<Vec<T>, DbError>
311    where
312        F: FnMut(&Row<'_>) -> Result<T, DbError>,
313    {
314        self.query_all_named(sql, values, f)
315    }
316
317    pub fn exists(&self, sql: &str, values: &[&dyn ToSql]) -> Result<bool, DbError> {
318        self.query_optional(sql, values, |row| row.get::<i64>(0))
319            .map(|value| value.unwrap_or(0) != 0)
320    }
321
322    pub fn query_scalar<T: FromColumn>(
323        &self,
324        sql: &str,
325        values: &[&dyn ToSql],
326    ) -> Result<T, DbError> {
327        self.query_one(sql, values, |row| row.get(0))
328    }
329
330    pub fn query_scalar_named<T: FromColumn>(
331        &self,
332        sql: &str,
333        values: &[(&str, &dyn ToSql)],
334    ) -> Result<T, DbError> {
335        self.query_one_named(sql, values, |row| row.get(0))
336    }
337
338    pub fn query_optional_scalar<T: FromColumn>(
339        &self,
340        sql: &str,
341        values: &[&dyn ToSql],
342    ) -> Result<Option<T>, DbError> {
343        self.query_optional(sql, values, |row| row.get(0))
344    }
345
346    pub fn query_optional_string_text(
347        &self,
348        sql: &str,
349        value: &str,
350    ) -> Result<Option<String>, DbError> {
351        let mut statement = self.prepare(sql)?;
352        statement.query_optional_string_text(value)
353    }
354
355    pub fn query_optional_scalar_named<T: FromColumn>(
356        &self,
357        sql: &str,
358        values: &[(&str, &dyn ToSql)],
359    ) -> Result<Option<T>, DbError> {
360        self.query_optional_named(sql, values, |row| row.get(0))
361    }
362
363    pub fn query_column<T: FromColumn>(
364        &self,
365        sql: &str,
366        values: &[&dyn ToSql],
367    ) -> Result<Vec<T>, DbError> {
368        self.query_all(sql, values, |row| row.get(0))
369    }
370
371    pub fn query_column_named<T: FromColumn>(
372        &self,
373        sql: &str,
374        values: &[(&str, &dyn ToSql)],
375    ) -> Result<Vec<T>, DbError> {
376        self.query_all_named(sql, values, |row| row.get(0))
377    }
378}
379
380impl Drop for Connection {
381    fn drop(&mut self) {
382        unsafe {
383            self.cached.get_mut().finalize_all();
384            ffi::sqlite3_close(self.raw.as_ptr());
385        }
386    }
387}
388
389impl<'connection> CachedStatement<'connection> {
390    fn new(
391        statement: Statement<'connection>,
392        sql: String,
393        cache: &'connection RefCell<StatementCache>,
394    ) -> Self {
395        Self {
396            statement: Some(statement),
397            sql,
398            cache,
399        }
400    }
401
402    pub fn discard(mut self) {
403        if let Some(statement) = self.statement.take() {
404            unsafe {
405                ffi::sqlite3_finalize(statement.into_raw().as_ptr());
406            }
407        }
408    }
409}
410
411impl<'connection> Deref for CachedStatement<'connection> {
412    type Target = Statement<'connection>;
413
414    fn deref(&self) -> &Self::Target {
415        self.statement
416            .as_ref()
417            .expect("cached statement is present")
418    }
419}
420
421impl DerefMut for CachedStatement<'_> {
422    fn deref_mut(&mut self) -> &mut Self::Target {
423        self.statement
424            .as_mut()
425            .expect("cached statement is present")
426    }
427}
428
429impl Drop for CachedStatement<'_> {
430    fn drop(&mut self) {
431        let Some(statement) = self.statement.take() else {
432            return;
433        };
434        let raw = statement.into_raw();
435        unsafe {
436            ffi::sqlite3_reset(raw.as_ptr());
437            ffi::sqlite3_clear_bindings(raw.as_ptr());
438            self.cache.borrow_mut().insert(self.sql.clone(), raw);
439        }
440    }
441}
442
443pub(crate) fn sqlite_error(db: *mut ffi::sqlite3, code: c_int) -> DbError {
444    let message = unsafe {
445        let ptr = ffi::sqlite3_errmsg(db);
446        if ptr.is_null() {
447            "unknown sqlite error".to_string()
448        } else {
449            CStr::from_ptr(ptr).to_string_lossy().into_owned()
450        }
451    };
452    classify_sqlite_error(code, message)
453}
454
455fn classify_sqlite_error(code: c_int, message: String) -> DbError {
456    if code == ffi::SQLITE_CONSTRAINT {
457        DbError::Constraint(message)
458    } else {
459        DbError::Sqlite(code, message)
460    }
461}
462
463fn take_error_message(error: *mut c_char) -> String {
464    if error.is_null() {
465        return "unknown sqlite error".to_string();
466    }
467    let message = unsafe { CStr::from_ptr(error).to_string_lossy().into_owned() };
468    unsafe {
469        ffi::sqlite3_free(error.cast::<c_void>());
470    }
471    message
472}
473
474fn tail_is_empty(tail: *const c_char) -> bool {
475    if tail.is_null() {
476        return true;
477    }
478    let bytes = unsafe { CStr::from_ptr(tail).to_bytes() };
479    bytes.iter().all(u8::is_ascii_whitespace)
480}
481
482#[cfg(test)]
483mod tests {
484    use super::open_read_write;
485    use crate::config::STATEMENT_CACHE_CAPACITY;
486    use crate::sqlite_vfs::{lock, stable_blob};
487    use crate::stable::memory;
488    use crate::Db;
489    use serial_test::serial;
490
491    fn reset() {
492        stable_blob::rollback_update();
493        stable_blob::invalidate_read_cache();
494        memory::reset_for_tests();
495        lock::reset_for_tests();
496        Db::init(memory::memory_for_tests()).unwrap();
497    }
498
499    #[test]
500    #[serial]
501    fn cached_statements_are_lru_bounded() {
502        reset();
503        let connection = open_read_write().unwrap();
504
505        for index in 0..(STATEMENT_CACHE_CAPACITY + 8) {
506            let sql = format!("SELECT {index}");
507            let mut statement = connection.prepare_cached(&sql).unwrap();
508            let value = statement.query_scalar::<i64>(crate::params![]).unwrap();
509            assert_eq!(value, i64::try_from(index).unwrap());
510        }
511
512        let cache = connection.cached.borrow();
513        assert_eq!(cache.statements.len(), STATEMENT_CACHE_CAPACITY);
514        assert!(!cache.statements.contains_key("SELECT 0"));
515        assert!(cache
516            .statements
517            .contains_key(&format!("SELECT {}", STATEMENT_CACHE_CAPACITY + 7)));
518    }
519
520    #[test]
521    #[serial]
522    fn discarded_cached_statement_is_finalized_not_cached() {
523        reset();
524        let connection = open_read_write().unwrap();
525
526        let statement = connection.prepare_cached("SELECT 1").unwrap();
527        statement.discard();
528
529        assert_eq!(connection.cached.borrow().statements.len(), 0);
530    }
531}