1use crate::encoding::dyn_size::candid_decode_one_allow_trailing;
13use crate::encoding::{AsDynSizeBytes, AsFixedSizeBytes, Buffer};
14use crate::mem::free_block::FreeBlock;
15use crate::mem::s_slice::SSlice;
16use crate::mem::StablePtr;
17use crate::primitive::s_box::SBox;
18use crate::primitive::StableType;
19use crate::utils::math::ceil_div;
20use crate::{stable, OutOfMemory, PAGE_SIZE_BYTES};
21use candid::{encode_one, CandidType, Deserialize};
22use std::collections::{BTreeMap, HashMap};
23
24pub(crate) const ALLOCATOR_PTR: StablePtr = 0;
25pub(crate) const MIN_PTR: StablePtr = u64::SIZE as u64;
26pub(crate) const EMPTY_PTR: StablePtr = u64::MAX;
27
28#[doc(hidden)]
29#[derive(Debug, CandidType, Deserialize, Eq, PartialEq)]
30pub struct StableMemoryAllocator {
31 free_blocks: BTreeMap<u64, Vec<FreeBlock>>,
32 custom_data_pointers: HashMap<usize, StablePtr>,
33 free_size: u64,
34 available_size: u64,
35 max_ptr: StablePtr,
36 max_pages: u64,
37}
38
39impl StableMemoryAllocator {
40 pub fn init(max_pages: u64) -> Self {
41 let mut it = Self {
42 max_ptr: MIN_PTR,
43 free_blocks: BTreeMap::default(),
44 custom_data_pointers: HashMap::default(),
45 free_size: 0,
46 available_size: 0,
47 max_pages,
48 };
49
50 let available_pages = stable::size_pages();
51 if it.max_pages != 0 && available_pages > it.max_pages {
52 it.max_pages = available_pages;
53 }
54
55 let real_max_ptr = available_pages * PAGE_SIZE_BYTES;
56 if real_max_ptr > it.max_ptr {
57 let free_block = FreeBlock::new_total_size(it.max_ptr, real_max_ptr - it.max_ptr);
58 it.more_free_size(free_block.get_total_size_bytes());
59 it.more_available_size(free_block.get_total_size_bytes());
60
61 it.push_free_block(free_block);
62 it.max_ptr = real_max_ptr;
63 }
64
65 it
66 }
67
68 pub fn make_sure_can_allocate(&mut self, mut size: u64) -> bool {
69 size = Self::pad_size(size);
70
71 if self.free_blocks.range(size..).next().is_some() {
72 return true;
73 }
74
75 if self.max_ptr > MIN_PTR {
76 if let Some(last_free_block) =
77 FreeBlock::from_rear_ptr(self.max_ptr - StablePtr::SIZE as u64)
78 {
79 size -= last_free_block.get_size_bytes();
80 }
81 }
82
83 match self.grow(size) {
84 Ok(fb) => {
85 self.more_available_size(fb.get_total_size_bytes());
86 self.more_free_size(fb.get_total_size_bytes());
87
88 self.push_free_block(fb);
89
90 true
91 }
92 Err(_) => false,
93 }
94 }
95
96 #[allow(clippy::never_loop)]
97 pub fn allocate(&mut self, mut size: u64) -> Result<SSlice, OutOfMemory> {
98 size = Self::pad_size(size);
99
100 let free_block = loop {
102 if let Some(fb) = self.pop_free_block(size) {
103 break fb;
104 } else {
105 if self.max_ptr > MIN_PTR {
106 if let Some(last_free_block) =
107 FreeBlock::from_rear_ptr(self.max_ptr - StablePtr::SIZE as u64)
108 {
109 let fb = self.grow(size - last_free_block.get_size_bytes())?;
110
111 self.more_available_size(fb.get_total_size_bytes());
112 self.more_free_size(fb.get_total_size_bytes());
113
114 self.remove_free_block(&last_free_block);
115
116 break FreeBlock::merge(last_free_block, fb);
117 }
118 }
119
120 let fb = self.grow(size)?;
121
122 self.more_available_size(fb.get_total_size_bytes());
123 self.more_free_size(fb.get_total_size_bytes());
124
125 break fb;
126 }
127 };
128
129 let slice = if FreeBlock::can_split(free_block.get_size_bytes(), size) {
131 let (a, b) = free_block.split(size);
132 let s = a.to_allocated();
133
134 self.push_free_block(b);
135
136 s
137 } else {
138 free_block.to_allocated()
139 };
140
141 self.less_free_size(slice.get_total_size_bytes());
142
143 Ok(slice)
144 }
145
146 #[inline]
147 pub fn deallocate(&mut self, slice: SSlice) {
148 let free_block = slice.to_free_block();
149
150 self.more_free_size(free_block.get_total_size_bytes());
151 self.push_free_block(free_block);
152 }
153
154 pub fn reallocate(&mut self, slice: SSlice, mut new_size: u64) -> Result<SSlice, OutOfMemory> {
155 new_size = Self::pad_size(new_size);
156
157 if new_size <= slice.get_size_bytes() {
158 return Ok(slice);
159 }
160
161 let free_block = slice.to_free_block();
162
163 if let Ok(fb) = self.try_reallocate_in_place(free_block, new_size) {
165 return Ok(fb);
166 }
167
168 if !self.make_sure_can_allocate(new_size) {
170 return Err(OutOfMemory);
171 }
172
173 let mut b = vec![0u8; slice.get_size_bytes().try_into().unwrap()];
175 unsafe { crate::mem::read_bytes(slice.offset(0), &mut b) };
176
177 self.more_free_size(free_block.get_total_size_bytes());
179 self.push_free_block(free_block);
180
181 let new_slice = self.allocate(new_size).unwrap();
183
184 unsafe { crate::mem::write_bytes(new_slice.offset(0), &b) };
186
187 Ok(new_slice)
188 }
189
190 pub fn store(&mut self) -> Result<(), OutOfMemory> {
191 let buf = self.as_dyn_size_bytes();
193
194 let slice = self.allocate(buf.len() as u64 + 100)?;
196
197 let buf = self.as_dyn_size_bytes();
198
199 unsafe { crate::mem::write_bytes(slice.offset(0), &buf) };
200 unsafe { crate::mem::write_fixed(0, &mut slice.as_ptr()) };
201
202 Ok(())
203 }
204
205 pub fn retrieve() -> Self {
206 let slice_ptr = unsafe { crate::mem::read_fixed_for_reference(0) };
207 let slice = unsafe { SSlice::from_ptr(slice_ptr).unwrap() };
208
209 let mut buf = vec![0u8; slice.get_size_bytes() as usize];
210 unsafe { crate::mem::read_bytes(slice.offset(0), &mut buf) };
211
212 let mut it = Self::from_dyn_size_bytes(&buf);
213 it.deallocate(slice);
214
215 it
216 }
217
218 #[inline]
219 pub fn get_allocated_size(&self) -> u64 {
220 self.available_size - self.free_size
221 }
222
223 #[inline]
224 pub fn get_available_size(&self) -> u64 {
225 self.available_size
226 }
227
228 #[inline]
229 pub fn get_free_size(&self) -> u64 {
230 self.free_size
231 }
232
233 #[inline]
234 fn more_available_size(&mut self, additional: u64) {
235 self.available_size += additional;
236 }
237
238 #[inline]
239 fn more_free_size(&mut self, additional: u64) {
240 self.free_size += additional;
241 }
242
243 #[inline]
244 fn less_free_size(&mut self, additional: u64) {
245 self.free_size -= additional;
246 }
247
248 #[inline]
249 pub fn store_custom_data<T: AsDynSizeBytes + StableType>(
250 &mut self,
251 idx: usize,
252 mut data: SBox<T>,
253 ) {
254 unsafe { data.stable_drop_flag_off() };
255
256 self.custom_data_pointers.insert(idx, data.as_ptr());
257 }
258
259 #[inline]
260 pub fn retrieve_custom_data<T: AsDynSizeBytes + StableType>(
261 &mut self,
262 idx: usize,
263 ) -> Option<SBox<T>> {
264 let mut b = unsafe { SBox::from_ptr(self.custom_data_pointers.remove(&idx)?) };
265 unsafe { SBox::<T>::stable_drop_flag_on(&mut b) };
266
267 Some(b)
268 }
269
270 #[inline]
271 pub fn get_max_pages(&self) -> u64 {
272 self.max_pages
273 }
274
275 fn try_reallocate_in_place(
276 &mut self,
277 mut free_block: FreeBlock,
278 new_size: u64,
279 ) -> Result<SSlice, Result<FreeBlock, OutOfMemory>> {
280 if let Some(mut next_neighbor) = free_block.next_neighbor_is_free(self.max_ptr) {
281 let mut merged_size = FreeBlock::merged_size(&free_block, &next_neighbor);
282
283 if merged_size < new_size {
284 if next_neighbor.get_next_neighbor_ptr() != self.max_ptr {
285 return Err(Ok(free_block));
286 }
287
288 let fb = self.grow(new_size).map_err(Err)?;
289
290 self.more_available_size(fb.get_total_size_bytes());
291
292 self.less_free_size(next_neighbor.get_total_size_bytes());
293 self.remove_free_block(&next_neighbor);
294
295 next_neighbor = FreeBlock::merge(next_neighbor, fb);
296 merged_size = FreeBlock::merged_size(&free_block, &next_neighbor);
297 } else {
298 self.less_free_size(next_neighbor.get_total_size_bytes());
299 self.remove_free_block(&next_neighbor);
300 }
301
302 free_block = FreeBlock::merge(free_block, next_neighbor);
303
304 if !FreeBlock::can_split(merged_size, new_size) {
305 return Ok(free_block.to_allocated());
306 }
307
308 let (free_block, b) = free_block.split(new_size);
309
310 let slice = free_block.to_allocated();
311
312 self.more_free_size(b.get_total_size_bytes());
313 self.push_free_block(b);
314
315 return Ok(slice);
316 }
317
318 Err(Ok(free_block))
319 }
320
321 fn try_merge_with_neighbors(&mut self, mut free_block: FreeBlock) -> FreeBlock {
322 if let Some(prev_neighbor) = free_block.prev_neighbor_is_free() {
323 self.remove_free_block(&prev_neighbor);
324
325 free_block = FreeBlock::merge(prev_neighbor, free_block);
326 };
327
328 if let Some(next_neighbor) = free_block.next_neighbor_is_free(self.max_ptr) {
329 self.remove_free_block(&next_neighbor);
330
331 free_block = FreeBlock::merge(free_block, next_neighbor);
332 }
333
334 free_block
335 }
336
337 fn push_free_block(&mut self, mut free_block: FreeBlock) {
338 free_block = self.try_merge_with_neighbors(free_block);
339
340 free_block.persist();
341
342 let blocks = self
343 .free_blocks
344 .entry(free_block.get_size_bytes())
345 .or_default();
346
347 let idx = match blocks.binary_search(&free_block) {
348 Ok(_) => unreachable!("there can't be two blocks of the same ptr"),
349 Err(idx) => idx,
350 };
351
352 blocks.insert(idx, free_block);
353 }
354
355 fn pop_free_block(&mut self, size: u64) -> Option<FreeBlock> {
356 let (&actual_size, blocks) = self.free_blocks.range_mut(size..).next()?;
357
358 let free_block = unsafe { blocks.pop().unwrap_unchecked() };
359
360 if blocks.is_empty() {
361 self.free_blocks.remove(&actual_size);
362 }
363
364 Some(free_block)
365 }
366
367 fn remove_free_block(&mut self, block: &FreeBlock) {
368 let blocks = self.free_blocks.get_mut(&block.get_size_bytes()).unwrap();
369
370 match blocks.binary_search(block) {
371 Ok(idx) => {
372 blocks.remove(idx);
373
374 if blocks.is_empty() {
375 self.free_blocks.remove(&block.get_size_bytes());
376 }
377 }
378 Err(_) => unreachable!("Free block not found {:?} {:?}", block, self.free_blocks),
379 };
380 }
381
382 fn grow(&mut self, mut size: u64) -> Result<FreeBlock, OutOfMemory> {
383 size = FreeBlock::to_total_size(size);
384 let pages_to_grow = ceil_div(size, PAGE_SIZE_BYTES);
385 let available_pages = stable::size_pages();
386
387 if self.max_pages != 0 && available_pages + pages_to_grow > self.max_pages {
388 return Err(OutOfMemory);
389 }
390
391 if stable::grow(pages_to_grow).is_err() {
392 return Err(OutOfMemory);
393 }
394
395 let new_max_ptr = (available_pages + pages_to_grow) * PAGE_SIZE_BYTES;
396 let it = FreeBlock::new_total_size(self.max_ptr, new_max_ptr - self.max_ptr);
397
398 self.max_ptr = new_max_ptr;
399
400 Ok(it)
401 }
402
403 pub fn debug_validate_free_blocks(&self) {
404 assert!(
405 self.available_size == 0
406 || self.available_size == stable::size_pages() * PAGE_SIZE_BYTES - MIN_PTR
407 );
408
409 let mut total_free_size = 0u64;
410 for blocks in self.free_blocks.values() {
411 for free_block in blocks {
412 free_block.debug_validate();
413
414 total_free_size += free_block.get_total_size_bytes();
415 }
416 }
417
418 assert_eq!(total_free_size, self.free_size);
419 }
420
421 pub fn _free_blocks_count(&self) -> usize {
422 let mut count = 0;
423
424 for blocks in self.free_blocks.values() {
425 for _ in blocks {
426 count += 1;
427 }
428 }
429
430 count
431 }
432
433 #[inline]
436 fn pad_size(size: u64) -> u64 {
437 if size < (StablePtr::SIZE * 2) as u64 {
438 return (StablePtr::SIZE * 2) as u64;
439 }
440
441 (size + 7) & !7
442 }
443}
444
445impl AsDynSizeBytes for StableMemoryAllocator {
446 #[inline]
447 fn as_dyn_size_bytes(&self) -> Vec<u8> {
448 encode_one(self).unwrap()
449 }
450
451 #[inline]
452 fn from_dyn_size_bytes(buf: &[u8]) -> Self {
453 candid_decode_one_allow_trailing(buf).unwrap()
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use crate::encoding::AsDynSizeBytes;
460 use crate::mem::allocator::StableMemoryAllocator;
461 use crate::primitive::s_box::SBox;
462 use crate::utils::mem_context::stable;
463 use crate::SSlice;
464 use rand::rngs::ThreadRng;
465 use rand::seq::SliceRandom;
466 use rand::{thread_rng, Rng};
467
468 #[test]
469 fn encoding_works_fine() {
470 let mut sma = StableMemoryAllocator::init(0);
471 sma.allocate(100);
472
473 let buf = sma.as_dyn_size_bytes();
474 let sma_1 = StableMemoryAllocator::from_dyn_size_bytes(&buf);
475
476 assert_eq!(sma, sma_1);
477
478 println!("original {:?}", sma);
479 println!("new {:?}", sma_1);
480 }
481
482 #[test]
483 fn initialization_growing_works_fine() {
484 stable::clear();
485 stable::grow(1).unwrap();
486
487 unsafe {
488 let mut sma = StableMemoryAllocator::init(0);
489 println!("{:?}", sma);
490
491 let slice = sma.allocate(100).unwrap();
492 println!("{:?}", sma);
493
494 assert_eq!(sma._free_blocks_count(), 1);
495
496 sma.store();
497
498 println!("after store {:?}", sma);
499 let mut sma = StableMemoryAllocator::retrieve();
500
501 println!("after retrieve {:?}", sma);
502 assert_eq!(sma._free_blocks_count(), 1);
503
504 sma.debug_validate_free_blocks();
505 }
506 }
507
508 #[test]
509 fn initialization_not_growing_works_fine() {
510 stable::clear();
511
512 unsafe {
513 let mut sma = StableMemoryAllocator::init(0);
514 let slice = sma.allocate(100);
515
516 assert_eq!(sma._free_blocks_count(), 1);
517
518 sma.store();
519
520 let sma = StableMemoryAllocator::retrieve();
521 assert_eq!(sma._free_blocks_count(), 1);
522
523 sma.debug_validate_free_blocks();
524 }
525 }
526
527 #[derive(Debug)]
528 enum Action {
529 Alloc(SSlice),
530 AllocOOM(u64),
531 Dealloc(SSlice),
532 Realloc(SSlice, SSlice),
533 ReallocOOM(u64),
534 CanisterUpgrade,
535 CanisterUpgradeOOM,
536 }
537
538 struct Fuzzer {
539 allocator: StableMemoryAllocator,
540 slices: Vec<SSlice>,
541 log: Vec<Action>,
542 total_allocated_size: u64,
543 rng: ThreadRng,
544 }
545
546 impl Fuzzer {
547 fn new(max_pages: u64) -> Self {
548 Self {
549 allocator: StableMemoryAllocator::init(max_pages),
550 slices: Vec::default(),
551 log: Vec::default(),
552 total_allocated_size: 0,
553 rng: thread_rng(),
554 }
555 }
556
557 fn next(&mut self) {
558 match self.rng.gen_range(0..100) {
559 0..=50 => {
561 let size = self.rng.gen_range(0..(u16::MAX as u64 * 2));
562
563 if self.allocator.make_sure_can_allocate(size) {
564 let slice = self.allocator.allocate(size).unwrap();
565
566 self.log.push(Action::Alloc(slice));
567 self.slices.push(slice);
568
569 let mut buf = vec![100u8; slice.get_size_bytes() as usize];
570 unsafe { crate::mem::write_bytes(slice.offset(0), &buf) };
571
572 let mut buf2 = vec![0u8; slice.get_size_bytes() as usize];
573 unsafe { crate::mem::read_bytes(slice.offset(0), &mut buf2) };
574
575 assert_eq!(buf, buf2);
576
577 self.total_allocated_size += slice.get_total_size_bytes() as u64;
578 } else {
579 assert!(self.allocator.allocate(size).is_err());
580 self.log.push(Action::AllocOOM(size));
581 }
582 }
583 51..=75 => {
585 if self.slices.len() < 2 {
586 return self.next();
587 }
588
589 let slice = self.slices.remove(self.rng.gen_range(0..self.slices.len()));
590 self.log.push(Action::Dealloc(slice));
591
592 self.total_allocated_size -= slice.get_total_size_bytes() as u64;
593
594 self.allocator.deallocate(slice);
595 }
596 76..=98 => {
598 if self.slices.len() < 2 {
599 return self.next();
600 }
601
602 let idx_to_remove = self.rng.gen_range(0..self.slices.len());
603 let size = self.rng.gen_range(0..(u16::MAX as u64 * 2));
604
605 let slice = self.slices[idx_to_remove];
606 if let Ok(slice1) = unsafe { self.allocator.reallocate(slice, size) } {
607 self.total_allocated_size -= slice.get_total_size_bytes();
608
609 self.slices.remove(idx_to_remove);
610 self.total_allocated_size += slice1.get_total_size_bytes();
611
612 self.log.push(Action::Realloc(slice, slice1));
613 self.slices.push(slice1);
614
615 let mut buf = vec![100u8; slice1.get_size_bytes() as usize];
616 unsafe { crate::mem::write_bytes(slice1.offset(0), &buf) };
617
618 let mut buf2 = vec![0u8; slice1.get_size_bytes() as usize];
619 unsafe { crate::mem::read_bytes(slice1.offset(0), &mut buf2) };
620
621 assert_eq!(buf, buf2);
622 } else {
623 self.log.push(Action::ReallocOOM(size));
624 }
625 }
626 _ => {
628 if self.allocator.store().is_ok() {
629 self.allocator = StableMemoryAllocator::retrieve();
630
631 self.log.push(Action::CanisterUpgrade);
632 } else {
633 self.log.push(Action::CanisterUpgradeOOM);
634 }
635 }
636 };
637
638 let res = std::panic::catch_unwind(|| {
639 self.allocator.debug_validate_free_blocks();
640 assert_eq!(
641 self.allocator.get_allocated_size(),
642 self.total_allocated_size
643 );
644 });
645
646 if res.is_err() {
647 panic!("{:?} {:?}", self.log.last().unwrap(), self.allocator);
648 }
649 }
650 }
651
652 #[test]
653 fn random_works_fine() {
654 stable::clear();
655
656 let mut fuzzer = Fuzzer::new(0);
657
658 for i in 0..10_000 {
659 fuzzer.next();
660 }
661
662 for action in &fuzzer.log {
663 match action {
664 Action::Alloc(_)
665 | Action::Realloc(_, _)
666 | Action::Dealloc(_)
667 | Action::CanisterUpgrade => {}
668 _ => panic!("Fuzzer cant OOM here"),
669 }
670 }
671
672 let mut fuzzer = Fuzzer::new(30);
673
674 for i in 0..10_000 {
675 fuzzer.next();
676 }
677 }
678
679 #[test]
680 fn allocation_works_fine() {
681 stable::clear();
682
683 let mut sma = StableMemoryAllocator::init(0);
684
685 let mut slices = vec![];
686
687 for i in 0..1024 {
689 let slice = sma.allocate(1024).unwrap();
690
691 assert!(
692 slice.get_size_bytes() >= 1024,
693 "Invalid membox size at {}",
694 i
695 );
696
697 slices.push(slice);
698 }
699
700 assert!(sma.get_allocated_size() >= 1024 * 1024);
701
702 for i in 0..1024 {
703 let mut slice = slices[i];
704 slice = unsafe { sma.reallocate(slice, 2 * 1024).unwrap() };
705
706 assert!(
707 slice.get_size_bytes() >= 2 * 1024,
708 "Invalid membox size at {}",
709 i
710 );
711
712 slices[i] = slice;
713 }
714
715 assert!(sma.get_allocated_size() >= 2 * 1024 * 1024);
716
717 for i in 0..1024 {
718 let slice = slices[i];
719 sma.deallocate(slice);
720 }
721
722 assert_eq!(sma.get_allocated_size(), 0);
723
724 sma.debug_validate_free_blocks();
725 }
726
727 #[test]
728 fn basic_flow_works_fine() {
729 unsafe {
730 stable::clear();
731
732 let mut allocator = StableMemoryAllocator::init(0);
733 allocator.store();
734
735 let mut allocator = StableMemoryAllocator::retrieve();
736
737 println!("before all - {:?}", allocator);
738
739 let slice1 = allocator.allocate(100).unwrap();
740
741 println!("allocate 100 (1) - {:?}", allocator);
742
743 let slice1 = allocator.reallocate(slice1, 200).unwrap();
744
745 println!("reallocate 100 to 200 (1) - {:?}", allocator);
746
747 let slice2 = allocator.allocate(100).unwrap();
748
749 println!("allocate 100 more (2) - {:?}", allocator);
750
751 let slice3 = allocator.allocate(100).unwrap();
752
753 println!("allocate 100 more (3) - {:?}", allocator);
754
755 allocator.deallocate(slice1);
756
757 println!("deallocate (1) - {:?}", allocator);
758
759 let slice2 = allocator.reallocate(slice2, 200).unwrap();
760
761 println!("reallocate (2) - {:?}", allocator);
762
763 allocator.deallocate(slice3);
764
765 println!("deallocate (3) - {:?}", allocator);
766
767 allocator.deallocate(slice2);
768
769 println!("deallocate (2) - {:?}", allocator);
770
771 allocator.store();
772
773 let mut allocator = StableMemoryAllocator::retrieve();
774
775 let mut slices = Vec::new();
776 for _ in 0..5000 {
777 slices.push(allocator.allocate(100).unwrap());
778 }
779
780 slices.shuffle(&mut thread_rng());
781
782 for slice in slices {
783 allocator.deallocate(slice);
784 }
785
786 assert_eq!(allocator.get_allocated_size(), 0);
787 allocator.debug_validate_free_blocks();
788 println!("{:?}", allocator);
789 }
790 }
791}