Skip to main content

ic_sqlite_vfs/db/
mod.rs

1//! Public SQLite database facade for canister methods.
2//!
3//! Update paths accept only synchronous closures, which prevents holding a DB
4//! transaction across `await`. Query paths open SQLite in read-only/query-only mode.
5
6pub mod connection;
7pub mod migrate;
8pub mod pragmas;
9pub mod row;
10pub mod statement;
11pub mod transaction;
12pub mod value;
13
14use crate::sqlite_vfs::stable_blob;
15use crate::stable::memory::{self, ContextId, DbMemory};
16use crate::stable::meta::Superblock;
17use connection::Connection;
18pub use row::{FromColumn, Row, TextLen};
19pub use stable_blob::ChecksumRefresh;
20use std::cell::RefCell;
21use std::collections::BTreeMap;
22use std::ffi::c_int;
23use std::rc::Rc;
24pub use transaction::UpdateConnection;
25pub use value::{Null, ToSql, Value, NULL};
26
27thread_local! {
28    static READ_CONNECTIONS: RefCell<BTreeMap<ContextId, Rc<Connection>>> = const { RefCell::new(BTreeMap::new()) };
29    static WRITE_CONNECTIONS: RefCell<BTreeMap<ContextId, Rc<Connection>>> = const { RefCell::new(BTreeMap::new()) };
30    static ACTIVE_READ_CONNECTIONS: RefCell<Vec<(ContextId, usize)>> = const { RefCell::new(Vec::new()) };
31}
32
33#[derive(Debug, thiserror::Error)]
34pub enum DbError {
35    #[error("sqlite error {0}: {1}")]
36    Sqlite(c_int, String),
37    #[error("sqlite constraint failed: {0}")]
38    Constraint(String),
39    #[error("query returned no rows")]
40    NotFound,
41    #[error("column {index} has type {actual}, expected {expected}")]
42    TypeMismatch {
43        index: usize,
44        expected: &'static str,
45        actual: &'static str,
46    },
47    #[error("column index {index} out of range for {count} columns")]
48    ColumnOutOfRange { index: usize, count: usize },
49    #[error("stable memory error: {0}")]
50    Stable(#[from] crate::stable::memory::StableMemoryError),
51    #[error("stable memory backend is not initialized; call Db::init(memory) first")]
52    StableMemoryNotInitialized,
53    #[error("stable memory backend is already initialized")]
54    StableMemoryAlreadyInitialized,
55    #[error("cannot mutate database while a query connection is active")]
56    ReadConnectionInUse,
57    #[error("migration version exceeds SQLite INTEGER range: {0}")]
58    MigrationVersionOutOfRange(u64),
59    #[error("duplicate migration version: {0}")]
60    DuplicateMigrationVersion(u64),
61    #[error("SQL contains an interior NUL byte")]
62    InteriorNul,
63    #[error("SQL contains no statement")]
64    EmptySql,
65    #[error("SQL contains trailing text after the first statement")]
66    TrailingSql,
67    #[error("text value too large")]
68    TextTooLarge,
69    #[error("blob value too large")]
70    BlobTooLarge,
71    #[error("too many SQL parameters")]
72    TooManyParameters,
73    #[error("SQL parameter count mismatch: expected {expected}, actual {actual}")]
74    ParameterCountMismatch { expected: usize, actual: usize },
75    #[error("named bind cannot be used with anonymous SQL parameter at index {index}")]
76    AnonymousParameterInNamedBind { index: usize },
77    #[error("SQL parameter not found: {0}")]
78    ParameterNotFound(String),
79}
80
81pub struct Db;
82
83#[derive(Clone, Copy, Debug, Eq, PartialEq)]
84pub struct DbHandle {
85    context: ContextId,
86}
87
88impl Db {
89    pub fn init(memory: DbMemory) -> Result<(), DbError> {
90        let context = memory::init(memory).map_err(|error| match error {
91            crate::stable::memory::StableMemoryError::AlreadyInitialized => {
92                DbError::StableMemoryAlreadyInitialized
93            }
94            crate::stable::memory::StableMemoryError::NotInitialized => {
95                DbError::StableMemoryNotInitialized
96            }
97            error => DbError::Stable(error),
98        })?;
99        clear_read_connection(context);
100        clear_write_connection(context);
101        let handle = DbHandle::from_context(context);
102        let result = handle.initialize();
103        if result.is_err() {
104            clear_read_connection(context);
105            clear_write_connection(context);
106            memory::clear_failed_initialization(context);
107        }
108        result
109    }
110
111    fn default_handle() -> Result<DbHandle, DbError> {
112        memory::default_context()
113            .map(DbHandle::from_context)
114            .ok_or(DbError::StableMemoryNotInitialized)
115    }
116
117    pub fn update<T, F>(f: F) -> Result<T, DbError>
118    where
119        F: FnOnce(&mut UpdateConnection<'_>) -> Result<T, DbError>,
120    {
121        Self::default_handle()?.update(f)
122    }
123
124    pub fn query<T, F>(f: F) -> Result<T, DbError>
125    where
126        F: FnOnce(&Connection) -> Result<T, DbError>,
127    {
128        Self::default_handle()?.query(f)
129    }
130
131    pub fn migrate(migrations: &[migrate::Migration]) -> Result<(), DbError> {
132        Self::default_handle()?.migrate(migrations)
133    }
134
135    pub fn integrity_check() -> Result<String, DbError> {
136        Self::default_handle()?.integrity_check()
137    }
138
139    pub fn export_chunk(offset: u64, len: u64) -> Result<Vec<u8>, DbError> {
140        Self::default_handle()?.export_chunk(offset, len)
141    }
142
143    pub fn db_checksum() -> Result<u64, DbError> {
144        Self::default_handle()?.db_checksum()
145    }
146
147    pub fn refresh_checksum() -> Result<u64, DbError> {
148        Self::default_handle()?.refresh_checksum()
149    }
150
151    pub fn refresh_checksum_chunk(max_bytes: u64) -> Result<ChecksumRefresh, DbError> {
152        Self::default_handle()?.refresh_checksum_chunk(max_bytes)
153    }
154
155    pub fn begin_import(total_size: u64, expected_checksum: u64) -> Result<(), DbError> {
156        Self::default_handle()?.begin_import(total_size, expected_checksum)
157    }
158
159    pub fn import_chunk(offset: u64, bytes: &[u8]) -> Result<(), DbError> {
160        Self::default_handle()?.import_chunk(offset, bytes)
161    }
162
163    pub fn finish_import() -> Result<(), DbError> {
164        Self::default_handle()?.finish_import()
165    }
166
167    pub fn cancel_import() -> Result<(), DbError> {
168        Self::default_handle()?.cancel_import()
169    }
170
171    pub fn compact() -> Result<(), DbError> {
172        Self::default_handle()?.compact()
173    }
174}
175
176impl DbHandle {
177    pub fn init(memory: DbMemory) -> Result<Self, DbError> {
178        let handle = Self::from_context(memory::init_context(memory));
179        clear_read_connection(handle.context);
180        clear_write_connection(handle.context);
181        if let Err(error) = handle.initialize() {
182            clear_read_connection(handle.context);
183            clear_write_connection(handle.context);
184            memory::clear_failed_initialization(handle.context);
185            return Err(error);
186        }
187        Ok(handle)
188    }
189
190    fn from_context(context: ContextId) -> Self {
191        Self { context }
192    }
193
194    fn initialize(self) -> Result<(), DbError> {
195        self.with_context(|| {
196            crate::sqlite_vfs::register();
197            Superblock::load()?;
198            stable_blob::ensure_page_map_layout()?;
199            Ok(())
200        })
201    }
202
203    fn with_context<T>(self, f: impl FnOnce() -> Result<T, DbError>) -> Result<T, DbError> {
204        memory::with_context(self.context, f)
205    }
206
207    pub fn update<T, F>(self, f: F) -> Result<T, DbError>
208    where
209        F: FnOnce(&mut UpdateConnection<'_>) -> Result<T, DbError>,
210    {
211        self.with_context(|| {
212            reject_active_read_connection(self.context)?;
213            clear_read_connection(self.context);
214            let db_size = stable_blob::begin_update()?;
215            let _overlay_guard = OverlayGuard;
216            let connection = write_connection(self.context, db_size)?;
217            let result = transaction::run_immediate(&connection, f);
218            clear_read_connection(self.context);
219            if result.is_err() {
220                clear_write_connection(self.context);
221            }
222            result
223        })
224    }
225
226    pub fn query<T, F>(self, f: F) -> Result<T, DbError>
227    where
228        F: FnOnce(&Connection) -> Result<T, DbError>,
229    {
230        self.with_context(|| with_read_connection(self.context, f))
231    }
232
233    pub fn migrate(self, migrations: &[migrate::Migration]) -> Result<(), DbError> {
234        self.update(|connection| migrate::apply(connection, migrations))?;
235        self.with_context(|| {
236            let target_version = migrations
237                .iter()
238                .map(|migration| migration.version)
239                .max()
240                .unwrap_or(0);
241            let mut block = Superblock::load()?;
242            if block.schema_version < target_version {
243                clear_read_connection(self.context);
244                block.schema_version = target_version;
245                block.store()?;
246            }
247            Ok(())
248        })
249    }
250
251    pub fn integrity_check(self) -> Result<String, DbError> {
252        self.query(|connection| {
253            connection.query_scalar::<String>("PRAGMA integrity_check", crate::params![])
254        })
255    }
256
257    pub fn export_chunk(self, offset: u64, len: u64) -> Result<Vec<u8>, DbError> {
258        self.with_context(|| stable_blob::export_chunk(offset, len).map_err(DbError::from))
259    }
260
261    pub fn db_checksum(self) -> Result<u64, DbError> {
262        self.with_context(|| stable_blob::checksum().map_err(DbError::from))
263    }
264
265    pub fn refresh_checksum(self) -> Result<u64, DbError> {
266        self.with_context(|| {
267            reject_active_read_connection(self.context)?;
268            clear_read_connection(self.context);
269            stable_blob::refresh_checksum().map_err(DbError::from)
270        })
271    }
272
273    pub fn refresh_checksum_chunk(self, max_bytes: u64) -> Result<ChecksumRefresh, DbError> {
274        self.with_context(|| {
275            reject_active_read_connection(self.context)?;
276            clear_read_connection(self.context);
277            stable_blob::refresh_checksum_chunk(max_bytes).map_err(DbError::from)
278        })
279    }
280
281    pub fn begin_import(self, total_size: u64, expected_checksum: u64) -> Result<(), DbError> {
282        self.with_context(|| {
283            reject_active_read_connection(self.context)?;
284            clear_read_connection(self.context);
285            clear_write_connection(self.context);
286            stable_blob::begin_import(total_size, expected_checksum).map_err(DbError::from)
287        })
288    }
289
290    pub fn import_chunk(self, offset: u64, bytes: &[u8]) -> Result<(), DbError> {
291        self.with_context(|| {
292            reject_active_read_connection(self.context)?;
293            clear_read_connection(self.context);
294            stable_blob::import_chunk(offset, bytes).map_err(DbError::from)
295        })
296    }
297
298    pub fn finish_import(self) -> Result<(), DbError> {
299        self.with_context(|| {
300            reject_active_read_connection(self.context)?;
301            clear_read_connection(self.context);
302            clear_write_connection(self.context);
303            stable_blob::finish_import().map_err(DbError::from)
304        })
305    }
306
307    pub fn cancel_import(self) -> Result<(), DbError> {
308        self.with_context(|| {
309            reject_active_read_connection(self.context)?;
310            clear_read_connection(self.context);
311            clear_write_connection(self.context);
312            stable_blob::cancel_import().map_err(DbError::from)
313        })
314    }
315
316    pub fn compact(self) -> Result<(), DbError> {
317        self.with_context(|| {
318            reject_active_read_connection(self.context)?;
319            clear_read_connection(self.context);
320            clear_write_connection(self.context);
321            stable_blob::compact().map_err(DbError::from)
322        })
323    }
324}
325
326fn write_connection(context: ContextId, db_size: u64) -> Result<Rc<Connection>, DbError> {
327    WRITE_CONNECTIONS.with(|slot| {
328        let cached = { slot.borrow().get(&context).cloned() };
329        if let Some(connection) = cached {
330            return Ok(connection);
331        }
332        let connection = if db_size == 0 {
333            connection::open_read_write()?
334        } else {
335            connection::open_read_write_existing()?
336        };
337        let connection = Rc::new(connection);
338        slot.borrow_mut().insert(context, Rc::clone(&connection));
339        Ok(connection)
340    })
341}
342
343fn with_read_connection<T>(
344    context: ContextId,
345    f: impl FnOnce(&Connection) -> Result<T, DbError>,
346) -> Result<T, DbError> {
347    READ_CONNECTIONS.with(|slot| {
348        let cached = { slot.borrow().get(&context).cloned() };
349        let connection = if let Some(connection) = cached {
350            connection
351        } else {
352            let connection = Rc::new(connection::open_read_only()?);
353            slot.borrow_mut().insert(context, Rc::clone(&connection));
354            connection
355        };
356        let _guard = ReadGuard::enter(context);
357        f(&connection)
358    })
359}
360
361fn reject_active_read_connection(context: ContextId) -> Result<(), DbError> {
362    ACTIVE_READ_CONNECTIONS.with(|slot| {
363        let slot = slot.borrow();
364        let active = active_read_index(&slot, context)
365            .map(|index| slot[index].1)
366            .unwrap_or(0);
367        if active == 0 {
368            Ok(())
369        } else {
370            Err(DbError::ReadConnectionInUse)
371        }
372    })
373}
374
375fn clear_read_connection(context: ContextId) {
376    READ_CONNECTIONS.with(|slot| {
377        slot.borrow_mut().remove(&context);
378    });
379}
380
381fn clear_write_connection(context: ContextId) {
382    WRITE_CONNECTIONS.with(|slot| {
383        slot.borrow_mut().remove(&context);
384    });
385}
386
387struct ReadGuard {
388    context: ContextId,
389}
390
391impl ReadGuard {
392    fn enter(context: ContextId) -> Self {
393        ACTIVE_READ_CONNECTIONS.with(|slot| {
394            let mut slot = slot.borrow_mut();
395            if slot.is_empty() {
396                slot.push((context, 1));
397                return;
398            }
399            if let Some(index) = active_read_index(&slot, context) {
400                slot[index].1 += 1;
401            } else {
402                slot.push((context, 1));
403            }
404        });
405        Self { context }
406    }
407}
408
409impl Drop for ReadGuard {
410    fn drop(&mut self) {
411        ACTIVE_READ_CONNECTIONS.with(|slot| {
412            let mut slot = slot.borrow_mut();
413            if slot.len() == 1 && slot[0].0 == self.context {
414                let depth = &mut slot[0].1;
415                *depth = depth.saturating_sub(1);
416                if *depth == 0 {
417                    slot.clear();
418                }
419                return;
420            }
421            let Some(index) = active_read_index(&slot, self.context) else {
422                return;
423            };
424            let depth = &mut slot[index].1;
425            *depth = depth.saturating_sub(1);
426            if *depth == 0 {
427                slot.swap_remove(index);
428            }
429        });
430    }
431}
432
433fn active_read_index(entries: &[(ContextId, usize)], context: ContextId) -> Option<usize> {
434    entries
435        .iter()
436        .position(|(stored_context, _)| *stored_context == context)
437}
438
439struct OverlayGuard;
440
441impl Drop for OverlayGuard {
442    fn drop(&mut self) {
443        stable_blob::rollback_update();
444    }
445}