1use crate::config::STABLE_PAGE_SIZE;
7pub use crate::stable::memory_layout::MemoryId;
8use crate::stable::memory_layout::{
9 bucket_allocations_address, write_growing, BucketCache, BucketId, VirtualSegment,
10 BUCKETS_OFFSET_IN_BYTES, BUCKETS_OFFSET_IN_PAGES, BUCKET_SIZE_IN_PAGES, HEADER_RESERVED_BYTES,
11 HEADER_SIZE, LAYOUT_VERSION, MAGIC, MAX_NUM_BUCKETS, MAX_NUM_MEMORIES,
12 UNALLOCATED_BUCKET_MARKER,
13};
14use crate::stable::memory_manager_validation::{load_validated_layout, try_load_validated_layout};
15use crate::stable::raw_memory::Memory;
16use std::cell::RefCell;
17use std::rc::Rc;
18
19#[derive(Clone)]
20pub struct MemoryManager<M: Memory> {
21 inner: Rc<RefCell<MemoryManagerInner<M>>>,
22}
23
24#[derive(Debug, thiserror::Error)]
25pub enum MemoryManagerInitError {
26 #[error("bucket size must be greater than zero")]
27 BucketSizeIsZero,
28 #[error("non-empty memory does not contain a MemoryManager layout")]
29 NonMemoryManagerLayout,
30 #[error("{0}")]
31 InvalidLayout(String),
32}
33
34impl<M: Memory> MemoryManager<M> {
35 pub fn init(memory: M) -> Self {
36 Self::init_with_bucket_size(memory, BUCKET_SIZE_IN_PAGES as u16)
37 }
38
39 pub fn init_strict(memory: M) -> Result<Self, MemoryManagerInitError> {
40 Self::init_strict_with_bucket_size(memory, BUCKET_SIZE_IN_PAGES as u16)
41 }
42
43 pub fn init_with_bucket_size(memory: M, bucket_size_in_pages: u16) -> Self {
44 if bucket_size_in_pages == 0 {
45 panic!("bucket size must be greater than zero");
46 }
47 Self {
48 inner: Rc::new(RefCell::new(MemoryManagerInner::init(
49 memory,
50 bucket_size_in_pages,
51 ))),
52 }
53 }
54
55 pub fn init_strict_with_bucket_size(
56 memory: M,
57 bucket_size_in_pages: u16,
58 ) -> Result<Self, MemoryManagerInitError> {
59 if bucket_size_in_pages == 0 {
60 return Err(MemoryManagerInitError::BucketSizeIsZero);
61 }
62 Ok(Self {
63 inner: Rc::new(RefCell::new(MemoryManagerInner::init_strict(
64 memory,
65 bucket_size_in_pages,
66 )?)),
67 })
68 }
69
70 pub fn get(&self, id: MemoryId) -> VirtualMemory<M> {
71 VirtualMemory {
72 id,
73 memory_manager: Rc::clone(&self.inner),
74 cache: BucketCache::new(),
75 }
76 }
77}
78#[derive(Clone)]
79pub struct VirtualMemory<M: Memory> {
80 id: MemoryId,
81 memory_manager: Rc<RefCell<MemoryManagerInner<M>>>,
82 cache: BucketCache,
83}
84impl<M: Memory> Memory for VirtualMemory<M> {
85 fn size(&self) -> u64 {
86 self.memory_manager.borrow().memory_size(self.id)
87 }
88
89 fn grow(&self, pages: u64) -> i64 {
90 self.memory_manager.borrow_mut().grow(self.id, pages)
91 }
92
93 fn read(&self, offset: u64, dst: &mut [u8]) {
94 self.memory_manager
95 .borrow()
96 .read(self.id, offset, dst, &self.cache);
97 }
98
99 unsafe fn read_unsafe(&self, offset: u64, dst: *mut u8, count: usize) {
100 self.memory_manager
101 .borrow()
102 .read_unsafe(self.id, offset, dst, count, &self.cache);
103 }
104
105 fn write(&self, offset: u64, src: &[u8]) {
106 self.memory_manager
107 .borrow()
108 .write(self.id, offset, src, &self.cache);
109 }
110}
111
112#[derive(Clone)]
113struct MemoryManagerInner<M: Memory> {
114 memory: M,
115 allocated_buckets: u16,
116 bucket_size_in_pages: u16,
117 memory_sizes_in_pages: [u64; MAX_NUM_MEMORIES as usize],
118 memory_buckets: Vec<Vec<BucketId>>,
119}
120impl<M: Memory> MemoryManagerInner<M> {
121 fn init(memory: M, bucket_size_in_pages: u16) -> Self {
122 if memory.size() == 0 {
123 return Self::new(memory, bucket_size_in_pages);
124 }
125
126 let mut magic = [0_u8; 3];
127 memory.read(0, &mut magic);
128 if &magic == MAGIC {
129 Self::load(memory)
130 } else {
131 Self::new(memory, bucket_size_in_pages)
132 }
133 }
134
135 fn init_strict(memory: M, bucket_size_in_pages: u16) -> Result<Self, MemoryManagerInitError> {
136 if memory.size() == 0 {
137 return Ok(Self::new(memory, bucket_size_in_pages));
138 }
139
140 let mut magic = [0_u8; 3];
141 memory.read(0, &mut magic);
142 if &magic != MAGIC {
143 return Err(MemoryManagerInitError::NonMemoryManagerLayout);
144 }
145 Self::try_load(memory)
146 }
147
148 fn new(memory: M, bucket_size_in_pages: u16) -> Self {
149 let manager = Self {
150 memory,
151 allocated_buckets: 0,
152 bucket_size_in_pages,
153 memory_sizes_in_pages: [0; MAX_NUM_MEMORIES as usize],
154 memory_buckets: vec![Vec::new(); MAX_NUM_MEMORIES as usize],
155 };
156 write_growing(
157 &manager.memory,
158 bucket_allocations_address(BucketId(0)),
159 &[UNALLOCATED_BUCKET_MARKER; MAX_NUM_BUCKETS as usize],
160 );
161 manager.save_header();
162 manager
163 }
164 fn load(memory: M) -> Self {
165 let mut header = vec![0_u8; HEADER_SIZE as usize];
166 memory.read(0, &mut header);
167 assert_eq!(&header[0..3], MAGIC, "Bad magic.");
168 assert_eq!(header[3], LAYOUT_VERSION, "Unsupported version.");
169 let layout = load_validated_layout(&memory, &header);
170
171 Self {
172 memory,
173 allocated_buckets: layout.allocated_buckets,
174 bucket_size_in_pages: layout.bucket_size_in_pages,
175 memory_sizes_in_pages: layout.memory_sizes_in_pages,
176 memory_buckets: layout.memory_buckets,
177 }
178 }
179
180 fn try_load(memory: M) -> Result<Self, MemoryManagerInitError> {
181 let mut header = vec![0_u8; HEADER_SIZE as usize];
182 memory.read(0, &mut header);
183 if &header[0..3] != MAGIC {
184 return Err(MemoryManagerInitError::NonMemoryManagerLayout);
185 }
186 if header[3] != LAYOUT_VERSION {
187 return Err(MemoryManagerInitError::InvalidLayout(
188 "Unsupported version.".to_string(),
189 ));
190 }
191 let layout = try_load_validated_layout(&memory, &header)
192 .map_err(|error| MemoryManagerInitError::InvalidLayout(error.to_string()))?;
193
194 Ok(Self {
195 memory,
196 allocated_buckets: layout.allocated_buckets,
197 bucket_size_in_pages: layout.bucket_size_in_pages,
198 memory_sizes_in_pages: layout.memory_sizes_in_pages,
199 memory_buckets: layout.memory_buckets,
200 })
201 }
202
203 fn save_header(&self) {
204 let mut header = [0_u8; HEADER_SIZE as usize];
205 header[0..3].copy_from_slice(MAGIC);
206 header[3] = LAYOUT_VERSION;
207 header[4..6].copy_from_slice(&self.allocated_buckets.to_le_bytes());
208 header[6..8].copy_from_slice(&self.bucket_size_in_pages.to_le_bytes());
209 let mut offset = 3 + 1 + 2 + 2 + HEADER_RESERVED_BYTES;
210 for size in self.memory_sizes_in_pages {
211 header[offset..offset + 8].copy_from_slice(&size.to_le_bytes());
212 offset += 8;
213 }
214 write_growing(&self.memory, 0, &header);
215 }
216
217 fn memory_size(&self, id: MemoryId) -> u64 {
218 self.memory_sizes_in_pages[id.0 as usize]
219 }
220
221 fn grow(&mut self, id: MemoryId, pages: u64) -> i64 {
222 let old_size = self.memory_size(id);
223 let Some(new_size) = old_size.checked_add(pages) else {
224 return -1;
225 };
226 let current_buckets = self.num_buckets_needed(old_size);
227 let required_buckets = self.num_buckets_needed(new_size);
228 let new_buckets = required_buckets - current_buckets;
229 let Some(target_allocated_buckets) =
230 new_buckets.checked_add(u64::from(self.allocated_buckets))
231 else {
232 return -1;
233 };
234 if target_allocated_buckets > MAX_NUM_BUCKETS {
235 return -1;
236 }
237 let Ok(new_buckets_len) = usize::try_from(new_buckets) else {
238 return -1;
239 };
240 let memory_bucket = &mut self.memory_buckets[id.0 as usize];
241 if memory_bucket.try_reserve(new_buckets_len).is_err() {
242 return -1;
243 }
244 let mut rollback_buckets = Vec::new();
245 if rollback_buckets.try_reserve(new_buckets_len).is_err() {
246 return -1;
247 }
248
249 let Some(data_pages) =
250 u64::from(self.bucket_size_in_pages).checked_mul(target_allocated_buckets)
251 else {
252 return -1;
253 };
254 let Some(pages_needed) = BUCKETS_OFFSET_IN_PAGES.checked_add(data_pages) else {
255 return -1;
256 };
257 let current_pages = self.memory.size();
258 if pages_needed > current_pages {
259 let previous = self.memory.grow(pages_needed - current_pages);
260 if previous < 0 {
261 return -1;
262 }
263 }
264
265 let mut rollback = AllocationRollback {
266 memory: std::ptr::addr_of!(self.memory),
267 buckets: rollback_buckets,
268 committed: false,
269 _memory: std::marker::PhantomData,
270 };
271 for _ in 0..new_buckets {
272 let bucket = BucketId(self.allocated_buckets);
273 memory_bucket.push(bucket);
274 write_growing(&self.memory, bucket_allocations_address(bucket), &[id.0]);
275 rollback.buckets.push(bucket);
276 self.allocated_buckets = self
277 .allocated_buckets
278 .checked_add(1)
279 .expect("allocated bucket count overflow");
280 }
281
282 self.memory_sizes_in_pages[id.0 as usize] = new_size;
283 self.save_header();
284 rollback.committed = true;
285 old_size as i64
286 }
287
288 fn read(&self, id: MemoryId, offset: u64, dst: &mut [u8], cache: &BucketCache) {
289 unsafe { self.read_unsafe(id, offset, dst.as_mut_ptr(), dst.len(), cache) }
290 }
291
292 unsafe fn read_unsafe(
293 &self,
294 id: MemoryId,
295 offset: u64,
296 dst: *mut u8,
297 count: usize,
298 cache: &BucketCache,
299 ) {
300 if count == 0 {
301 return;
302 }
303 self.assert_bounds(id, offset, count as u64, "read");
304 if let Some(real) = cache.get(VirtualSegment::new(offset, count as u64)) {
305 self.memory.read_unsafe(real, dst, count);
306 return;
307 }
308 let mut bytes_read = 0_u64;
309 self.for_each_bucket(id, offset, count as u64, cache, |address, len| {
310 self.memory
311 .read_unsafe(address, dst.add(bytes_read as usize), len as usize);
312 bytes_read += len;
313 });
314 }
315
316 fn write(&self, id: MemoryId, offset: u64, src: &[u8], cache: &BucketCache) {
317 if src.is_empty() {
318 return;
319 }
320 self.assert_bounds(id, offset, src.len() as u64, "write");
321 if let Some(real) = cache.get(VirtualSegment::new(offset, src.len() as u64)) {
322 self.memory.write(real, src);
323 return;
324 }
325 let mut written = 0_u64;
326 self.for_each_bucket(id, offset, src.len() as u64, cache, |address, len| {
327 self.memory
328 .write(address, &src[written as usize..(written + len) as usize]);
329 written += len;
330 });
331 }
332
333 fn for_each_bucket(
334 &self,
335 MemoryId(id): MemoryId,
336 offset: u64,
337 mut len: u64,
338 cache: &BucketCache,
339 mut f: impl FnMut(u64, u64),
340 ) {
341 let bucket_size = self.bucket_size_in_bytes();
342 let buckets = self.memory_buckets[id as usize].as_slice();
343 let mut bucket_idx = (offset / bucket_size) as usize;
344 let mut bucket_offset = offset % bucket_size;
345 while len > 0 {
346 let bucket = buckets.get(bucket_idx).expect("bucket idx out of bounds");
347 let bucket_address = self.bucket_address(*bucket);
348 let segment_len = (bucket_size - bucket_offset).min(len);
349 cache.store(
350 VirtualSegment::new(bucket_idx as u64 * bucket_size, bucket_size),
351 bucket_address,
352 );
353 f(bucket_address + bucket_offset, segment_len);
354 len -= segment_len;
355 bucket_idx += 1;
356 bucket_offset = 0;
357 }
358 }
359
360 fn assert_bounds(&self, id: MemoryId, offset: u64, len: u64, operation: &str) {
361 let end = offset
362 .checked_add(len)
363 .unwrap_or_else(|| panic!("{id:?}: {operation} out of bounds"));
364 let capacity = self
365 .memory_size(id)
366 .checked_mul(STABLE_PAGE_SIZE)
367 .unwrap_or_else(|| panic!("{id:?}: {operation} out of bounds"));
368 assert!(end <= capacity, "{id:?}: {operation} out of bounds");
369 }
370
371 fn bucket_size_in_bytes(&self) -> u64 {
372 u64::from(self.bucket_size_in_pages) * STABLE_PAGE_SIZE
373 }
374
375 fn num_buckets_needed(&self, pages: u64) -> u64 {
376 pages.div_ceil(u64::from(self.bucket_size_in_pages))
377 }
378
379 fn bucket_address(&self, id: BucketId) -> u64 {
380 BUCKETS_OFFSET_IN_BYTES + self.bucket_size_in_bytes() * u64::from(id.0)
381 }
382}
383
384struct AllocationRollback<'memory, M: Memory> {
385 memory: *const M,
386 buckets: Vec<BucketId>,
387 committed: bool,
388 _memory: std::marker::PhantomData<&'memory M>,
389}
390
391impl<M: Memory> Drop for AllocationRollback<'_, M> {
392 fn drop(&mut self) {
393 if self.committed || !std::thread::panicking() {
394 return;
395 }
396 for bucket in self.buckets.iter().copied() {
397 let memory = unsafe { &*self.memory };
398 write_growing(
399 memory,
400 bucket_allocations_address(bucket),
401 &[UNALLOCATED_BUCKET_MARKER],
402 );
403 }
404 }
405}