use crate::{
read_struct,
types::{Address, Bytes},
write, write_struct, Memory, WASM_PAGE_SIZE,
};
use std::cell::RefCell;
use std::cmp::min;
use std::collections::BTreeMap;
use std::rc::Rc;
const MAGIC: &[u8; 3] = b"MGR";
const LAYOUT_VERSION: u8 = 1;
const MAX_NUM_MEMORIES: u8 = 255;
const MAX_NUM_BUCKETS: u64 = 32768;
const BUCKET_SIZE_IN_PAGES: u64 = 1024;
const UNALLOCATED_BUCKET_MARKER: u8 = MAX_NUM_MEMORIES;
const BUCKETS_OFFSET_IN_PAGES: u64 = 1;
const BUCKETS_OFFSET_IN_BYTES: u64 = BUCKETS_OFFSET_IN_PAGES * WASM_PAGE_SIZE;
const HEADER_RESERVED_BYTES: usize = 32;
pub struct MemoryManager<M: Memory> {
inner: Rc<RefCell<MemoryManagerInner<M>>>,
}
impl<M: Memory> MemoryManager<M> {
pub fn init(memory: M) -> Self {
Self::init_with_buckets(memory, BUCKET_SIZE_IN_PAGES as u16)
}
fn init_with_buckets(memory: M, bucket_size_in_pages: u16) -> Self {
Self {
inner: Rc::new(RefCell::new(MemoryManagerInner::init(
memory,
bucket_size_in_pages,
))),
}
}
pub fn get(&self, id: MemoryId) -> VirtualMemory<M> {
VirtualMemory {
id,
memory_manager: self.inner.clone(),
}
}
}
#[repr(packed)]
struct Header {
magic: [u8; 3],
version: u8,
num_allocated_buckets: u16,
bucket_size_in_pages: u16,
_reserved: [u8; HEADER_RESERVED_BYTES],
memory_sizes_in_pages: [u64; MAX_NUM_MEMORIES as usize],
}
impl Header {
fn size() -> Bytes {
Bytes::new(core::mem::size_of::<Self>() as u64)
}
}
#[derive(Clone)]
pub struct VirtualMemory<M: Memory> {
id: MemoryId,
memory_manager: Rc<RefCell<MemoryManagerInner<M>>>,
}
impl<M: Memory> Memory for VirtualMemory<M> {
fn size(&self) -> u64 {
self.memory_manager.borrow().memory_size(self.id)
}
fn grow(&self, pages: u64) -> i64 {
self.memory_manager.borrow_mut().grow(self.id, pages)
}
fn read(&self, offset: u64, dst: &mut [u8]) {
self.memory_manager.borrow().read(self.id, offset, dst)
}
fn write(&self, offset: u64, src: &[u8]) {
self.memory_manager.borrow().write(self.id, offset, src)
}
}
#[derive(Clone)]
struct MemoryManagerInner<M: Memory> {
memory: M,
allocated_buckets: u16,
bucket_size_in_pages: u16,
memory_sizes_in_pages: [u64; MAX_NUM_MEMORIES as usize],
memory_buckets: BTreeMap<MemoryId, Vec<BucketId>>,
}
impl<M: Memory> MemoryManagerInner<M> {
fn init(memory: M, bucket_size_in_pages: u16) -> Self {
if memory.size() == 0 {
return Self::new(memory, bucket_size_in_pages);
}
let mut dst = vec![0; 3];
memory.read(0, &mut dst);
if dst != MAGIC {
MemoryManagerInner::new(memory, bucket_size_in_pages)
} else {
let mem_mgr = MemoryManagerInner::load(memory);
assert_eq!(mem_mgr.bucket_size_in_pages, bucket_size_in_pages);
mem_mgr
}
}
fn new(memory: M, bucket_size_in_pages: u16) -> Self {
let mem_mgr = Self {
memory,
allocated_buckets: 0,
memory_sizes_in_pages: [0; MAX_NUM_MEMORIES as usize],
memory_buckets: BTreeMap::new(),
bucket_size_in_pages,
};
mem_mgr.save_header();
write(
&mem_mgr.memory,
bucket_allocations_address(BucketId(0)).get(),
&[UNALLOCATED_BUCKET_MARKER; MAX_NUM_BUCKETS as usize],
);
mem_mgr
}
fn load(memory: M) -> Self {
let header: Header = read_struct(Address::from(0), &memory);
assert_eq!(&header.magic, MAGIC, "Bad magic.");
assert_eq!(header.version, LAYOUT_VERSION, "Unsupported version.");
let mut buckets = vec![0; MAX_NUM_BUCKETS as usize];
memory.read(bucket_allocations_address(BucketId(0)).get(), &mut buckets);
let mut memory_buckets = BTreeMap::new();
for (bucket_idx, memory) in buckets.into_iter().enumerate() {
if memory != UNALLOCATED_BUCKET_MARKER {
memory_buckets
.entry(MemoryId(memory))
.or_insert_with(Vec::new)
.push(BucketId(bucket_idx as u16));
}
}
Self {
memory,
allocated_buckets: header.num_allocated_buckets,
bucket_size_in_pages: header.bucket_size_in_pages,
memory_sizes_in_pages: header.memory_sizes_in_pages,
memory_buckets,
}
}
fn save_header(&self) {
let header = Header {
magic: *MAGIC,
version: LAYOUT_VERSION,
num_allocated_buckets: self.allocated_buckets,
bucket_size_in_pages: self.bucket_size_in_pages,
_reserved: [0; HEADER_RESERVED_BYTES],
memory_sizes_in_pages: self.memory_sizes_in_pages,
};
write_struct(&header, Address::from(0), &self.memory);
}
fn memory_size(&self, id: MemoryId) -> u64 {
self.memory_sizes_in_pages[id.0 as usize]
}
fn grow(&mut self, id: MemoryId, pages: u64) -> i64 {
let old_size = self.memory_size(id);
let new_size = old_size + pages;
let current_buckets = self.num_buckets_needed(old_size);
let required_buckets = self.num_buckets_needed(new_size);
let new_buckets_needed = required_buckets - current_buckets;
if new_buckets_needed + self.allocated_buckets as u64 > MAX_NUM_BUCKETS {
return -1;
}
for _ in 0..new_buckets_needed {
let new_bucket_id = BucketId(self.allocated_buckets);
self.memory_buckets
.entry(id)
.or_insert_with(Vec::new)
.push(new_bucket_id);
write(
&self.memory,
bucket_allocations_address(new_bucket_id).get(),
&[id.0],
);
self.allocated_buckets += 1;
}
let pages_needed = BUCKETS_OFFSET_IN_PAGES
+ self.bucket_size_in_pages as u64 * self.allocated_buckets as u64;
if pages_needed > self.memory.size() {
let additional_pages_needed = pages_needed - self.memory.size();
let prev_pages = self.memory.grow(additional_pages_needed);
if prev_pages == -1 {
panic!("{:?}: grow failed", id);
}
}
self.memory_sizes_in_pages[id.0 as usize] = new_size;
self.save_header();
old_size as i64
}
fn write(&self, id: MemoryId, offset: u64, src: &[u8]) {
if (offset + src.len() as u64) > self.memory_size(id) * WASM_PAGE_SIZE {
panic!("{:?}: write out of bounds", id);
}
let mut bytes_written = 0;
for Segment { address, length } in self.bucket_iter(id, offset, src.len()) {
self.memory.write(
address.get(),
&src[bytes_written as usize..(bytes_written + length.get()) as usize],
);
bytes_written += length.get();
}
}
fn read(&self, id: MemoryId, offset: u64, dst: &mut [u8]) {
if (offset + dst.len() as u64) > self.memory_size(id) * WASM_PAGE_SIZE {
panic!("{:?}: read out of bounds", id);
}
let mut bytes_read = 0;
for Segment { address, length } in self.bucket_iter(id, offset, dst.len()) {
self.memory.read(
address.get(),
&mut dst[bytes_read as usize..(bytes_read + length.get()) as usize],
);
bytes_read += length.get();
}
}
fn bucket_iter(&self, id: MemoryId, offset: u64, length: usize) -> BucketIterator {
let buckets = match self.memory_buckets.get(&id) {
Some(s) => s.as_slice(),
None => &[],
};
BucketIterator {
virtual_segment: Segment {
address: Address::from(offset),
length: Bytes::from(length as u64),
},
buckets,
bucket_size_in_bytes: self.bucket_size_in_bytes(),
}
}
fn bucket_size_in_bytes(&self) -> Bytes {
Bytes::from(self.bucket_size_in_pages as u64 * WASM_PAGE_SIZE)
}
fn num_buckets_needed(&self, num_pages: u64) -> u64 {
(num_pages + self.bucket_size_in_pages as u64 - 1) / self.bucket_size_in_pages as u64
}
}
struct Segment {
address: Address,
length: Bytes,
}
struct BucketIterator<'a> {
virtual_segment: Segment,
buckets: &'a [BucketId],
bucket_size_in_bytes: Bytes,
}
impl Iterator for BucketIterator<'_> {
type Item = Segment;
fn next(&mut self) -> Option<Self::Item> {
if self.virtual_segment.length == Bytes::from(0u64) {
return None;
}
let bucket_idx =
(self.virtual_segment.address.get() / self.bucket_size_in_bytes.get()) as usize;
let bucket_address = self.bucket_address(
*self
.buckets
.get(bucket_idx)
.expect("bucket idx out of bounds"),
);
let real_address = bucket_address
+ Bytes::from(self.virtual_segment.address.get() % self.bucket_size_in_bytes.get());
let bytes_in_segment = {
let next_bucket_address = bucket_address + self.bucket_size_in_bytes;
min(
Bytes::from(next_bucket_address.get() - real_address.get()),
self.virtual_segment.length,
)
};
self.virtual_segment.length -= bytes_in_segment;
self.virtual_segment.address += bytes_in_segment;
Some(Segment {
address: real_address,
length: bytes_in_segment,
})
}
}
impl<'a> BucketIterator<'a> {
fn bucket_address(&self, id: BucketId) -> Address {
Address::from(BUCKETS_OFFSET_IN_BYTES) + self.bucket_size_in_bytes * Bytes::from(id.0)
}
}
#[derive(Clone, Copy, Ord, Eq, PartialEq, PartialOrd, Debug)]
pub struct MemoryId(u8);
impl MemoryId {
pub const fn new(id: u8) -> Self {
assert!(id != UNALLOCATED_BUCKET_MARKER);
Self(id)
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
struct BucketId(u16);
fn bucket_allocations_address(id: BucketId) -> Address {
Address::from(0) + Header::size() + Bytes::from(id.0)
}
#[cfg(test)]
mod test {
use super::*;
use maplit::btreemap;
use proptest::prelude::*;
const MAX_MEMORY_IN_PAGES: u64 = MAX_NUM_BUCKETS * BUCKET_SIZE_IN_PAGES;
fn make_memory() -> Rc<RefCell<Vec<u8>>> {
Rc::new(RefCell::new(Vec::new()))
}
#[test]
fn can_get_memory() {
let mem_mgr = MemoryManager::init(make_memory());
let memory = mem_mgr.get(MemoryId(0));
assert_eq!(memory.size(), 0);
}
#[test]
fn can_allocate_and_use_memory() {
let mem_mgr = MemoryManager::init(make_memory());
let memory = mem_mgr.get(MemoryId(0));
assert_eq!(memory.grow(1), 0);
assert_eq!(memory.size(), 1);
memory.write(0, &[1, 2, 3]);
let mut bytes = vec![0; 3];
memory.read(0, &mut bytes);
assert_eq!(bytes, vec![1, 2, 3]);
assert_eq!(
mem_mgr.inner.borrow().memory_buckets,
btreemap! {
MemoryId(0) => vec![BucketId(0)]
}
);
}
#[test]
fn can_allocate_and_use_multiple_memories() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem.clone());
let memory_0 = mem_mgr.get(MemoryId(0));
let memory_1 = mem_mgr.get(MemoryId(1));
assert_eq!(memory_0.grow(1), 0);
assert_eq!(memory_1.grow(1), 0);
assert_eq!(memory_0.size(), 1);
assert_eq!(memory_1.size(), 1);
assert_eq!(
mem_mgr.inner.borrow().memory_buckets,
btreemap! {
MemoryId(0) => vec![BucketId(0)],
MemoryId(1) => vec![BucketId(1)],
}
);
memory_0.write(0, &[1, 2, 3]);
memory_0.write(0, &[1, 2, 3]);
memory_1.write(0, &[4, 5, 6]);
let mut bytes = vec![0; 3];
memory_0.read(0, &mut bytes);
assert_eq!(bytes, vec![1, 2, 3]);
let mut bytes = vec![0; 3];
memory_1.read(0, &mut bytes);
assert_eq!(bytes, vec![4, 5, 6]);
assert_eq!(mem.size(), 2 * BUCKET_SIZE_IN_PAGES + 1);
}
#[test]
fn can_be_reinitialized_from_memory() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem.clone());
let memory_0 = mem_mgr.get(MemoryId(0));
let memory_1 = mem_mgr.get(MemoryId(1));
assert_eq!(memory_0.grow(1), 0);
assert_eq!(memory_1.grow(1), 0);
memory_0.write(0, &[1, 2, 3]);
memory_1.write(0, &[4, 5, 6]);
let mem_mgr = MemoryManager::init(mem);
let memory_0 = mem_mgr.get(MemoryId(0));
let memory_1 = mem_mgr.get(MemoryId(1));
let mut bytes = vec![0; 3];
memory_0.read(0, &mut bytes);
assert_eq!(bytes, vec![1, 2, 3]);
memory_1.read(0, &mut bytes);
assert_eq!(bytes, vec![4, 5, 6]);
}
#[test]
fn growing_same_memory_multiple_times_doesnt_increase_underlying_allocation() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem.clone());
let memory_0 = mem_mgr.get(MemoryId(0));
assert_eq!(memory_0.grow(1), 0);
assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES);
assert_eq!(memory_0.grow(1), 1);
assert_eq!(memory_0.size(), 2);
assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES);
assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES - 2), 2);
assert_eq!(memory_0.size(), BUCKET_SIZE_IN_PAGES);
assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES);
assert_eq!(memory_0.grow(1), BUCKET_SIZE_IN_PAGES as i64);
assert_eq!(memory_0.size(), BUCKET_SIZE_IN_PAGES + 1);
assert_eq!(mem.size(), 1 + 2 * BUCKET_SIZE_IN_PAGES);
}
#[test]
fn does_not_grow_memory_unnecessarily() {
let mem = make_memory();
let initial_size = BUCKET_SIZE_IN_PAGES * 2;
mem.grow(initial_size);
let mem_mgr = MemoryManager::init(mem.clone());
let memory_0 = mem_mgr.get(MemoryId(0));
assert_eq!(memory_0.grow(1), 0);
assert_eq!(mem.size(), initial_size);
assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES), 1);
assert_eq!(mem.size(), 1 + BUCKET_SIZE_IN_PAGES * 2);
}
#[test]
fn growing_beyond_capacity_fails() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem);
let memory_0 = mem_mgr.get(MemoryId(0));
assert_eq!(memory_0.grow(MAX_MEMORY_IN_PAGES + 1), -1);
assert_eq!(memory_0.grow(1), 0); assert_eq!(memory_0.grow(MAX_MEMORY_IN_PAGES), -1); }
#[test]
fn can_write_across_bucket_boundaries() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem);
let memory_0 = mem_mgr.get(MemoryId(0));
assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES + 1), 0);
memory_0.write(
mem_mgr.inner.borrow().bucket_size_in_bytes().get() - 1,
&[1, 2, 3],
);
let mut bytes = vec![0; 3];
memory_0.read(
mem_mgr.inner.borrow().bucket_size_in_bytes().get() - 1,
&mut bytes,
);
assert_eq!(bytes, vec![1, 2, 3]);
}
#[test]
fn can_write_across_bucket_boundaries_with_interleaving_memories() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem);
let memory_0 = mem_mgr.get(MemoryId(0));
let memory_1 = mem_mgr.get(MemoryId(1));
assert_eq!(memory_0.grow(BUCKET_SIZE_IN_PAGES), 0);
assert_eq!(memory_1.grow(1), 0);
assert_eq!(memory_0.grow(1), BUCKET_SIZE_IN_PAGES as i64);
memory_0.write(
mem_mgr.inner.borrow().bucket_size_in_bytes().get() - 1,
&[1, 2, 3],
);
memory_1.write(0, &[4, 5, 6]);
let mut bytes = vec![0; 3];
memory_0.read(WASM_PAGE_SIZE * BUCKET_SIZE_IN_PAGES - 1, &mut bytes);
assert_eq!(bytes, vec![1, 2, 3]);
let mut bytes = vec![0; 3];
memory_1.read(0, &mut bytes);
assert_eq!(bytes, vec![4, 5, 6]);
}
#[test]
#[should_panic]
fn reading_out_of_bounds_should_panic() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem);
let memory_0 = mem_mgr.get(MemoryId(0));
let memory_1 = mem_mgr.get(MemoryId(1));
assert_eq!(memory_0.grow(1), 0);
assert_eq!(memory_1.grow(1), 0);
let mut bytes = vec![0; WASM_PAGE_SIZE as usize + 1];
memory_0.read(0, &mut bytes);
}
#[test]
#[should_panic]
fn writing_out_of_bounds_should_panic() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem);
let memory_0 = mem_mgr.get(MemoryId(0));
let memory_1 = mem_mgr.get(MemoryId(1));
assert_eq!(memory_0.grow(1), 0);
assert_eq!(memory_1.grow(1), 0);
let bytes = vec![0; WASM_PAGE_SIZE as usize + 1];
memory_0.write(0, &bytes);
}
#[test]
fn reading_zero_bytes_from_empty_memory_should_not_panic() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem);
let memory_0 = mem_mgr.get(MemoryId(0));
assert_eq!(memory_0.size(), 0);
let mut bytes = vec![];
memory_0.read(0, &mut bytes);
}
#[test]
fn writing_zero_bytes_to_empty_memory_should_not_panic() {
let mem = make_memory();
let mem_mgr = MemoryManager::init(mem);
let memory_0 = mem_mgr.get(MemoryId(0));
assert_eq!(memory_0.size(), 0);
memory_0.write(0, &[]);
}
#[test]
fn write_and_read_random_bytes() {
let mem = make_memory();
let mem_mgr = MemoryManager::init_with_buckets(mem, 1);
let memories: Vec<_> = (0..MAX_NUM_MEMORIES)
.map(|id| mem_mgr.get(MemoryId(id)))
.collect();
proptest!(|(
num_memories in 0..255usize,
data in proptest::collection::vec(0..u8::MAX, 0..2*WASM_PAGE_SIZE as usize),
offset in 0..10*WASM_PAGE_SIZE
)| {
for memory in memories.iter().take(num_memories) {
write(memory, offset, &data);
let mut bytes = vec![0; data.len()];
memory.read(offset, &mut bytes);
assert_eq!(bytes, data);
}
});
}
}