1use crate::{
44 read_struct,
45 types::{Address, Bytes},
46 write, write_struct, Memory, WASM_PAGE_SIZE,
47};
48use std::cell::RefCell;
49use std::cmp::min;
50use std::collections::BTreeMap;
51use std::rc::Rc;
52
53const MAGIC: &[u8; 3] = b"MGR";
54const LAYOUT_VERSION: u8 = 1;
55
56const MAX_NUM_MEMORIES: u8 = 255;
58
59const MAX_NUM_BUCKETS: u64 = 32768;
62
63const BUCKET_SIZE_IN_PAGES: u64 = 128;
64
65const UNALLOCATED_BUCKET_MARKER: u8 = MAX_NUM_MEMORIES;
67
68const BUCKETS_OFFSET_IN_PAGES: u64 = 1;
70const BUCKETS_OFFSET_IN_BYTES: u64 = BUCKETS_OFFSET_IN_PAGES * WASM_PAGE_SIZE;
71
72const HEADER_RESERVED_BYTES: usize = 32;
74
75pub struct MemoryManager<M: Memory> {
132 inner: Rc<RefCell<MemoryManagerInner<M>>>,
133}
134
135impl<M: Memory> MemoryManager<M> {
136 pub fn init(memory: M) -> Self {
138 Self::init_with_bucket_size(memory, BUCKET_SIZE_IN_PAGES as u16)
139 }
140
141 pub fn init_with_bucket_size(memory: M, bucket_size_in_pages: u16) -> Self {
143 Self {
144 inner: Rc::new(RefCell::new(MemoryManagerInner::init(
145 memory,
146 bucket_size_in_pages,
147 ))),
148 }
149 }
150
151 pub fn get(&self, id: MemoryId) -> VirtualMemory<M> {
153 VirtualMemory {
154 id,
155 memory_manager: self.inner.clone(),
156 }
157 }
158}
159
160#[repr(C, packed)]
161struct Header {
162 magic: [u8; 3],
163
164 version: u8,
165
166 num_allocated_buckets: u16,
168
169 bucket_size_in_pages: u16,
171
172 _reserved: [u8; HEADER_RESERVED_BYTES],
174
175 memory_sizes_in_pages: [u64; MAX_NUM_MEMORIES as usize],
177}
178
179impl Header {
180 fn size() -> Bytes {
181 Bytes::new(core::mem::size_of::<Self>() as u64)
182 }
183}
184
185#[derive(Clone)]
186pub struct VirtualMemory<M: Memory> {
187 id: MemoryId,
188 memory_manager: Rc<RefCell<MemoryManagerInner<M>>>,
189}
190
191impl<M: Memory> Memory for VirtualMemory<M> {
192 fn size(&self) -> u64 {
193 self.memory_manager.borrow().memory_size(self.id)
194 }
195
196 fn grow(&self, pages: u64) -> i64 {
197 self.memory_manager.borrow_mut().grow(self.id, pages)
198 }
199
200 fn read(&self, offset: u64, dst: &mut [u8]) {
201 self.memory_manager.borrow().read(self.id, offset, dst)
202 }
203
204 fn write(&self, offset: u64, src: &[u8]) {
205 self.memory_manager.borrow().write(self.id, offset, src)
206 }
207}
208
209#[derive(Clone)]
210struct MemoryManagerInner<M: Memory> {
211 memory: M,
212
213 allocated_buckets: u16,
215
216 bucket_size_in_pages: u16,
217
218 memory_sizes_in_pages: [u64; MAX_NUM_MEMORIES as usize],
220
221 memory_buckets: BTreeMap<MemoryId, Vec<BucketId>>,
223}
224
225impl<M: Memory> MemoryManagerInner<M> {
226 fn init(memory: M, bucket_size_in_pages: u16) -> Self {
227 if memory.size() == 0 {
228 return Self::new(memory, bucket_size_in_pages);
230 }
231
232 let mut dst = vec![0; 3];
234 memory.read(0, &mut dst);
235 if dst != MAGIC {
236 MemoryManagerInner::new(memory, bucket_size_in_pages)
238 } else {
239 MemoryManagerInner::load(memory)
241 }
242 }
243
244 fn new(memory: M, bucket_size_in_pages: u16) -> Self {
245 let mem_mgr = Self {
246 memory,
247 allocated_buckets: 0,
248 memory_sizes_in_pages: [0; MAX_NUM_MEMORIES as usize],
249 memory_buckets: BTreeMap::new(),
250 bucket_size_in_pages,
251 };
252
253 mem_mgr.save_header();
254
255 write(
257 &mem_mgr.memory,
258 bucket_allocations_address(BucketId(0)).get(),
259 &[UNALLOCATED_BUCKET_MARKER; MAX_NUM_BUCKETS as usize],
260 );
261
262 mem_mgr
263 }
264
265 fn load(memory: M) -> Self {
266 let header: Header = read_struct(Address::from(0), &memory);
268 assert_eq!(&header.magic, MAGIC, "Bad magic.");
269 assert_eq!(header.version, LAYOUT_VERSION, "Unsupported version.");
270
271 let mut buckets = vec![0; MAX_NUM_BUCKETS as usize];
272 memory.read(bucket_allocations_address(BucketId(0)).get(), &mut buckets);
273
274 let mut memory_buckets = BTreeMap::new();
275 for (bucket_idx, memory) in buckets.into_iter().enumerate() {
276 if memory != UNALLOCATED_BUCKET_MARKER {
277 memory_buckets
278 .entry(MemoryId(memory))
279 .or_insert_with(Vec::new)
280 .push(BucketId(bucket_idx as u16));
281 }
282 }
283
284 Self {
285 memory,
286 allocated_buckets: header.num_allocated_buckets,
287 bucket_size_in_pages: header.bucket_size_in_pages,
288 memory_sizes_in_pages: header.memory_sizes_in_pages,
289 memory_buckets,
290 }
291 }
292
293 fn save_header(&self) {
294 let header = Header {
295 magic: *MAGIC,
296 version: LAYOUT_VERSION,
297 num_allocated_buckets: self.allocated_buckets,
298 bucket_size_in_pages: self.bucket_size_in_pages,
299 _reserved: [0; HEADER_RESERVED_BYTES],
300 memory_sizes_in_pages: self.memory_sizes_in_pages,
301 };
302
303 write_struct(&header, Address::from(0), &self.memory);
304 }
305
306 fn memory_size(&self, id: MemoryId) -> u64 {
308 self.memory_sizes_in_pages[id.0 as usize]
309 }
310
311 fn grow(&mut self, id: MemoryId, pages: u64) -> i64 {
313 let old_size = self.memory_size(id);
315 let new_size = old_size + pages;
316 let current_buckets = self.num_buckets_needed(old_size);
317 let required_buckets = self.num_buckets_needed(new_size);
318 let new_buckets_needed = required_buckets - current_buckets;
319
320 if new_buckets_needed + self.allocated_buckets as u64 > MAX_NUM_BUCKETS {
321 return -1;
323 }
324
325 for _ in 0..new_buckets_needed {
327 let new_bucket_id = BucketId(self.allocated_buckets);
328
329 self.memory_buckets
330 .entry(id)
331 .or_insert_with(Vec::new)
332 .push(new_bucket_id);
333
334 write(
336 &self.memory,
337 bucket_allocations_address(new_bucket_id).get(),
338 &[id.0],
339 );
340
341 self.allocated_buckets += 1;
342 }
343
344 let pages_needed = BUCKETS_OFFSET_IN_PAGES
346 + self.bucket_size_in_pages as u64 * self.allocated_buckets as u64;
347 if pages_needed > self.memory.size() {
348 let additional_pages_needed = pages_needed - self.memory.size();
349 let prev_pages = self.memory.grow(additional_pages_needed);
350 if prev_pages == -1 {
351 panic!("{id:?}: grow failed");
352 }
353 }
354
355 self.memory_sizes_in_pages[id.0 as usize] = new_size;
357
358 self.save_header();
360 old_size as i64
361 }
362
363 fn write(&self, id: MemoryId, offset: u64, src: &[u8]) {
364 if (offset + src.len() as u64) > self.memory_size(id) * WASM_PAGE_SIZE {
365 panic!("{id:?}: write out of bounds");
366 }
367
368 let mut bytes_written = 0;
369 for Segment { address, length } in self.bucket_iter(id, offset, src.len()) {
370 self.memory.write(
371 address.get(),
372 &src[bytes_written as usize..(bytes_written + length.get()) as usize],
373 );
374
375 bytes_written += length.get();
376 }
377 }
378
379 fn read(&self, id: MemoryId, offset: u64, dst: &mut [u8]) {
380 if (offset + dst.len() as u64) > self.memory_size(id) * WASM_PAGE_SIZE {
381 panic!("{id:?}: read out of bounds");
382 }
383
384 let mut bytes_read = 0;
385 for Segment { address, length } in self.bucket_iter(id, offset, dst.len()) {
386 self.memory.read(
387 address.get(),
388 &mut dst[bytes_read as usize..(bytes_read + length.get()) as usize],
389 );
390
391 bytes_read += length.get();
392 }
393 }
394
395 fn bucket_iter(&self, id: MemoryId, offset: u64, length: usize) -> BucketIterator {
397 let buckets = match self.memory_buckets.get(&id) {
399 Some(s) => s.as_slice(),
400 None => &[],
401 };
402
403 BucketIterator {
404 virtual_segment: Segment {
405 address: Address::from(offset),
406 length: Bytes::from(length as u64),
407 },
408 buckets,
409 bucket_size_in_bytes: self.bucket_size_in_bytes(),
410 }
411 }
412
413 fn bucket_size_in_bytes(&self) -> Bytes {
414 Bytes::from(self.bucket_size_in_pages as u64 * WASM_PAGE_SIZE)
415 }
416
417 fn num_buckets_needed(&self, num_pages: u64) -> u64 {
419 (num_pages + self.bucket_size_in_pages as u64 - 1) / self.bucket_size_in_pages as u64
421 }
422}
423
424struct Segment {
425 address: Address,
426 length: Bytes,
427}
428
429struct BucketIterator<'a> {
453 virtual_segment: Segment,
454 buckets: &'a [BucketId],
455 bucket_size_in_bytes: Bytes,
456}
457
458impl Iterator for BucketIterator<'_> {
459 type Item = Segment;
460
461 fn next(&mut self) -> Option<Self::Item> {
462 if self.virtual_segment.length == Bytes::from(0u64) {
463 return None;
464 }
465
466 let bucket_idx =
468 (self.virtual_segment.address.get() / self.bucket_size_in_bytes.get()) as usize;
469 let bucket_address = self.bucket_address(
470 *self
471 .buckets
472 .get(bucket_idx)
473 .expect("bucket idx out of bounds"),
474 );
475
476 let real_address = bucket_address
477 + Bytes::from(self.virtual_segment.address.get() % self.bucket_size_in_bytes.get());
478
479 let bytes_in_segment = {
481 let next_bucket_address = bucket_address + self.bucket_size_in_bytes;
482
483 min(
485 Bytes::from(next_bucket_address.get() - real_address.get()),
486 self.virtual_segment.length,
487 )
488 };
489
490 self.virtual_segment.length -= bytes_in_segment;
492 self.virtual_segment.address += bytes_in_segment;
493
494 Some(Segment {
495 address: real_address,
496 length: bytes_in_segment,
497 })
498 }
499}
500
501impl<'a> BucketIterator<'a> {
502 fn bucket_address(&self, id: BucketId) -> Address {
504 Address::from(BUCKETS_OFFSET_IN_BYTES) + self.bucket_size_in_bytes * Bytes::from(id.0)
505 }
506}
507
508#[derive(Clone, Copy, Ord, Eq, PartialEq, PartialOrd, Debug)]
509pub struct MemoryId(u8);
510
511impl MemoryId {
512 pub const fn new(id: u8) -> Self {
513 assert!(id != UNALLOCATED_BUCKET_MARKER);
516
517 Self(id)
518 }
519}
520
521impl crate::Storable for MemoryId {
522 fn to_bytes(&self) -> std::borrow::Cow<[u8]> {
523 self.0.to_bytes()
524 }
525
526 fn from_bytes(bytes: std::borrow::Cow<[u8]>) -> Self {
527 Self(u8::from_bytes(bytes))
528 }
529
530 const BOUND: crate::storable::Bound = crate::storable::Bound::Bounded {
531 max_size: 1,
532 is_fixed_size: true,
533 };
534}
535
536#[derive(Clone, Copy, Debug, PartialEq)]
538struct BucketId(u16);
539
540fn bucket_allocations_address(id: BucketId) -> Address {
541 Address::from(0) + Header::size() + Bytes::from(id.0)
542}
543
544#[cfg(test)]
545mod test {
546 use super::*;
547 use maplit::btreemap;
548 use proptest::prelude::*;
549
550 const MAX_MEMORY_IN_PAGES: u64 = MAX_NUM_BUCKETS * BUCKET_SIZE_IN_PAGES;
551
552 fn make_memory() -> Rc<RefCell<Vec<u8>>> {
553 Rc::new(RefCell::new(Vec::new()))
554 }
555
556 #[test]
557 fn can_get_memory() {
558 let mem_mgr = MemoryManager::init(make_memory());
559 let memory = mem_mgr.get(MemoryId(0));
560 assert_eq!(memory.size(), 0);
561 }
562
563 #[test]
564 fn can_allocate_and_use_memory() {
565 let mem_mgr = MemoryManager::init(make_memory());
566 let memory = mem_mgr.get(MemoryId(0));
567 assert_eq!(memory.grow(1), 0);
568 assert_eq!(memory.size(), 1);
569
570 memory.write(0, &[1, 2, 3]);
571
572 let mut bytes = vec![0; 3];
573 memory.read(0, &mut bytes);
574 assert_eq!(bytes, vec![1, 2, 3]);
575
576 assert_eq!(
577 mem_mgr.inner.borrow().memory_buckets,
578 btreemap! {
579 MemoryId(0) => vec![BucketId(0)]
580 }
581 );
582 }
583
584 #[test]
585 fn can_allocate_and_use_multiple_memories() {
586 let mem = make_memory();
587 let mem_mgr = MemoryManager::init(mem.clone());
588 let memory_0 = mem_mgr.get(MemoryId(0));
589 let memory_1 = mem_mgr.get(MemoryId(1));
590
591 assert_eq!(memory_0.grow(1), 0);
592 assert_eq!(memory_1.grow(1), 0);
593
594 assert_eq!(memory_0.size(), 1);
595 assert_eq!(memory_1.size(), 1);
596
597 assert_eq!(
598 mem_mgr.inner.borrow().memory_buckets,
599 btreemap! {
600 MemoryId(0) => vec![BucketId(0)],
601 MemoryId(1) => vec![BucketId(1)],
602 }
603 );
604
605 memory_0.write(0, &[1, 2, 3]);
606 memory_0.write(0, &[1, 2, 3]);
607 memory_1.write(0, &[4, 5, 6]);
608
609 let mut bytes = vec![0; 3];
610 memory_0.read(0, &mut bytes);
611 assert_eq!(bytes, vec![1, 2, 3]);
612
613 let mut bytes = vec![0; 3];
614 memory_1.read(0, &mut bytes);
615 assert_eq!(bytes, vec![4, 5, 6]);
616
617 assert_eq!(mem.size(), 2 * BUCKET_SIZE_IN_PAGES + 1);
619 }
620
621 #[test]
622 fn can_be_reinitialized_from_memory() {
623 let mem = make_memory();
624 let mem_mgr = MemoryManager::init(mem.clone());
625 let memory_0 = mem_mgr.get(MemoryId(0));
626 let memory_1 = mem_mgr.get(MemoryId(1));
627
628 assert_eq!(memory_0.grow(1), 0);
629 assert_eq!(memory_1.grow(1), 0);
630
631 memory_0.write(0, &[1, 2, 3]);
632 memory_1.write(0, &[4, 5, 6]);
633
634 let mem_mgr = MemoryManager::init(mem);
635 let memory_0 = mem_mgr.get(MemoryId(0));
636 let memory_1 = mem_mgr.get(MemoryId(1));
637
638 let mut bytes = vec![0; 3];
639 memory_0.read(0, &mut bytes);
640 assert_eq!(bytes, vec![1, 2, 3]);
641
642 memory_1.read(0, &mut bytes);
643 assert_eq!(bytes, vec![4, 5, 6]);
644 }
645
646 #[test]
647 fn growing_same_memory_multiple_times_doesnt_increase_underlying_allocation() {
648 let mem = make_memory();
649 let mem_mgr = MemoryManager::init(mem.clone());
650 let memory_0 = mem_mgr.get(MemoryId(0));
651
652 assert_eq!(memory_0.grow(1), 0);
655 assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES);
656
657 assert_eq!(memory_0.grow(1), 1);
659 assert_eq!(memory_0.size(), 2);
660 assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES);
661
662 assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES - 2), 2);
665 assert_eq!(memory_0.size(), BUCKET_SIZE_IN_PAGES);
666 assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES);
667
668 assert_eq!(memory_0.grow(1), BUCKET_SIZE_IN_PAGES as i64);
670 assert_eq!(memory_0.size(), BUCKET_SIZE_IN_PAGES + 1);
671 assert_eq!(mem.size(), 1 + 2 * BUCKET_SIZE_IN_PAGES);
672 }
673
674 #[test]
675 fn does_not_grow_memory_unnecessarily() {
676 let mem = make_memory();
677 let initial_size = BUCKET_SIZE_IN_PAGES * 2;
678
679 mem.grow(initial_size);
681
682 let mem_mgr = MemoryManager::init(mem.clone());
683 let memory_0 = mem_mgr.get(MemoryId(0));
684
685 assert_eq!(memory_0.grow(1), 0);
687 assert_eq!(mem.size(), initial_size);
688
689 assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES), 1);
692 assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES * 2);
693 }
694
695 #[test]
696 fn growing_beyond_capacity_fails() {
697 let mem = make_memory();
698 let mem_mgr = MemoryManager::init(mem);
699 let memory_0 = mem_mgr.get(MemoryId(0));
700
701 assert_eq!(memory_0.grow(MAX_MEMORY_IN_PAGES + 1), -1);
702
703 assert_eq!(memory_0.grow(1), 0); assert_eq!(memory_0.grow(MAX_MEMORY_IN_PAGES), -1); }
707
708 #[test]
709 fn can_write_across_bucket_boundaries() {
710 let mem = make_memory();
711 let mem_mgr = MemoryManager::init(mem);
712 let memory_0 = mem_mgr.get(MemoryId(0));
713
714 assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES + 1), 0);
715
716 memory_0.write(
717 mem_mgr.inner.borrow().bucket_size_in_bytes().get() - 1,
718 &[1, 2, 3],
719 );
720
721 let mut bytes = vec![0; 3];
722 memory_0.read(
723 mem_mgr.inner.borrow().bucket_size_in_bytes().get() - 1,
724 &mut bytes,
725 );
726 assert_eq!(bytes, vec![1, 2, 3]);
727 }
728
729 #[test]
730 fn can_write_across_bucket_boundaries_with_interleaving_memories() {
731 let mem = make_memory();
732 let mem_mgr = MemoryManager::init(mem);
733 let memory_0 = mem_mgr.get(MemoryId(0));
734 let memory_1 = mem_mgr.get(MemoryId(1));
735
736 assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES), 0);
737 assert_eq!(memory_1.grow(1), 0);
738 assert_eq!(memory_0.grow(1), BUCKET_SIZE_IN_PAGES as i64);
739
740 memory_0.write(
741 mem_mgr.inner.borrow().bucket_size_in_bytes().get() - 1,
742 &[1, 2, 3],
743 );
744 memory_1.write(0, &[4, 5, 6]);
745
746 let mut bytes = vec![0; 3];
747 memory_0.read(WASM_PAGE_SIZE * BUCKET_SIZE_IN_PAGES - 1, &mut bytes);
748 assert_eq!(bytes, vec![1, 2, 3]);
749
750 let mut bytes = vec![0; 3];
751 memory_1.read(0, &mut bytes);
752 assert_eq!(bytes, vec![4, 5, 6]);
753 }
754
755 #[test]
756 #[should_panic]
757 fn reading_out_of_bounds_should_panic() {
758 let mem = make_memory();
759 let mem_mgr = MemoryManager::init(mem);
760 let memory_0 = mem_mgr.get(MemoryId(0));
761 let memory_1 = mem_mgr.get(MemoryId(1));
762
763 assert_eq!(memory_0.grow(1), 0);
764 assert_eq!(memory_1.grow(1), 0);
765
766 let mut bytes = vec![0; WASM_PAGE_SIZE as usize + 1];
767 memory_0.read(0, &mut bytes);
768 }
769
770 #[test]
771 #[should_panic]
772 fn writing_out_of_bounds_should_panic() {
773 let mem = make_memory();
774 let mem_mgr = MemoryManager::init(mem);
775 let memory_0 = mem_mgr.get(MemoryId(0));
776 let memory_1 = mem_mgr.get(MemoryId(1));
777
778 assert_eq!(memory_0.grow(1), 0);
779 assert_eq!(memory_1.grow(1), 0);
780
781 let bytes = vec![0; WASM_PAGE_SIZE as usize + 1];
782 memory_0.write(0, &bytes);
783 }
784
785 #[test]
786 fn reading_zero_bytes_from_empty_memory_should_not_panic() {
787 let mem = make_memory();
788 let mem_mgr = MemoryManager::init(mem);
789 let memory_0 = mem_mgr.get(MemoryId(0));
790
791 assert_eq!(memory_0.size(), 0);
792 let mut bytes = vec![];
793 memory_0.read(0, &mut bytes);
794 }
795
796 #[test]
797 fn writing_zero_bytes_to_empty_memory_should_not_panic() {
798 let mem = make_memory();
799 let mem_mgr = MemoryManager::init(mem);
800 let memory_0 = mem_mgr.get(MemoryId(0));
801
802 assert_eq!(memory_0.size(), 0);
803 memory_0.write(0, &[]);
804 }
805
806 #[test]
807 fn write_and_read_random_bytes() {
808 let mem = make_memory();
809 let mem_mgr = MemoryManager::init_with_bucket_size(mem, 1); let memories: Vec<_> = (0..MAX_NUM_MEMORIES)
812 .map(|id| mem_mgr.get(MemoryId(id)))
813 .collect();
814
815 proptest!(|(
816 num_memories in 0..255usize,
817 data in proptest::collection::vec(0..u8::MAX, 0..2*WASM_PAGE_SIZE as usize),
818 offset in 0..10*WASM_PAGE_SIZE
819 )| {
820 for memory in memories.iter().take(num_memories) {
821 write(memory, offset, &data);
823
824 let mut bytes = vec![0; data.len()];
826 memory.read(offset, &mut bytes);
827 assert_eq!(bytes, data);
828 }
829 });
830 }
831
832 #[test]
833 fn init_with_non_default_bucket_size() {
834 let bucket_size = 256;
836 assert_ne!(bucket_size, BUCKET_SIZE_IN_PAGES as u16);
837
838 let mem = make_memory();
840 let mem_mgr = MemoryManager::init_with_bucket_size(mem.clone(), bucket_size);
841
842 let memory_0 = mem_mgr.get(MemoryId(0));
844 let memory_1 = mem_mgr.get(MemoryId(1));
845 memory_0.grow(300);
846 memory_1.grow(100);
847 memory_0.write(0, &[1; 1000]);
848 memory_1.write(0, &[2; 1000]);
849
850 let mem_mgr = MemoryManager::init(mem);
853
854 assert_eq!(mem_mgr.inner.borrow().bucket_size_in_pages, bucket_size);
856
857 let memory_0 = mem_mgr.get(MemoryId(0));
859 let memory_1 = mem_mgr.get(MemoryId(1));
860
861 assert_eq!(memory_0.size(), 300);
862 assert_eq!(memory_1.size(), 100);
863
864 let mut buf = vec![0; 1000];
865 memory_0.read(0, &mut buf);
866 assert_eq!(buf, vec![1; 1000]);
867
868 memory_1.read(0, &mut buf);
869 assert_eq!(buf, vec![2; 1000]);
870 }
871}