Skip to main content

ic_sqlite_vfs/db/
mod.rs

1//! Public SQLite database facade for canister methods.
2//!
3//! Update paths accept only synchronous closures, which prevents holding a DB
4//! transaction across `await`. Query paths open SQLite in read-only/query-only mode.
5
6pub mod connection;
7pub mod migrate;
8pub mod pragmas;
9pub mod row;
10pub mod statement;
11pub mod transaction;
12pub mod value;
13
14use crate::sqlite_vfs::stable_blob;
15use crate::stable::memory::{self, ContextId, DbMemory};
16use crate::stable::meta::Superblock;
17use connection::Connection;
18pub use row::{FromColumn, Row, TextLen};
19pub use stable_blob::ChecksumRefresh;
20use std::cell::RefCell;
21use std::collections::BTreeMap;
22use std::ffi::c_int;
23use std::rc::Rc;
24pub use transaction::UpdateConnection;
25pub use value::{Null, ToSql, Value, NULL};
26
27thread_local! {
28    static READ_CONNECTIONS: RefCell<BTreeMap<ContextId, Rc<Connection>>> = const { RefCell::new(BTreeMap::new()) };
29    static ACTIVE_READ_CONNECTIONS: RefCell<BTreeMap<ContextId, usize>> = const { RefCell::new(BTreeMap::new()) };
30}
31
32#[derive(Debug, thiserror::Error)]
33pub enum DbError {
34    #[error("sqlite error {0}: {1}")]
35    Sqlite(c_int, String),
36    #[error("sqlite constraint failed: {0}")]
37    Constraint(String),
38    #[error("query returned no rows")]
39    NotFound,
40    #[error("column {index} has type {actual}, expected {expected}")]
41    TypeMismatch {
42        index: usize,
43        expected: &'static str,
44        actual: &'static str,
45    },
46    #[error("column index {index} out of range for {count} columns")]
47    ColumnOutOfRange { index: usize, count: usize },
48    #[error("stable memory error: {0}")]
49    Stable(#[from] crate::stable::memory::StableMemoryError),
50    #[error("stable memory backend is not initialized; call Db::init(memory) first")]
51    StableMemoryNotInitialized,
52    #[error("stable memory backend is already initialized")]
53    StableMemoryAlreadyInitialized,
54    #[error("cannot mutate database while a query connection is active")]
55    ReadConnectionInUse,
56    #[error("migration version exceeds SQLite INTEGER range: {0}")]
57    MigrationVersionOutOfRange(u64),
58    #[error("duplicate migration version: {0}")]
59    DuplicateMigrationVersion(u64),
60    #[error("SQL contains an interior NUL byte")]
61    InteriorNul,
62    #[error("SQL contains no statement")]
63    EmptySql,
64    #[error("SQL contains trailing text after the first statement")]
65    TrailingSql,
66    #[error("text value too large")]
67    TextTooLarge,
68    #[error("blob value too large")]
69    BlobTooLarge,
70    #[error("too many SQL parameters")]
71    TooManyParameters,
72    #[error("SQL parameter count mismatch: expected {expected}, actual {actual}")]
73    ParameterCountMismatch { expected: usize, actual: usize },
74    #[error("named bind cannot be used with anonymous SQL parameter at index {index}")]
75    AnonymousParameterInNamedBind { index: usize },
76    #[error("SQL parameter not found: {0}")]
77    ParameterNotFound(String),
78}
79
80pub struct Db;
81
82#[derive(Clone, Copy, Debug, Eq, PartialEq)]
83pub struct DbHandle {
84    context: ContextId,
85}
86
87impl Db {
88    pub fn init(memory: DbMemory) -> Result<(), DbError> {
89        let context = memory::init(memory).map_err(|error| match error {
90            crate::stable::memory::StableMemoryError::AlreadyInitialized => {
91                DbError::StableMemoryAlreadyInitialized
92            }
93            crate::stable::memory::StableMemoryError::NotInitialized => {
94                DbError::StableMemoryNotInitialized
95            }
96            error => DbError::Stable(error),
97        })?;
98        clear_read_connection(context);
99        let handle = DbHandle::from_context(context);
100        let result = handle.initialize();
101        if result.is_err() {
102            clear_read_connection(context);
103            memory::clear_failed_initialization(context);
104        }
105        result
106    }
107
108    fn default_handle() -> Result<DbHandle, DbError> {
109        memory::default_context()
110            .map(DbHandle::from_context)
111            .ok_or(DbError::StableMemoryNotInitialized)
112    }
113
114    pub fn update<T, F>(f: F) -> Result<T, DbError>
115    where
116        F: FnOnce(&mut UpdateConnection<'_>) -> Result<T, DbError>,
117    {
118        Self::default_handle()?.update(f)
119    }
120
121    pub fn query<T, F>(f: F) -> Result<T, DbError>
122    where
123        F: FnOnce(&Connection) -> Result<T, DbError>,
124    {
125        Self::default_handle()?.query(f)
126    }
127
128    pub fn migrate(migrations: &[migrate::Migration]) -> Result<(), DbError> {
129        Self::default_handle()?.migrate(migrations)
130    }
131
132    pub fn integrity_check() -> Result<String, DbError> {
133        Self::default_handle()?.integrity_check()
134    }
135
136    pub fn export_chunk(offset: u64, len: u64) -> Result<Vec<u8>, DbError> {
137        Self::default_handle()?.export_chunk(offset, len)
138    }
139
140    pub fn db_checksum() -> Result<u64, DbError> {
141        Self::default_handle()?.db_checksum()
142    }
143
144    pub fn refresh_checksum() -> Result<u64, DbError> {
145        Self::default_handle()?.refresh_checksum()
146    }
147
148    pub fn refresh_checksum_chunk(max_bytes: u64) -> Result<ChecksumRefresh, DbError> {
149        Self::default_handle()?.refresh_checksum_chunk(max_bytes)
150    }
151
152    pub fn begin_import(total_size: u64, expected_checksum: u64) -> Result<(), DbError> {
153        Self::default_handle()?.begin_import(total_size, expected_checksum)
154    }
155
156    pub fn import_chunk(offset: u64, bytes: &[u8]) -> Result<(), DbError> {
157        Self::default_handle()?.import_chunk(offset, bytes)
158    }
159
160    pub fn finish_import() -> Result<(), DbError> {
161        Self::default_handle()?.finish_import()
162    }
163
164    pub fn cancel_import() -> Result<(), DbError> {
165        Self::default_handle()?.cancel_import()
166    }
167
168    pub fn compact() -> Result<(), DbError> {
169        Self::default_handle()?.compact()
170    }
171}
172
173impl DbHandle {
174    pub fn init(memory: DbMemory) -> Result<Self, DbError> {
175        let handle = Self::from_context(memory::init_context(memory));
176        clear_read_connection(handle.context);
177        if let Err(error) = handle.initialize() {
178            clear_read_connection(handle.context);
179            memory::clear_failed_initialization(handle.context);
180            return Err(error);
181        }
182        Ok(handle)
183    }
184
185    fn from_context(context: ContextId) -> Self {
186        Self { context }
187    }
188
189    fn initialize(self) -> Result<(), DbError> {
190        self.with_context(|| {
191            crate::sqlite_vfs::register();
192            Superblock::load()?;
193            stable_blob::ensure_page_map_layout()?;
194            Ok(())
195        })
196    }
197
198    fn with_context<T>(self, f: impl FnOnce() -> Result<T, DbError>) -> Result<T, DbError> {
199        memory::with_context(self.context, f)
200    }
201
202    pub fn update<T, F>(self, f: F) -> Result<T, DbError>
203    where
204        F: FnOnce(&mut UpdateConnection<'_>) -> Result<T, DbError>,
205    {
206        self.with_context(|| {
207            reject_active_read_connection(self.context)?;
208            clear_read_connection(self.context);
209            stable_blob::begin_update()?;
210            let _overlay_guard = OverlayGuard;
211            let connection = connection::open_read_write()?;
212            let result = transaction::run_immediate(&connection, f);
213            clear_read_connection(self.context);
214            result
215        })
216    }
217
218    pub fn query<T, F>(self, f: F) -> Result<T, DbError>
219    where
220        F: FnOnce(&Connection) -> Result<T, DbError>,
221    {
222        self.with_context(|| with_read_connection(self.context, f))
223    }
224
225    pub fn migrate(self, migrations: &[migrate::Migration]) -> Result<(), DbError> {
226        self.update(|connection| migrate::apply(connection, migrations))?;
227        self.with_context(|| {
228            let target_version = migrations
229                .iter()
230                .map(|migration| migration.version)
231                .max()
232                .unwrap_or(0);
233            let mut block = Superblock::load()?;
234            if block.schema_version < target_version {
235                clear_read_connection(self.context);
236                block.schema_version = target_version;
237                block.store()?;
238            }
239            Ok(())
240        })
241    }
242
243    pub fn integrity_check(self) -> Result<String, DbError> {
244        self.query(|connection| {
245            connection.query_scalar::<String>("PRAGMA integrity_check", crate::params![])
246        })
247    }
248
249    pub fn export_chunk(self, offset: u64, len: u64) -> Result<Vec<u8>, DbError> {
250        self.with_context(|| stable_blob::export_chunk(offset, len).map_err(DbError::from))
251    }
252
253    pub fn db_checksum(self) -> Result<u64, DbError> {
254        self.with_context(|| stable_blob::checksum().map_err(DbError::from))
255    }
256
257    pub fn refresh_checksum(self) -> Result<u64, DbError> {
258        self.with_context(|| {
259            reject_active_read_connection(self.context)?;
260            clear_read_connection(self.context);
261            stable_blob::refresh_checksum().map_err(DbError::from)
262        })
263    }
264
265    pub fn refresh_checksum_chunk(self, max_bytes: u64) -> Result<ChecksumRefresh, DbError> {
266        self.with_context(|| {
267            reject_active_read_connection(self.context)?;
268            clear_read_connection(self.context);
269            stable_blob::refresh_checksum_chunk(max_bytes).map_err(DbError::from)
270        })
271    }
272
273    pub fn begin_import(self, total_size: u64, expected_checksum: u64) -> Result<(), DbError> {
274        self.with_context(|| {
275            reject_active_read_connection(self.context)?;
276            clear_read_connection(self.context);
277            stable_blob::begin_import(total_size, expected_checksum).map_err(DbError::from)
278        })
279    }
280
281    pub fn import_chunk(self, offset: u64, bytes: &[u8]) -> Result<(), DbError> {
282        self.with_context(|| {
283            reject_active_read_connection(self.context)?;
284            clear_read_connection(self.context);
285            stable_blob::import_chunk(offset, bytes).map_err(DbError::from)
286        })
287    }
288
289    pub fn finish_import(self) -> Result<(), DbError> {
290        self.with_context(|| {
291            reject_active_read_connection(self.context)?;
292            clear_read_connection(self.context);
293            stable_blob::finish_import().map_err(DbError::from)
294        })
295    }
296
297    pub fn cancel_import(self) -> Result<(), DbError> {
298        self.with_context(|| {
299            reject_active_read_connection(self.context)?;
300            clear_read_connection(self.context);
301            stable_blob::cancel_import().map_err(DbError::from)
302        })
303    }
304
305    pub fn compact(self) -> Result<(), DbError> {
306        self.with_context(|| {
307            reject_active_read_connection(self.context)?;
308            clear_read_connection(self.context);
309            stable_blob::compact().map_err(DbError::from)
310        })
311    }
312}
313
314fn with_read_connection<T>(
315    context: ContextId,
316    f: impl FnOnce(&Connection) -> Result<T, DbError>,
317) -> Result<T, DbError> {
318    READ_CONNECTIONS.with(|slot| {
319        let cached = { slot.borrow().get(&context).cloned() };
320        let connection = if let Some(connection) = cached {
321            connection
322        } else {
323            let connection = Rc::new(connection::open_read_only()?);
324            slot.borrow_mut().insert(context, Rc::clone(&connection));
325            connection
326        };
327        let _guard = ReadGuard::enter(context);
328        f(&connection)
329    })
330}
331
332fn reject_active_read_connection(context: ContextId) -> Result<(), DbError> {
333    ACTIVE_READ_CONNECTIONS.with(|slot| {
334        if slot.borrow().get(&context).copied().unwrap_or(0) == 0 {
335            Ok(())
336        } else {
337            Err(DbError::ReadConnectionInUse)
338        }
339    })
340}
341
342fn clear_read_connection(context: ContextId) {
343    READ_CONNECTIONS.with(|slot| {
344        slot.borrow_mut().remove(&context);
345    });
346}
347
348struct ReadGuard {
349    context: ContextId,
350}
351
352impl ReadGuard {
353    fn enter(context: ContextId) -> Self {
354        ACTIVE_READ_CONNECTIONS.with(|slot| {
355            let mut slot = slot.borrow_mut();
356            *slot.entry(context).or_insert(0) += 1;
357        });
358        Self { context }
359    }
360}
361
362impl Drop for ReadGuard {
363    fn drop(&mut self) {
364        ACTIVE_READ_CONNECTIONS.with(|slot| {
365            let mut slot = slot.borrow_mut();
366            let Some(depth) = slot.get_mut(&self.context) else {
367                return;
368            };
369            *depth = depth.saturating_sub(1);
370            if *depth == 0 {
371                slot.remove(&self.context);
372            }
373        });
374    }
375}
376
377struct OverlayGuard;
378
379impl Drop for OverlayGuard {
380    fn drop(&mut self) {
381        stable_blob::rollback_update();
382    }
383}