1use crate::config::STABLE_PAGE_SIZE;
7#[cfg(any(test, debug_assertions))]
8use ic_stable_structures::memory_manager::{MemoryId, MemoryManager};
9use ic_stable_structures::{memory_manager::VirtualMemory, DefaultMemoryImpl, Memory};
10use std::cell::{Cell, RefCell};
11use std::collections::BTreeMap;
12
13pub type DbMemory = VirtualMemory<DefaultMemoryImpl>;
14
15#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
16pub struct ContextId(u64);
17
18#[derive(Debug, thiserror::Error)]
19pub enum StableMemoryError {
20 #[error("stable memory backend is not initialized")]
21 NotInitialized,
22 #[error("stable memory backend is already initialized")]
23 AlreadyInitialized,
24 #[error(
25 "stable memory grow failed: current_pages={current_pages}, required_pages={required_pages}"
26 )]
27 GrowFailed {
28 current_pages: u64,
29 required_pages: u64,
30 },
31 #[error("offset overflow")]
32 OffsetOverflow,
33 #[error("import session already started")]
34 ImportAlreadyStarted,
35 #[error("import session not started")]
36 ImportNotStarted,
37 #[error("database update already in progress")]
38 UpdateInProgress,
39 #[error("import chunk out of order: offset={offset}, expected={expected}")]
40 ImportOutOfOrder { offset: u64, expected: u64 },
41 #[error("import chunk out of bounds: offset={offset}, len={len}, db_size={db_size}")]
42 ImportOutOfBounds { offset: u64, len: u64, db_size: u64 },
43 #[error("import incomplete: written_until={written_until}, db_size={db_size}")]
44 ImportIncomplete { written_until: u64, db_size: u64 },
45 #[error("checksum mismatch: expected={expected}, actual={actual}")]
46 ChecksumMismatch { expected: u64, actual: u64 },
47 #[error("checksum refresh chunk size must be greater than zero")]
48 ChecksumRefreshChunkEmpty,
49 #[error("stable blob failpoint: {0}")]
50 Failpoint(&'static str),
51 #[error("superblock metadata checksum mismatch")]
52 MetaChecksumMismatch,
53 #[error("unsupported stable layout version: {0}")]
54 UnsupportedLayoutVersion(u64),
55}
56
57#[cfg(any(test, feature = "canister-api-test-failpoints"))]
58#[derive(Clone, Copy, Debug, Eq, PartialEq)]
59pub enum MemoryFailpoint {
60 GrowFailed { ordinal: u64 },
61 TrapAfterWrite { ordinal: u64 },
62}
63
64#[cfg(any(test, feature = "canister-api-test-failpoints"))]
65thread_local! {
66 static FAILPOINTS: RefCell<BTreeMap<ContextId, MemoryFailpointState>> = const { RefCell::new(BTreeMap::new()) };
67}
68
69thread_local! {
70 static NEXT_CONTEXT_ID: Cell<u64> = const { Cell::new(1) };
71 static DEFAULT_CONTEXT: Cell<Option<ContextId>> = const { Cell::new(None) };
72 static CURRENT_CONTEXT: Cell<Option<ContextId>> = const { Cell::new(None) };
73 static DB_MEMORY: RefCell<BTreeMap<ContextId, DbMemory>> = const { RefCell::new(BTreeMap::new()) };
74}
75
76pub fn init(memory: DbMemory) -> Result<ContextId, StableMemoryError> {
77 DEFAULT_CONTEXT.with(|default| {
78 if default.get().is_some() {
79 return Err(StableMemoryError::AlreadyInitialized);
80 }
81 let context = init_context(memory);
82 default.set(Some(context));
83 Ok(context)
84 })
85}
86
87pub fn init_context(memory: DbMemory) -> ContextId {
88 let context = NEXT_CONTEXT_ID.with(|next| {
89 let context = ContextId(next.get());
90 next.set(next.get().saturating_add(1));
91 context
92 });
93 DB_MEMORY.with(|slot| {
94 slot.borrow_mut().insert(context, memory);
95 });
96 context
97}
98
99pub fn is_initialized() -> bool {
100 DEFAULT_CONTEXT.with(|context| context.get().is_some())
101}
102
103pub fn default_context() -> Option<ContextId> {
104 DEFAULT_CONTEXT.with(Cell::get)
105}
106
107pub fn active_context_id() -> Result<ContextId, StableMemoryError> {
108 if let Some(context) = CURRENT_CONTEXT.with(Cell::get) {
109 return Ok(context);
110 }
111 default_context().ok_or(StableMemoryError::NotInitialized)
112}
113
114pub fn with_context<T>(context: ContextId, f: impl FnOnce() -> T) -> T {
115 let previous = CURRENT_CONTEXT.with(|current| {
116 let previous = current.get();
117 current.set(Some(context));
118 previous
119 });
120 let _guard = ContextGuard { previous };
121 f()
122}
123
124#[cfg(any(test, feature = "canister-api-test-failpoints"))]
125pub fn set_failpoint(failpoint: MemoryFailpoint) {
126 if let Ok(context) = active_context_id() {
127 FAILPOINTS.with(|slot| {
128 slot.borrow_mut().insert(
129 context,
130 MemoryFailpointState {
131 failpoint: Some(failpoint),
132 grow_count: 0,
133 write_count: 0,
134 },
135 );
136 });
137 }
138}
139
140#[cfg(any(test, feature = "canister-api-test-failpoints"))]
141pub fn clear_failpoint() {
142 FAILPOINTS.with(|slot| slot.borrow_mut().clear());
143}
144
145pub fn size_pages() -> u64 {
146 with_memory(|memory| memory.size()).unwrap_or(0)
147}
148
149pub fn ensure_capacity(end_offset: u64) -> Result<(), StableMemoryError> {
150 let current_bytes = size_pages()
151 .checked_mul(STABLE_PAGE_SIZE)
152 .ok_or(StableMemoryError::OffsetOverflow)?;
153 if end_offset <= current_bytes {
154 return Ok(());
155 }
156
157 let missing = end_offset
158 .checked_sub(current_bytes)
159 .ok_or(StableMemoryError::OffsetOverflow)?;
160 let pages = missing.div_ceil(STABLE_PAGE_SIZE);
161
162 #[cfg(any(test, feature = "canister-api-test-failpoints"))]
163 if hit_grow_failpoint() {
164 return Err(StableMemoryError::GrowFailed {
165 current_pages: current_bytes / STABLE_PAGE_SIZE,
166 required_pages: current_bytes / STABLE_PAGE_SIZE + pages,
167 });
168 }
169
170 let previous = with_memory(|memory| memory.grow(pages))?;
171 if previous < 0 {
172 let required_pages = size_pages()
173 .checked_add(pages)
174 .ok_or(StableMemoryError::OffsetOverflow)?;
175 return Err(StableMemoryError::GrowFailed {
176 current_pages: size_pages(),
177 required_pages,
178 });
179 }
180
181 Ok(())
182}
183
184pub fn read(offset: u64, dst: &mut [u8]) -> Result<(), StableMemoryError> {
185 if dst.is_empty() {
186 return Ok(());
187 }
188 let end = checked_end(offset, dst.len())?;
189 ensure_capacity(end)?;
190
191 with_memory(|memory| memory.read(offset, dst))?;
192
193 Ok(())
194}
195
196pub fn write(offset: u64, bytes: &[u8]) -> Result<(), StableMemoryError> {
197 if bytes.is_empty() {
198 return Ok(());
199 }
200 let end = checked_end(offset, bytes.len())?;
201 ensure_capacity(end)?;
202
203 with_memory(|memory| memory.write(offset, bytes))?;
204
205 #[cfg(any(test, feature = "canister-api-test-failpoints"))]
206 if hit_write_trap_failpoint() {
207 fail_after_stable_write();
208 }
209
210 Ok(())
211}
212
213#[cfg(any(test, debug_assertions))]
214pub fn reset_for_tests() {
215 clear_initialization();
216 #[cfg(any(test, feature = "canister-api-test-failpoints"))]
217 clear_failpoint();
218}
219
220#[cfg(any(test, debug_assertions))]
221pub(crate) fn clear_initialization() {
222 DB_MEMORY.with(|memory| memory.borrow_mut().clear());
223 DEFAULT_CONTEXT.with(|context| context.set(None));
224 CURRENT_CONTEXT.with(|context| context.set(None));
225 NEXT_CONTEXT_ID.with(|next| next.set(1));
226 #[cfg(any(test, feature = "canister-api-test-failpoints"))]
227 clear_failpoint();
228 crate::stable::meta::clear_superblock_cache();
229 crate::sqlite_vfs::stable_blob::invalidate_read_cache();
230}
231
232pub(crate) fn clear_failed_initialization(context: ContextId) {
233 DB_MEMORY.with(|memory| {
234 memory.borrow_mut().remove(&context);
235 });
236 DEFAULT_CONTEXT.with(|default| {
237 if default.get() == Some(context) {
238 default.set(None);
239 }
240 });
241 CURRENT_CONTEXT.with(|current| {
242 if current.get() == Some(context) {
243 current.set(None);
244 }
245 });
246 #[cfg(any(test, feature = "canister-api-test-failpoints"))]
247 FAILPOINTS.with(|slot| {
248 slot.borrow_mut().remove(&context);
249 });
250 crate::stable::meta::clear_superblock_cache();
251 crate::sqlite_vfs::stable_blob::invalidate_read_cache();
252}
253
254#[cfg(test)]
255pub fn snapshot_for_tests() -> Vec<u8> {
256 let len = usize::try_from(size_pages().saturating_mul(STABLE_PAGE_SIZE))
257 .expect("test memory size fits usize");
258 let mut out = vec![0_u8; len];
259 read(0, &mut out).expect("test memory snapshot succeeds");
260 out
261}
262
263#[cfg(test)]
264pub fn restore_for_tests(snapshot: Vec<u8>) -> DbMemory {
265 reset_for_tests();
266 let memory = memory_for_tests();
267 let pages = u64::try_from(snapshot.len())
268 .expect("snapshot len fits u64")
269 .div_ceil(STABLE_PAGE_SIZE);
270 if pages > 0 {
271 assert!(memory.grow(pages) >= 0, "snapshot memory grows");
272 memory.write(0, &snapshot);
273 }
274 crate::stable::meta::clear_superblock_cache();
275 memory
276}
277
278#[cfg(any(test, debug_assertions))]
279pub fn memory_for_tests() -> DbMemory {
280 MemoryManager::init(DefaultMemoryImpl::default()).get(MemoryId::new(42))
281}
282
283fn with_memory<T>(f: impl FnOnce(&DbMemory) -> T) -> Result<T, StableMemoryError> {
284 let context = active_context_id()?;
285 DB_MEMORY.with(|slot| {
286 let slot = slot.borrow();
287 let memory = slot
288 .get(&context)
289 .ok_or(StableMemoryError::NotInitialized)?;
290 Ok(f(memory))
291 })
292}
293
294struct ContextGuard {
295 previous: Option<ContextId>,
296}
297
298impl Drop for ContextGuard {
299 fn drop(&mut self) {
300 CURRENT_CONTEXT.with(|current| current.set(self.previous));
301 }
302}
303
304#[cfg(any(test, feature = "canister-api-test-failpoints"))]
305#[derive(Clone, Copy, Debug)]
306struct MemoryFailpointState {
307 failpoint: Option<MemoryFailpoint>,
308 grow_count: u64,
309 write_count: u64,
310}
311
312fn checked_end(offset: u64, len: usize) -> Result<u64, StableMemoryError> {
313 let len = u64::try_from(len).map_err(|_| StableMemoryError::OffsetOverflow)?;
314 offset
315 .checked_add(len)
316 .ok_or(StableMemoryError::OffsetOverflow)
317}
318
319#[cfg(any(test, feature = "canister-api-test-failpoints"))]
320fn hit_grow_failpoint() -> bool {
321 let Ok(context) = active_context_id() else {
322 return false;
323 };
324 FAILPOINTS.with(|slot| {
325 let mut slot = slot.borrow_mut();
326 let Some(state) = slot.get_mut(&context) else {
327 return false;
328 };
329 state.grow_count += 1;
330 if state.failpoint
331 == Some(MemoryFailpoint::GrowFailed {
332 ordinal: state.grow_count,
333 })
334 {
335 state.failpoint = None;
336 true
337 } else {
338 false
339 }
340 })
341}
342
343#[cfg(any(test, feature = "canister-api-test-failpoints"))]
344fn hit_write_trap_failpoint() -> bool {
345 let Ok(context) = active_context_id() else {
346 return false;
347 };
348 FAILPOINTS.with(|slot| {
349 let mut slot = slot.borrow_mut();
350 let Some(state) = slot.get_mut(&context) else {
351 return false;
352 };
353 state.write_count += 1;
354 if state.failpoint
355 == Some(MemoryFailpoint::TrapAfterWrite {
356 ordinal: state.write_count,
357 })
358 {
359 state.failpoint = None;
360 true
361 } else {
362 false
363 }
364 })
365}
366
367#[cfg(all(target_arch = "wasm32", feature = "canister-api-test-failpoints"))]
368fn fail_after_stable_write() -> ! {
369 ic_cdk::trap("stable write failpoint");
370}
371
372#[cfg(all(
373 any(test, feature = "canister-api-test-failpoints"),
374 not(all(target_arch = "wasm32", feature = "canister-api-test-failpoints"))
375))]
376fn fail_after_stable_write() -> ! {
377 panic!("stable write failpoint");
378}