1use crate::{
44 read_struct,
45 types::{Address, Bytes},
46 write, write_struct, Memory, WASM_PAGE_SIZE,
47};
48use std::cell::RefCell;
49use std::cmp::min;
50use std::collections::BTreeMap;
51use std::rc::Rc;
52
53const MAGIC: &[u8; 3] = b"MGR";
54const LAYOUT_VERSION: u8 = 1;
55
56const MAX_NUM_MEMORIES: u8 = 255;
58
59const MAX_NUM_BUCKETS: u64 = 32768;
62
63const BUCKET_SIZE_IN_PAGES: u64 = 128;
64
65const UNALLOCATED_BUCKET_MARKER: u8 = MAX_NUM_MEMORIES;
67
68const BUCKETS_OFFSET_IN_PAGES: u64 = 1;
70const BUCKETS_OFFSET_IN_BYTES: u64 = BUCKETS_OFFSET_IN_PAGES * WASM_PAGE_SIZE;
71
72const HEADER_RESERVED_BYTES: usize = 32;
74
75pub struct MemoryManager<M: Memory> {
132 inner: Rc<RefCell<MemoryManagerInner<M>>>,
133}
134
135impl<M: Memory> MemoryManager<M> {
136 pub fn init(memory: M) -> Self {
138 Self::init_with_bucket_size(memory, BUCKET_SIZE_IN_PAGES as u16)
139 }
140
141 pub fn init_with_bucket_size(memory: M, bucket_size_in_pages: u16) -> Self {
143 Self {
144 inner: Rc::new(RefCell::new(MemoryManagerInner::init(
145 memory,
146 bucket_size_in_pages,
147 ))),
148 }
149 }
150
151 pub fn get(&self, id: MemoryId) -> VirtualMemory<M> {
153 VirtualMemory {
154 id,
155 memory_manager: self.inner.clone(),
156 }
157 }
158}
159
160#[repr(C, packed)]
161struct Header {
162 magic: [u8; 3],
163
164 version: u8,
165
166 num_allocated_buckets: u16,
168
169 bucket_size_in_pages: u16,
171
172 _reserved: [u8; HEADER_RESERVED_BYTES],
174
175 memory_sizes_in_pages: [u64; MAX_NUM_MEMORIES as usize],
177}
178
179impl Header {
180 fn size() -> Bytes {
181 Bytes::new(core::mem::size_of::<Self>() as u64)
182 }
183}
184
185#[derive(Clone)]
186pub struct VirtualMemory<M: Memory> {
187 id: MemoryId,
188 memory_manager: Rc<RefCell<MemoryManagerInner<M>>>,
189}
190
191impl<M: Memory> Memory for VirtualMemory<M> {
192 fn size(&self) -> u64 {
193 self.memory_manager.borrow().memory_size(self.id)
194 }
195
196 fn grow(&self, pages: u64) -> i64 {
197 self.memory_manager.borrow_mut().grow(self.id, pages)
198 }
199
200 fn read(&self, offset: u64, dst: &mut [u8]) {
201 self.memory_manager.borrow().read(self.id, offset, dst)
202 }
203
204 fn write(&self, offset: u64, src: &[u8]) {
205 self.memory_manager.borrow().write(self.id, offset, src)
206 }
207}
208
209#[derive(Clone)]
210struct MemoryManagerInner<M: Memory> {
211 memory: M,
212
213 allocated_buckets: u16,
215
216 bucket_size_in_pages: u16,
217
218 memory_sizes_in_pages: [u64; MAX_NUM_MEMORIES as usize],
220
221 memory_buckets: BTreeMap<MemoryId, Vec<BucketId>>,
223}
224
225impl<M: Memory> MemoryManagerInner<M> {
226 fn init(memory: M, bucket_size_in_pages: u16) -> Self {
227 if memory.size() == 0 {
228 return Self::new(memory, bucket_size_in_pages);
230 }
231
232 let mut dst = vec![0; 3];
234 memory.read(0, &mut dst);
235 if dst != MAGIC {
236 MemoryManagerInner::new(memory, bucket_size_in_pages)
238 } else {
239 MemoryManagerInner::load(memory)
241 }
242 }
243
244 fn new(memory: M, bucket_size_in_pages: u16) -> Self {
245 let mem_mgr = Self {
246 memory,
247 allocated_buckets: 0,
248 memory_sizes_in_pages: [0; MAX_NUM_MEMORIES as usize],
249 memory_buckets: BTreeMap::new(),
250 bucket_size_in_pages,
251 };
252
253 mem_mgr.save_header();
254
255 write(
257 &mem_mgr.memory,
258 bucket_allocations_address(BucketId(0)).get(),
259 &[UNALLOCATED_BUCKET_MARKER; MAX_NUM_BUCKETS as usize],
260 );
261
262 mem_mgr
263 }
264
265 fn load(memory: M) -> Self {
266 let header: Header = read_struct(Address::from(0), &memory);
268 assert_eq!(&header.magic, MAGIC, "Bad magic.");
269 assert_eq!(header.version, LAYOUT_VERSION, "Unsupported version.");
270
271 let mut buckets = vec![0; MAX_NUM_BUCKETS as usize];
272 memory.read(bucket_allocations_address(BucketId(0)).get(), &mut buckets);
273
274 let mut memory_buckets = BTreeMap::new();
275 for (bucket_idx, memory) in buckets.into_iter().enumerate() {
276 if memory != UNALLOCATED_BUCKET_MARKER {
277 memory_buckets
278 .entry(MemoryId(memory))
279 .or_insert_with(Vec::new)
280 .push(BucketId(bucket_idx as u16));
281 }
282 }
283
284 Self {
285 memory,
286 allocated_buckets: header.num_allocated_buckets,
287 bucket_size_in_pages: header.bucket_size_in_pages,
288 memory_sizes_in_pages: header.memory_sizes_in_pages,
289 memory_buckets,
290 }
291 }
292
293 fn save_header(&self) {
294 let header = Header {
295 magic: *MAGIC,
296 version: LAYOUT_VERSION,
297 num_allocated_buckets: self.allocated_buckets,
298 bucket_size_in_pages: self.bucket_size_in_pages,
299 _reserved: [0; HEADER_RESERVED_BYTES],
300 memory_sizes_in_pages: self.memory_sizes_in_pages,
301 };
302
303 write_struct(&header, Address::from(0), &self.memory);
304 }
305
306 fn memory_size(&self, id: MemoryId) -> u64 {
308 self.memory_sizes_in_pages[id.0 as usize]
309 }
310
311 fn grow(&mut self, id: MemoryId, pages: u64) -> i64 {
313 let old_size = self.memory_size(id);
315 let new_size = old_size + pages;
316 let current_buckets = self.num_buckets_needed(old_size);
317 let required_buckets = self.num_buckets_needed(new_size);
318 let new_buckets_needed = required_buckets - current_buckets;
319
320 if new_buckets_needed + self.allocated_buckets as u64 > MAX_NUM_BUCKETS {
321 return -1;
323 }
324
325 for _ in 0..new_buckets_needed {
327 let new_bucket_id = BucketId(self.allocated_buckets);
328
329 self.memory_buckets
330 .entry(id)
331 .or_insert_with(Vec::new)
332 .push(new_bucket_id);
333
334 write(
336 &self.memory,
337 bucket_allocations_address(new_bucket_id).get(),
338 &[id.0],
339 );
340
341 self.allocated_buckets += 1;
342 }
343
344 let pages_needed = BUCKETS_OFFSET_IN_PAGES
346 + self.bucket_size_in_pages as u64 * self.allocated_buckets as u64;
347 if pages_needed > self.memory.size() {
348 let additional_pages_needed = pages_needed - self.memory.size();
349 let prev_pages = self.memory.grow(additional_pages_needed);
350 if prev_pages == -1 {
351 panic!("{id:?}: grow failed");
352 }
353 }
354
355 self.memory_sizes_in_pages[id.0 as usize] = new_size;
357
358 self.save_header();
360 old_size as i64
361 }
362
363 fn write(&self, id: MemoryId, offset: u64, src: &[u8]) {
364 if (offset + src.len() as u64) > self.memory_size(id) * WASM_PAGE_SIZE {
365 panic!("{id:?}: write out of bounds");
366 }
367
368 let mut bytes_written = 0;
369 for Segment { address, length } in self.bucket_iter(id, offset, src.len()) {
370 self.memory.write(
371 address.get(),
372 &src[bytes_written as usize..(bytes_written + length.get()) as usize],
373 );
374
375 bytes_written += length.get();
376 }
377 }
378
379 fn read(&self, id: MemoryId, offset: u64, dst: &mut [u8]) {
380 if (offset + dst.len() as u64) > self.memory_size(id) * WASM_PAGE_SIZE {
381 panic!("{id:?}: read out of bounds");
382 }
383
384 let mut bytes_read = 0;
385 for Segment { address, length } in self.bucket_iter(id, offset, dst.len()) {
386 self.memory.read(
387 address.get(),
388 &mut dst[bytes_read as usize..(bytes_read + length.get()) as usize],
389 );
390
391 bytes_read += length.get();
392 }
393 }
394
395 fn bucket_iter(&self, id: MemoryId, offset: u64, length: usize) -> BucketIterator {
397 let buckets = match self.memory_buckets.get(&id) {
399 Some(s) => s.as_slice(),
400 None => &[],
401 };
402
403 BucketIterator {
404 virtual_segment: Segment {
405 address: Address::from(offset),
406 length: Bytes::from(length as u64),
407 },
408 buckets,
409 bucket_size_in_bytes: self.bucket_size_in_bytes(),
410 }
411 }
412
413 fn bucket_size_in_bytes(&self) -> Bytes {
414 Bytes::from(self.bucket_size_in_pages as u64 * WASM_PAGE_SIZE)
415 }
416
417 fn num_buckets_needed(&self, num_pages: u64) -> u64 {
419 (num_pages + self.bucket_size_in_pages as u64 - 1) / self.bucket_size_in_pages as u64
421 }
422}
423
424struct Segment {
425 address: Address,
426 length: Bytes,
427}
428
429struct BucketIterator<'a> {
453 virtual_segment: Segment,
454 buckets: &'a [BucketId],
455 bucket_size_in_bytes: Bytes,
456}
457
458impl Iterator for BucketIterator<'_> {
459 type Item = Segment;
460
461 fn next(&mut self) -> Option<Self::Item> {
462 if self.virtual_segment.length == Bytes::from(0u64) {
463 return None;
464 }
465
466 let bucket_idx =
468 (self.virtual_segment.address.get() / self.bucket_size_in_bytes.get()) as usize;
469 let bucket_address = self.bucket_address(
470 *self
471 .buckets
472 .get(bucket_idx)
473 .expect("bucket idx out of bounds"),
474 );
475
476 let real_address = bucket_address
477 + Bytes::from(self.virtual_segment.address.get() % self.bucket_size_in_bytes.get());
478
479 let bytes_in_segment = {
481 let next_bucket_address = bucket_address + self.bucket_size_in_bytes;
482
483 min(
485 Bytes::from(next_bucket_address.get() - real_address.get()),
486 self.virtual_segment.length,
487 )
488 };
489
490 self.virtual_segment.length -= bytes_in_segment;
492 self.virtual_segment.address += bytes_in_segment;
493
494 Some(Segment {
495 address: real_address,
496 length: bytes_in_segment,
497 })
498 }
499}
500
501impl<'a> BucketIterator<'a> {
502 fn bucket_address(&self, id: BucketId) -> Address {
504 Address::from(BUCKETS_OFFSET_IN_BYTES) + self.bucket_size_in_bytes * Bytes::from(id.0)
505 }
506}
507
508#[derive(Clone, Copy, Ord, Eq, PartialEq, PartialOrd, Debug)]
509pub struct MemoryId(u8);
510
511impl MemoryId {
512 pub const fn new(id: u8) -> Self {
513 assert!(id != UNALLOCATED_BUCKET_MARKER);
516
517 Self(id)
518 }
519}
520
521#[derive(Clone, Copy, Debug, PartialEq)]
523struct BucketId(u16);
524
525fn bucket_allocations_address(id: BucketId) -> Address {
526 Address::from(0) + Header::size() + Bytes::from(id.0)
527}
528
529#[cfg(test)]
530mod test {
531 use super::*;
532 use maplit::btreemap;
533 use proptest::prelude::*;
534
535 const MAX_MEMORY_IN_PAGES: u64 = MAX_NUM_BUCKETS * BUCKET_SIZE_IN_PAGES;
536
537 fn make_memory() -> Rc<RefCell<Vec<u8>>> {
538 Rc::new(RefCell::new(Vec::new()))
539 }
540
541 #[test]
542 fn can_get_memory() {
543 let mem_mgr = MemoryManager::init(make_memory());
544 let memory = mem_mgr.get(MemoryId(0));
545 assert_eq!(memory.size(), 0);
546 }
547
548 #[test]
549 fn can_allocate_and_use_memory() {
550 let mem_mgr = MemoryManager::init(make_memory());
551 let memory = mem_mgr.get(MemoryId(0));
552 assert_eq!(memory.grow(1), 0);
553 assert_eq!(memory.size(), 1);
554
555 memory.write(0, &[1, 2, 3]);
556
557 let mut bytes = vec![0; 3];
558 memory.read(0, &mut bytes);
559 assert_eq!(bytes, vec![1, 2, 3]);
560
561 assert_eq!(
562 mem_mgr.inner.borrow().memory_buckets,
563 btreemap! {
564 MemoryId(0) => vec![BucketId(0)]
565 }
566 );
567 }
568
569 #[test]
570 fn can_allocate_and_use_multiple_memories() {
571 let mem = make_memory();
572 let mem_mgr = MemoryManager::init(mem.clone());
573 let memory_0 = mem_mgr.get(MemoryId(0));
574 let memory_1 = mem_mgr.get(MemoryId(1));
575
576 assert_eq!(memory_0.grow(1), 0);
577 assert_eq!(memory_1.grow(1), 0);
578
579 assert_eq!(memory_0.size(), 1);
580 assert_eq!(memory_1.size(), 1);
581
582 assert_eq!(
583 mem_mgr.inner.borrow().memory_buckets,
584 btreemap! {
585 MemoryId(0) => vec![BucketId(0)],
586 MemoryId(1) => vec![BucketId(1)],
587 }
588 );
589
590 memory_0.write(0, &[1, 2, 3]);
591 memory_0.write(0, &[1, 2, 3]);
592 memory_1.write(0, &[4, 5, 6]);
593
594 let mut bytes = vec![0; 3];
595 memory_0.read(0, &mut bytes);
596 assert_eq!(bytes, vec![1, 2, 3]);
597
598 let mut bytes = vec![0; 3];
599 memory_1.read(0, &mut bytes);
600 assert_eq!(bytes, vec![4, 5, 6]);
601
602 assert_eq!(mem.size(), 2 * BUCKET_SIZE_IN_PAGES + 1);
604 }
605
606 #[test]
607 fn can_be_reinitialized_from_memory() {
608 let mem = make_memory();
609 let mem_mgr = MemoryManager::init(mem.clone());
610 let memory_0 = mem_mgr.get(MemoryId(0));
611 let memory_1 = mem_mgr.get(MemoryId(1));
612
613 assert_eq!(memory_0.grow(1), 0);
614 assert_eq!(memory_1.grow(1), 0);
615
616 memory_0.write(0, &[1, 2, 3]);
617 memory_1.write(0, &[4, 5, 6]);
618
619 let mem_mgr = MemoryManager::init(mem);
620 let memory_0 = mem_mgr.get(MemoryId(0));
621 let memory_1 = mem_mgr.get(MemoryId(1));
622
623 let mut bytes = vec![0; 3];
624 memory_0.read(0, &mut bytes);
625 assert_eq!(bytes, vec![1, 2, 3]);
626
627 memory_1.read(0, &mut bytes);
628 assert_eq!(bytes, vec![4, 5, 6]);
629 }
630
631 #[test]
632 fn growing_same_memory_multiple_times_doesnt_increase_underlying_allocation() {
633 let mem = make_memory();
634 let mem_mgr = MemoryManager::init(mem.clone());
635 let memory_0 = mem_mgr.get(MemoryId(0));
636
637 assert_eq!(memory_0.grow(1), 0);
640 assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES);
641
642 assert_eq!(memory_0.grow(1), 1);
644 assert_eq!(memory_0.size(), 2);
645 assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES);
646
647 assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES - 2), 2);
650 assert_eq!(memory_0.size(), BUCKET_SIZE_IN_PAGES);
651 assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES);
652
653 assert_eq!(memory_0.grow(1), BUCKET_SIZE_IN_PAGES as i64);
655 assert_eq!(memory_0.size(), BUCKET_SIZE_IN_PAGES + 1);
656 assert_eq!(mem.size(), 1 + 2 * BUCKET_SIZE_IN_PAGES);
657 }
658
659 #[test]
660 fn does_not_grow_memory_unnecessarily() {
661 let mem = make_memory();
662 let initial_size = BUCKET_SIZE_IN_PAGES * 2;
663
664 mem.grow(initial_size);
666
667 let mem_mgr = MemoryManager::init(mem.clone());
668 let memory_0 = mem_mgr.get(MemoryId(0));
669
670 assert_eq!(memory_0.grow(1), 0);
672 assert_eq!(mem.size(), initial_size);
673
674 assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES), 1);
677 assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES * 2);
678 }
679
680 #[test]
681 fn growing_beyond_capacity_fails() {
682 let mem = make_memory();
683 let mem_mgr = MemoryManager::init(mem);
684 let memory_0 = mem_mgr.get(MemoryId(0));
685
686 assert_eq!(memory_0.grow(MAX_MEMORY_IN_PAGES + 1), -1);
687
688 assert_eq!(memory_0.grow(1), 0); assert_eq!(memory_0.grow(MAX_MEMORY_IN_PAGES), -1); }
692
693 #[test]
694 fn can_write_across_bucket_boundaries() {
695 let mem = make_memory();
696 let mem_mgr = MemoryManager::init(mem);
697 let memory_0 = mem_mgr.get(MemoryId(0));
698
699 assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES + 1), 0);
700
701 memory_0.write(
702 mem_mgr.inner.borrow().bucket_size_in_bytes().get() - 1,
703 &[1, 2, 3],
704 );
705
706 let mut bytes = vec![0; 3];
707 memory_0.read(
708 mem_mgr.inner.borrow().bucket_size_in_bytes().get() - 1,
709 &mut bytes,
710 );
711 assert_eq!(bytes, vec![1, 2, 3]);
712 }
713
714 #[test]
715 fn can_write_across_bucket_boundaries_with_interleaving_memories() {
716 let mem = make_memory();
717 let mem_mgr = MemoryManager::init(mem);
718 let memory_0 = mem_mgr.get(MemoryId(0));
719 let memory_1 = mem_mgr.get(MemoryId(1));
720
721 assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES), 0);
722 assert_eq!(memory_1.grow(1), 0);
723 assert_eq!(memory_0.grow(1), BUCKET_SIZE_IN_PAGES as i64);
724
725 memory_0.write(
726 mem_mgr.inner.borrow().bucket_size_in_bytes().get() - 1,
727 &[1, 2, 3],
728 );
729 memory_1.write(0, &[4, 5, 6]);
730
731 let mut bytes = vec![0; 3];
732 memory_0.read(WASM_PAGE_SIZE * BUCKET_SIZE_IN_PAGES - 1, &mut bytes);
733 assert_eq!(bytes, vec![1, 2, 3]);
734
735 let mut bytes = vec![0; 3];
736 memory_1.read(0, &mut bytes);
737 assert_eq!(bytes, vec![4, 5, 6]);
738 }
739
740 #[test]
741 #[should_panic]
742 fn reading_out_of_bounds_should_panic() {
743 let mem = make_memory();
744 let mem_mgr = MemoryManager::init(mem);
745 let memory_0 = mem_mgr.get(MemoryId(0));
746 let memory_1 = mem_mgr.get(MemoryId(1));
747
748 assert_eq!(memory_0.grow(1), 0);
749 assert_eq!(memory_1.grow(1), 0);
750
751 let mut bytes = vec![0; WASM_PAGE_SIZE as usize + 1];
752 memory_0.read(0, &mut bytes);
753 }
754
755 #[test]
756 #[should_panic]
757 fn writing_out_of_bounds_should_panic() {
758 let mem = make_memory();
759 let mem_mgr = MemoryManager::init(mem);
760 let memory_0 = mem_mgr.get(MemoryId(0));
761 let memory_1 = mem_mgr.get(MemoryId(1));
762
763 assert_eq!(memory_0.grow(1), 0);
764 assert_eq!(memory_1.grow(1), 0);
765
766 let bytes = vec![0; WASM_PAGE_SIZE as usize + 1];
767 memory_0.write(0, &bytes);
768 }
769
770 #[test]
771 fn reading_zero_bytes_from_empty_memory_should_not_panic() {
772 let mem = make_memory();
773 let mem_mgr = MemoryManager::init(mem);
774 let memory_0 = mem_mgr.get(MemoryId(0));
775
776 assert_eq!(memory_0.size(), 0);
777 let mut bytes = vec![];
778 memory_0.read(0, &mut bytes);
779 }
780
781 #[test]
782 fn writing_zero_bytes_to_empty_memory_should_not_panic() {
783 let mem = make_memory();
784 let mem_mgr = MemoryManager::init(mem);
785 let memory_0 = mem_mgr.get(MemoryId(0));
786
787 assert_eq!(memory_0.size(), 0);
788 memory_0.write(0, &[]);
789 }
790
791 #[test]
792 fn write_and_read_random_bytes() {
793 let mem = make_memory();
794 let mem_mgr = MemoryManager::init_with_bucket_size(mem, 1); let memories: Vec<_> = (0..MAX_NUM_MEMORIES)
797 .map(|id| mem_mgr.get(MemoryId(id)))
798 .collect();
799
800 proptest!(|(
801 num_memories in 0..255usize,
802 data in proptest::collection::vec(0..u8::MAX, 0..2*WASM_PAGE_SIZE as usize),
803 offset in 0..10*WASM_PAGE_SIZE
804 )| {
805 for memory in memories.iter().take(num_memories) {
806 write(memory, offset, &data);
808
809 let mut bytes = vec![0; data.len()];
811 memory.read(offset, &mut bytes);
812 assert_eq!(bytes, data);
813 }
814 });
815 }
816
817 #[test]
818 fn init_with_non_default_bucket_size() {
819 let bucket_size = 256;
821 assert_ne!(bucket_size, BUCKET_SIZE_IN_PAGES as u16);
822
823 let mem = make_memory();
825 let mem_mgr = MemoryManager::init_with_bucket_size(mem.clone(), bucket_size);
826
827 let memory_0 = mem_mgr.get(MemoryId(0));
829 let memory_1 = mem_mgr.get(MemoryId(1));
830 memory_0.grow(300);
831 memory_1.grow(100);
832 memory_0.write(0, &[1; 1000]);
833 memory_1.write(0, &[2; 1000]);
834
835 let mem_mgr = MemoryManager::init(mem);
838
839 assert_eq!(mem_mgr.inner.borrow().bucket_size_in_pages, bucket_size);
841
842 let memory_0 = mem_mgr.get(MemoryId(0));
844 let memory_1 = mem_mgr.get(MemoryId(1));
845
846 assert_eq!(memory_0.size(), 300);
847 assert_eq!(memory_1.size(), 100);
848
849 let mut buf = vec![0; 1000];
850 memory_0.read(0, &mut buf);
851 assert_eq!(buf, vec![1; 1000]);
852
853 memory_1.read(0, &mut buf);
854 assert_eq!(buf, vec![2; 1000]);
855 }
856}