Skip to main content

ic_sqlite_vfs/stable/
memory_manager.rs

1//! Minimal fork of `ic-stable-structures` MemoryManager 0.7 layout.
2//!
3//! The fork keeps the existing on-stable-memory format, but removes unrelated
4//! stable data structures from this crate's dependency graph.
5
6use crate::config::STABLE_PAGE_SIZE;
7pub use crate::stable::memory_layout::MemoryId;
8use crate::stable::memory_layout::{
9    bucket_allocations_address, write_growing, BucketCache, BucketId, VirtualSegment,
10    BUCKETS_OFFSET_IN_BYTES, BUCKETS_OFFSET_IN_PAGES, BUCKET_SIZE_IN_PAGES, HEADER_RESERVED_BYTES,
11    HEADER_SIZE, LAYOUT_VERSION, MAGIC, MAX_NUM_BUCKETS, MAX_NUM_MEMORIES,
12    UNALLOCATED_BUCKET_MARKER,
13};
14use crate::stable::memory_manager_validation::{load_validated_layout, try_load_validated_layout};
15use crate::stable::raw_memory::Memory;
16use std::cell::RefCell;
17use std::rc::Rc;
18
19#[derive(Clone)]
20pub struct MemoryManager<M: Memory> {
21    inner: Rc<RefCell<MemoryManagerInner<M>>>,
22}
23
24#[derive(Debug, thiserror::Error)]
25pub enum MemoryManagerInitError {
26    #[error("bucket size must be greater than zero")]
27    BucketSizeIsZero,
28    #[error("non-empty memory does not contain a MemoryManager layout")]
29    NonMemoryManagerLayout,
30    #[error("{0}")]
31    InvalidLayout(String),
32}
33
34impl<M: Memory> MemoryManager<M> {
35    pub fn init(memory: M) -> Self {
36        Self::init_with_bucket_size(memory, BUCKET_SIZE_IN_PAGES as u16)
37    }
38
39    pub fn init_strict(memory: M) -> Result<Self, MemoryManagerInitError> {
40        Self::init_strict_with_bucket_size(memory, BUCKET_SIZE_IN_PAGES as u16)
41    }
42
43    pub fn init_with_bucket_size(memory: M, bucket_size_in_pages: u16) -> Self {
44        if bucket_size_in_pages == 0 {
45            panic!("bucket size must be greater than zero");
46        }
47        Self {
48            inner: Rc::new(RefCell::new(MemoryManagerInner::init(
49                memory,
50                bucket_size_in_pages,
51            ))),
52        }
53    }
54
55    pub fn init_strict_with_bucket_size(
56        memory: M,
57        bucket_size_in_pages: u16,
58    ) -> Result<Self, MemoryManagerInitError> {
59        if bucket_size_in_pages == 0 {
60            return Err(MemoryManagerInitError::BucketSizeIsZero);
61        }
62        Ok(Self {
63            inner: Rc::new(RefCell::new(MemoryManagerInner::init_strict(
64                memory,
65                bucket_size_in_pages,
66            )?)),
67        })
68    }
69
70    pub fn get(&self, id: MemoryId) -> VirtualMemory<M> {
71        VirtualMemory {
72            id,
73            memory_manager: Rc::clone(&self.inner),
74            cache: BucketCache::new(),
75        }
76    }
77}
78#[derive(Clone)]
79pub struct VirtualMemory<M: Memory> {
80    id: MemoryId,
81    memory_manager: Rc<RefCell<MemoryManagerInner<M>>>,
82    cache: BucketCache,
83}
84impl<M: Memory> Memory for VirtualMemory<M> {
85    fn size(&self) -> u64 {
86        self.memory_manager.borrow().memory_size(self.id)
87    }
88
89    fn grow(&self, pages: u64) -> i64 {
90        self.memory_manager.borrow_mut().grow(self.id, pages)
91    }
92
93    fn read(&self, offset: u64, dst: &mut [u8]) {
94        self.memory_manager
95            .borrow()
96            .read(self.id, offset, dst, &self.cache);
97    }
98
99    unsafe fn read_unsafe(&self, offset: u64, dst: *mut u8, count: usize) {
100        self.memory_manager
101            .borrow()
102            .read_unsafe(self.id, offset, dst, count, &self.cache);
103    }
104
105    fn write(&self, offset: u64, src: &[u8]) {
106        self.memory_manager
107            .borrow()
108            .write(self.id, offset, src, &self.cache);
109    }
110}
111
112#[derive(Clone)]
113struct MemoryManagerInner<M: Memory> {
114    memory: M,
115    allocated_buckets: u16,
116    bucket_size_in_pages: u16,
117    memory_sizes_in_pages: [u64; MAX_NUM_MEMORIES as usize],
118    memory_buckets: Vec<Vec<BucketId>>,
119}
120impl<M: Memory> MemoryManagerInner<M> {
121    fn init(memory: M, bucket_size_in_pages: u16) -> Self {
122        if memory.size() == 0 {
123            return Self::new(memory, bucket_size_in_pages);
124        }
125
126        let mut magic = [0_u8; 3];
127        memory.read(0, &mut magic);
128        if &magic == MAGIC {
129            Self::load(memory)
130        } else {
131            Self::new(memory, bucket_size_in_pages)
132        }
133    }
134
135    fn init_strict(memory: M, bucket_size_in_pages: u16) -> Result<Self, MemoryManagerInitError> {
136        if memory.size() == 0 {
137            return Ok(Self::new(memory, bucket_size_in_pages));
138        }
139
140        let mut magic = [0_u8; 3];
141        memory.read(0, &mut magic);
142        if &magic != MAGIC {
143            return Err(MemoryManagerInitError::NonMemoryManagerLayout);
144        }
145        Self::try_load(memory)
146    }
147
148    fn new(memory: M, bucket_size_in_pages: u16) -> Self {
149        let manager = Self {
150            memory,
151            allocated_buckets: 0,
152            bucket_size_in_pages,
153            memory_sizes_in_pages: [0; MAX_NUM_MEMORIES as usize],
154            memory_buckets: vec![Vec::new(); MAX_NUM_MEMORIES as usize],
155        };
156        write_growing(
157            &manager.memory,
158            bucket_allocations_address(BucketId(0)),
159            &[UNALLOCATED_BUCKET_MARKER; MAX_NUM_BUCKETS as usize],
160        );
161        manager.save_header();
162        manager
163    }
164    fn load(memory: M) -> Self {
165        let mut header = vec![0_u8; HEADER_SIZE as usize];
166        memory.read(0, &mut header);
167        assert_eq!(&header[0..3], MAGIC, "Bad magic.");
168        assert_eq!(header[3], LAYOUT_VERSION, "Unsupported version.");
169        let layout = load_validated_layout(&memory, &header);
170
171        Self {
172            memory,
173            allocated_buckets: layout.allocated_buckets,
174            bucket_size_in_pages: layout.bucket_size_in_pages,
175            memory_sizes_in_pages: layout.memory_sizes_in_pages,
176            memory_buckets: layout.memory_buckets,
177        }
178    }
179
180    fn try_load(memory: M) -> Result<Self, MemoryManagerInitError> {
181        let mut header = vec![0_u8; HEADER_SIZE as usize];
182        memory.read(0, &mut header);
183        if &header[0..3] != MAGIC {
184            return Err(MemoryManagerInitError::NonMemoryManagerLayout);
185        }
186        if header[3] != LAYOUT_VERSION {
187            return Err(MemoryManagerInitError::InvalidLayout(
188                "Unsupported version.".to_string(),
189            ));
190        }
191        let layout = try_load_validated_layout(&memory, &header)
192            .map_err(|error| MemoryManagerInitError::InvalidLayout(error.to_string()))?;
193
194        Ok(Self {
195            memory,
196            allocated_buckets: layout.allocated_buckets,
197            bucket_size_in_pages: layout.bucket_size_in_pages,
198            memory_sizes_in_pages: layout.memory_sizes_in_pages,
199            memory_buckets: layout.memory_buckets,
200        })
201    }
202
203    fn save_header(&self) {
204        let mut header = [0_u8; HEADER_SIZE as usize];
205        header[0..3].copy_from_slice(MAGIC);
206        header[3] = LAYOUT_VERSION;
207        header[4..6].copy_from_slice(&self.allocated_buckets.to_le_bytes());
208        header[6..8].copy_from_slice(&self.bucket_size_in_pages.to_le_bytes());
209        let mut offset = 3 + 1 + 2 + 2 + HEADER_RESERVED_BYTES;
210        for size in self.memory_sizes_in_pages {
211            header[offset..offset + 8].copy_from_slice(&size.to_le_bytes());
212            offset += 8;
213        }
214        write_growing(&self.memory, 0, &header);
215    }
216
217    fn memory_size(&self, id: MemoryId) -> u64 {
218        self.memory_sizes_in_pages[id.0 as usize]
219    }
220
221    fn grow(&mut self, id: MemoryId, pages: u64) -> i64 {
222        let old_size = self.memory_size(id);
223        let Some(new_size) = old_size.checked_add(pages) else {
224            return -1;
225        };
226        let current_buckets = self.num_buckets_needed(old_size);
227        let required_buckets = self.num_buckets_needed(new_size);
228        let new_buckets = required_buckets - current_buckets;
229        let Some(target_allocated_buckets) =
230            new_buckets.checked_add(u64::from(self.allocated_buckets))
231        else {
232            return -1;
233        };
234        if target_allocated_buckets > MAX_NUM_BUCKETS {
235            return -1;
236        }
237        let Ok(new_buckets_len) = usize::try_from(new_buckets) else {
238            return -1;
239        };
240        let memory_bucket = &mut self.memory_buckets[id.0 as usize];
241        if memory_bucket.try_reserve(new_buckets_len).is_err() {
242            return -1;
243        }
244        let mut rollback_buckets = Vec::new();
245        if rollback_buckets.try_reserve(new_buckets_len).is_err() {
246            return -1;
247        }
248
249        let Some(data_pages) =
250            u64::from(self.bucket_size_in_pages).checked_mul(target_allocated_buckets)
251        else {
252            return -1;
253        };
254        let Some(pages_needed) = BUCKETS_OFFSET_IN_PAGES.checked_add(data_pages) else {
255            return -1;
256        };
257        let current_pages = self.memory.size();
258        if pages_needed > current_pages {
259            let previous = self.memory.grow(pages_needed - current_pages);
260            if previous < 0 {
261                return -1;
262            }
263        }
264
265        let mut rollback = AllocationRollback {
266            memory: std::ptr::addr_of!(self.memory),
267            buckets: rollback_buckets,
268            committed: false,
269            _memory: std::marker::PhantomData,
270        };
271        for _ in 0..new_buckets {
272            let bucket = BucketId(self.allocated_buckets);
273            memory_bucket.push(bucket);
274            write_growing(&self.memory, bucket_allocations_address(bucket), &[id.0]);
275            rollback.buckets.push(bucket);
276            self.allocated_buckets = self
277                .allocated_buckets
278                .checked_add(1)
279                .expect("allocated bucket count overflow");
280        }
281
282        self.memory_sizes_in_pages[id.0 as usize] = new_size;
283        self.save_header();
284        rollback.committed = true;
285        old_size as i64
286    }
287
288    fn read(&self, id: MemoryId, offset: u64, dst: &mut [u8], cache: &BucketCache) {
289        unsafe { self.read_unsafe(id, offset, dst.as_mut_ptr(), dst.len(), cache) }
290    }
291
292    unsafe fn read_unsafe(
293        &self,
294        id: MemoryId,
295        offset: u64,
296        dst: *mut u8,
297        count: usize,
298        cache: &BucketCache,
299    ) {
300        if count == 0 {
301            return;
302        }
303        self.assert_bounds(id, offset, count as u64, "read");
304        if let Some(real) = cache.get(VirtualSegment::new(offset, count as u64)) {
305            self.memory.read_unsafe(real, dst, count);
306            return;
307        }
308        let mut bytes_read = 0_u64;
309        self.for_each_bucket(id, offset, count as u64, cache, |address, len| {
310            self.memory
311                .read_unsafe(address, dst.add(bytes_read as usize), len as usize);
312            bytes_read += len;
313        });
314    }
315
316    fn write(&self, id: MemoryId, offset: u64, src: &[u8], cache: &BucketCache) {
317        if src.is_empty() {
318            return;
319        }
320        self.assert_bounds(id, offset, src.len() as u64, "write");
321        if let Some(real) = cache.get(VirtualSegment::new(offset, src.len() as u64)) {
322            self.memory.write(real, src);
323            return;
324        }
325        let mut written = 0_u64;
326        self.for_each_bucket(id, offset, src.len() as u64, cache, |address, len| {
327            self.memory
328                .write(address, &src[written as usize..(written + len) as usize]);
329            written += len;
330        });
331    }
332
333    fn for_each_bucket(
334        &self,
335        MemoryId(id): MemoryId,
336        offset: u64,
337        mut len: u64,
338        cache: &BucketCache,
339        mut f: impl FnMut(u64, u64),
340    ) {
341        let bucket_size = self.bucket_size_in_bytes();
342        let buckets = self.memory_buckets[id as usize].as_slice();
343        let mut bucket_idx = (offset / bucket_size) as usize;
344        let mut bucket_offset = offset % bucket_size;
345        while len > 0 {
346            let bucket = buckets.get(bucket_idx).expect("bucket idx out of bounds");
347            let bucket_address = self.bucket_address(*bucket);
348            let segment_len = (bucket_size - bucket_offset).min(len);
349            cache.store(
350                VirtualSegment::new(bucket_idx as u64 * bucket_size, bucket_size),
351                bucket_address,
352            );
353            f(bucket_address + bucket_offset, segment_len);
354            len -= segment_len;
355            bucket_idx += 1;
356            bucket_offset = 0;
357        }
358    }
359
360    fn assert_bounds(&self, id: MemoryId, offset: u64, len: u64, operation: &str) {
361        let end = offset
362            .checked_add(len)
363            .unwrap_or_else(|| panic!("{id:?}: {operation} out of bounds"));
364        let capacity = self
365            .memory_size(id)
366            .checked_mul(STABLE_PAGE_SIZE)
367            .unwrap_or_else(|| panic!("{id:?}: {operation} out of bounds"));
368        assert!(end <= capacity, "{id:?}: {operation} out of bounds");
369    }
370
371    fn bucket_size_in_bytes(&self) -> u64 {
372        u64::from(self.bucket_size_in_pages) * STABLE_PAGE_SIZE
373    }
374
375    fn num_buckets_needed(&self, pages: u64) -> u64 {
376        pages.div_ceil(u64::from(self.bucket_size_in_pages))
377    }
378
379    fn bucket_address(&self, id: BucketId) -> u64 {
380        BUCKETS_OFFSET_IN_BYTES + self.bucket_size_in_bytes() * u64::from(id.0)
381    }
382}
383
384struct AllocationRollback<'memory, M: Memory> {
385    memory: *const M,
386    buckets: Vec<BucketId>,
387    committed: bool,
388    _memory: std::marker::PhantomData<&'memory M>,
389}
390
391impl<M: Memory> Drop for AllocationRollback<'_, M> {
392    fn drop(&mut self) {
393        if self.committed || !std::thread::panicking() {
394            return;
395        }
396        for bucket in self.buckets.iter().copied() {
397            let memory = unsafe { &*self.memory };
398            write_growing(
399                memory,
400                bucket_allocations_address(bucket),
401                &[UNALLOCATED_BUCKET_MARKER],
402            );
403        }
404    }
405}