Skip to main content

ic_sqlite_vfs/stable/
memory.rs

1//! Byte-addressed SQLite memory wrapper.
2//!
3//! The crate stores SQLite inside a user-provided `VirtualMemory`, so it can
4//! coexist with other stable structures managed by the same MemoryManager.
5
6use crate::config::STABLE_PAGE_SIZE;
7#[cfg(any(test, debug_assertions))]
8use ic_stable_structures::memory_manager::{MemoryId, MemoryManager};
9use ic_stable_structures::{memory_manager::VirtualMemory, DefaultMemoryImpl, Memory};
10use std::cell::{Cell, RefCell};
11use std::collections::BTreeMap;
12
13pub type DbMemory = VirtualMemory<DefaultMemoryImpl>;
14
15#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
16pub struct ContextId(u64);
17
18#[derive(Debug, thiserror::Error)]
19pub enum StableMemoryError {
20    #[error("stable memory backend is not initialized")]
21    NotInitialized,
22    #[error("stable memory backend is already initialized")]
23    AlreadyInitialized,
24    #[error(
25        "stable memory grow failed: current_pages={current_pages}, required_pages={required_pages}"
26    )]
27    GrowFailed {
28        current_pages: u64,
29        required_pages: u64,
30    },
31    #[error("offset overflow")]
32    OffsetOverflow,
33    #[error("import session already started")]
34    ImportAlreadyStarted,
35    #[error("import session not started")]
36    ImportNotStarted,
37    #[error("database update already in progress")]
38    UpdateInProgress,
39    #[error("import chunk out of order: offset={offset}, expected={expected}")]
40    ImportOutOfOrder { offset: u64, expected: u64 },
41    #[error("import chunk out of bounds: offset={offset}, len={len}, db_size={db_size}")]
42    ImportOutOfBounds { offset: u64, len: u64, db_size: u64 },
43    #[error("import incomplete: written_until={written_until}, db_size={db_size}")]
44    ImportIncomplete { written_until: u64, db_size: u64 },
45    #[error("checksum mismatch: expected={expected}, actual={actual}")]
46    ChecksumMismatch { expected: u64, actual: u64 },
47    #[error("checksum refresh chunk size must be greater than zero")]
48    ChecksumRefreshChunkEmpty,
49    #[error("stable blob failpoint: {0}")]
50    Failpoint(&'static str),
51    #[error("superblock metadata checksum mismatch")]
52    MetaChecksumMismatch,
53    #[error("unsupported stable layout version: {0}")]
54    UnsupportedLayoutVersion(u64),
55}
56
57#[cfg(any(test, feature = "canister-api-test-failpoints"))]
58#[derive(Clone, Copy, Debug, Eq, PartialEq)]
59pub enum MemoryFailpoint {
60    GrowFailed { ordinal: u64 },
61    TrapAfterWrite { ordinal: u64 },
62}
63
64#[cfg(any(test, feature = "canister-api-test-failpoints"))]
65thread_local! {
66    static FAILPOINTS: RefCell<BTreeMap<ContextId, MemoryFailpointState>> = const { RefCell::new(BTreeMap::new()) };
67}
68
69thread_local! {
70    static NEXT_CONTEXT_ID: Cell<u64> = const { Cell::new(1) };
71    static DEFAULT_CONTEXT: Cell<Option<ContextId>> = const { Cell::new(None) };
72    static CURRENT_CONTEXT: Cell<Option<ContextId>> = const { Cell::new(None) };
73    static DB_MEMORY: RefCell<BTreeMap<ContextId, DbMemory>> = const { RefCell::new(BTreeMap::new()) };
74}
75
76pub fn init(memory: DbMemory) -> Result<ContextId, StableMemoryError> {
77    DEFAULT_CONTEXT.with(|default| {
78        if default.get().is_some() {
79            return Err(StableMemoryError::AlreadyInitialized);
80        }
81        let context = init_context(memory);
82        default.set(Some(context));
83        Ok(context)
84    })
85}
86
87pub fn init_context(memory: DbMemory) -> ContextId {
88    let context = NEXT_CONTEXT_ID.with(|next| {
89        let context = ContextId(next.get());
90        next.set(next.get().saturating_add(1));
91        context
92    });
93    DB_MEMORY.with(|slot| {
94        slot.borrow_mut().insert(context, memory);
95    });
96    context
97}
98
99pub fn is_initialized() -> bool {
100    DEFAULT_CONTEXT.with(|context| context.get().is_some())
101}
102
103pub fn default_context() -> Option<ContextId> {
104    DEFAULT_CONTEXT.with(Cell::get)
105}
106
107pub fn active_context_id() -> Result<ContextId, StableMemoryError> {
108    if let Some(context) = CURRENT_CONTEXT.with(Cell::get) {
109        return Ok(context);
110    }
111    default_context().ok_or(StableMemoryError::NotInitialized)
112}
113
114pub fn with_context<T>(context: ContextId, f: impl FnOnce() -> T) -> T {
115    let previous = CURRENT_CONTEXT.with(|current| {
116        let previous = current.get();
117        current.set(Some(context));
118        previous
119    });
120    let _guard = ContextGuard { previous };
121    f()
122}
123
124#[cfg(any(test, feature = "canister-api-test-failpoints"))]
125pub fn set_failpoint(failpoint: MemoryFailpoint) {
126    if let Ok(context) = active_context_id() {
127        FAILPOINTS.with(|slot| {
128            slot.borrow_mut().insert(
129                context,
130                MemoryFailpointState {
131                    failpoint: Some(failpoint),
132                    grow_count: 0,
133                    write_count: 0,
134                },
135            );
136        });
137    }
138}
139
140#[cfg(any(test, feature = "canister-api-test-failpoints"))]
141pub fn clear_failpoint() {
142    FAILPOINTS.with(|slot| slot.borrow_mut().clear());
143}
144
145pub fn size_pages() -> u64 {
146    with_memory(|memory| memory.size()).unwrap_or(0)
147}
148
149pub fn ensure_capacity(end_offset: u64) -> Result<(), StableMemoryError> {
150    let current_bytes = size_pages()
151        .checked_mul(STABLE_PAGE_SIZE)
152        .ok_or(StableMemoryError::OffsetOverflow)?;
153    if end_offset <= current_bytes {
154        return Ok(());
155    }
156
157    let missing = end_offset
158        .checked_sub(current_bytes)
159        .ok_or(StableMemoryError::OffsetOverflow)?;
160    let pages = missing.div_ceil(STABLE_PAGE_SIZE);
161
162    #[cfg(any(test, feature = "canister-api-test-failpoints"))]
163    if hit_grow_failpoint() {
164        return Err(StableMemoryError::GrowFailed {
165            current_pages: current_bytes / STABLE_PAGE_SIZE,
166            required_pages: current_bytes / STABLE_PAGE_SIZE + pages,
167        });
168    }
169
170    let previous = with_memory(|memory| memory.grow(pages))?;
171    if previous < 0 {
172        let required_pages = size_pages()
173            .checked_add(pages)
174            .ok_or(StableMemoryError::OffsetOverflow)?;
175        return Err(StableMemoryError::GrowFailed {
176            current_pages: size_pages(),
177            required_pages,
178        });
179    }
180
181    Ok(())
182}
183
184pub fn read(offset: u64, dst: &mut [u8]) -> Result<(), StableMemoryError> {
185    if dst.is_empty() {
186        return Ok(());
187    }
188    let end = checked_end(offset, dst.len())?;
189    ensure_capacity(end)?;
190
191    with_memory(|memory| memory.read(offset, dst))?;
192
193    Ok(())
194}
195
196pub fn write(offset: u64, bytes: &[u8]) -> Result<(), StableMemoryError> {
197    if bytes.is_empty() {
198        return Ok(());
199    }
200    let end = checked_end(offset, bytes.len())?;
201    ensure_capacity(end)?;
202
203    with_memory(|memory| memory.write(offset, bytes))?;
204
205    #[cfg(any(test, feature = "canister-api-test-failpoints"))]
206    if hit_write_trap_failpoint() {
207        fail_after_stable_write();
208    }
209
210    Ok(())
211}
212
213#[cfg(any(test, debug_assertions))]
214pub fn reset_for_tests() {
215    clear_initialization();
216    #[cfg(any(test, feature = "canister-api-test-failpoints"))]
217    clear_failpoint();
218}
219
220#[cfg(any(test, debug_assertions))]
221pub(crate) fn clear_initialization() {
222    DB_MEMORY.with(|memory| memory.borrow_mut().clear());
223    DEFAULT_CONTEXT.with(|context| context.set(None));
224    CURRENT_CONTEXT.with(|context| context.set(None));
225    NEXT_CONTEXT_ID.with(|next| next.set(1));
226    #[cfg(any(test, feature = "canister-api-test-failpoints"))]
227    clear_failpoint();
228    crate::stable::meta::clear_superblock_cache();
229    crate::sqlite_vfs::stable_blob::invalidate_read_cache();
230}
231
232pub(crate) fn clear_failed_initialization(context: ContextId) {
233    DB_MEMORY.with(|memory| {
234        memory.borrow_mut().remove(&context);
235    });
236    DEFAULT_CONTEXT.with(|default| {
237        if default.get() == Some(context) {
238            default.set(None);
239        }
240    });
241    CURRENT_CONTEXT.with(|current| {
242        if current.get() == Some(context) {
243            current.set(None);
244        }
245    });
246    #[cfg(any(test, feature = "canister-api-test-failpoints"))]
247    FAILPOINTS.with(|slot| {
248        slot.borrow_mut().remove(&context);
249    });
250    crate::stable::meta::clear_superblock_cache();
251    crate::sqlite_vfs::stable_blob::invalidate_read_cache();
252}
253
254#[cfg(test)]
255pub fn snapshot_for_tests() -> Vec<u8> {
256    let len = usize::try_from(size_pages().saturating_mul(STABLE_PAGE_SIZE))
257        .expect("test memory size fits usize");
258    let mut out = vec![0_u8; len];
259    read(0, &mut out).expect("test memory snapshot succeeds");
260    out
261}
262
263#[cfg(test)]
264pub fn restore_for_tests(snapshot: Vec<u8>) -> DbMemory {
265    reset_for_tests();
266    let memory = memory_for_tests();
267    let pages = u64::try_from(snapshot.len())
268        .expect("snapshot len fits u64")
269        .div_ceil(STABLE_PAGE_SIZE);
270    if pages > 0 {
271        assert!(memory.grow(pages) >= 0, "snapshot memory grows");
272        memory.write(0, &snapshot);
273    }
274    crate::stable::meta::clear_superblock_cache();
275    memory
276}
277
278#[cfg(any(test, debug_assertions))]
279pub fn memory_for_tests() -> DbMemory {
280    MemoryManager::init(DefaultMemoryImpl::default()).get(MemoryId::new(42))
281}
282
283fn with_memory<T>(f: impl FnOnce(&DbMemory) -> T) -> Result<T, StableMemoryError> {
284    let context = active_context_id()?;
285    DB_MEMORY.with(|slot| {
286        let slot = slot.borrow();
287        let memory = slot
288            .get(&context)
289            .ok_or(StableMemoryError::NotInitialized)?;
290        Ok(f(memory))
291    })
292}
293
294struct ContextGuard {
295    previous: Option<ContextId>,
296}
297
298impl Drop for ContextGuard {
299    fn drop(&mut self) {
300        CURRENT_CONTEXT.with(|current| current.set(self.previous));
301    }
302}
303
304#[cfg(any(test, feature = "canister-api-test-failpoints"))]
305#[derive(Clone, Copy, Debug)]
306struct MemoryFailpointState {
307    failpoint: Option<MemoryFailpoint>,
308    grow_count: u64,
309    write_count: u64,
310}
311
312fn checked_end(offset: u64, len: usize) -> Result<u64, StableMemoryError> {
313    let len = u64::try_from(len).map_err(|_| StableMemoryError::OffsetOverflow)?;
314    offset
315        .checked_add(len)
316        .ok_or(StableMemoryError::OffsetOverflow)
317}
318
319#[cfg(any(test, feature = "canister-api-test-failpoints"))]
320fn hit_grow_failpoint() -> bool {
321    let Ok(context) = active_context_id() else {
322        return false;
323    };
324    FAILPOINTS.with(|slot| {
325        let mut slot = slot.borrow_mut();
326        let Some(state) = slot.get_mut(&context) else {
327            return false;
328        };
329        state.grow_count += 1;
330        if state.failpoint
331            == Some(MemoryFailpoint::GrowFailed {
332                ordinal: state.grow_count,
333            })
334        {
335            state.failpoint = None;
336            true
337        } else {
338            false
339        }
340    })
341}
342
343#[cfg(any(test, feature = "canister-api-test-failpoints"))]
344fn hit_write_trap_failpoint() -> bool {
345    let Ok(context) = active_context_id() else {
346        return false;
347    };
348    FAILPOINTS.with(|slot| {
349        let mut slot = slot.borrow_mut();
350        let Some(state) = slot.get_mut(&context) else {
351            return false;
352        };
353        state.write_count += 1;
354        if state.failpoint
355            == Some(MemoryFailpoint::TrapAfterWrite {
356                ordinal: state.write_count,
357            })
358        {
359            state.failpoint = None;
360            true
361        } else {
362            false
363        }
364    })
365}
366
367#[cfg(all(target_arch = "wasm32", feature = "canister-api-test-failpoints"))]
368fn fail_after_stable_write() -> ! {
369    ic_cdk::trap("stable write failpoint");
370}
371
372#[cfg(all(
373    any(test, feature = "canister-api-test-failpoints"),
374    not(all(target_arch = "wasm32", feature = "canister-api-test-failpoints"))
375))]
376fn fail_after_stable_write() -> ! {
377    panic!("stable write failpoint");
378}