use crate::{
read_struct,
types::{Address, Bytes, NULL},
write_struct, Memory,
};
const ALLOCATOR_LAYOUT_VERSION: u8 = 1;
const CHUNK_LAYOUT_VERSION: u8 = 1;
const ALLOCATOR_MAGIC: &[u8; 3] = b"BTA"; const CHUNK_MAGIC: &[u8; 3] = b"CHK";
pub struct Allocator<M: Memory> {
header_addr: Address,
allocation_size: Bytes,
num_allocated_chunks: u64,
free_list_head: Address,
memory: M,
}
#[repr(C, packed)]
#[derive(PartialEq, Debug)]
struct AllocatorHeader {
magic: [u8; 3],
version: u8,
_alignment: [u8; 4],
allocation_size: Bytes,
num_allocated_chunks: u64,
free_list_head: Address,
_buffer: [u8; 16],
}
impl AllocatorHeader {
fn size() -> Bytes {
Bytes::from(core::mem::size_of::<Self>() as u64)
}
}
impl<M: Memory> Allocator<M> {
pub fn new(memory: M, addr: Address, allocation_size: Bytes) -> Self {
let mut allocator = Self {
header_addr: addr,
allocation_size,
num_allocated_chunks: 0,
free_list_head: NULL, memory,
};
allocator.clear();
allocator
}
pub fn clear(&mut self) {
self.free_list_head = self.header_addr + AllocatorHeader::size();
let chunk = ChunkHeader::null();
chunk.save(self.free_list_head, &self.memory);
self.num_allocated_chunks = 0;
self.save()
}
pub fn load(memory: M, addr: Address) -> Self {
let header: AllocatorHeader = read_struct(addr, &memory);
assert_eq!(&header.magic, ALLOCATOR_MAGIC, "Bad magic.");
assert_eq!(
header.version, ALLOCATOR_LAYOUT_VERSION,
"Unsupported version."
);
Self {
header_addr: addr,
allocation_size: header.allocation_size,
num_allocated_chunks: header.num_allocated_chunks,
free_list_head: header.free_list_head,
memory,
}
}
pub fn allocate(&mut self) -> Address {
let chunk_addr = self.free_list_head;
let mut chunk = ChunkHeader::load(chunk_addr, &self.memory);
assert!(
!chunk.allocated,
"Attempting to allocate an already allocated chunk."
);
chunk.allocated = true;
chunk.save(chunk_addr, &self.memory);
if chunk.next != NULL {
self.free_list_head = chunk.next;
} else {
self.free_list_head += self.chunk_size();
ChunkHeader::null().save(self.free_list_head, &self.memory);
}
self.num_allocated_chunks += 1;
self.save();
chunk_addr + ChunkHeader::size()
}
pub fn deallocate(&mut self, address: Address) {
let chunk_addr = address - ChunkHeader::size();
let mut chunk = ChunkHeader::load(chunk_addr, &self.memory);
assert!(chunk.allocated);
chunk.allocated = false;
chunk.next = self.free_list_head;
chunk.save(chunk_addr, &self.memory);
self.free_list_head = chunk_addr;
self.num_allocated_chunks -= 1;
self.save();
}
pub fn save(&self) {
let header = AllocatorHeader {
magic: *ALLOCATOR_MAGIC,
version: ALLOCATOR_LAYOUT_VERSION,
_alignment: [0; 4],
num_allocated_chunks: self.num_allocated_chunks,
allocation_size: self.allocation_size,
free_list_head: self.free_list_head,
_buffer: [0; 16],
};
write_struct(&header, self.header_addr, &self.memory);
}
#[cfg(test)]
pub fn num_allocated_chunks(&self) -> u64 {
self.num_allocated_chunks
}
fn chunk_size(&self) -> Bytes {
self.allocation_size + ChunkHeader::size()
}
#[inline]
pub fn into_memory(self) -> M {
self.memory
}
#[inline]
pub fn memory(&self) -> &M {
&self.memory
}
}
#[derive(Debug)]
#[repr(C, packed)]
struct ChunkHeader {
magic: [u8; 3],
version: u8,
allocated: bool,
_alignment: [u8; 3],
next: Address,
}
impl ChunkHeader {
fn null() -> Self {
Self {
magic: *CHUNK_MAGIC,
version: CHUNK_LAYOUT_VERSION,
allocated: false,
_alignment: [0; 3],
next: NULL,
}
}
fn save<M: Memory>(&self, address: Address, memory: &M) {
write_struct(self, address, memory);
}
fn load<M: Memory>(address: Address, memory: &M) -> Self {
let header: ChunkHeader = read_struct(address, memory);
assert_eq!(&header.magic, CHUNK_MAGIC, "Bad magic.");
assert_eq!(header.version, CHUNK_LAYOUT_VERSION, "Unsupported version.");
header
}
fn size() -> Bytes {
Bytes::from(core::mem::size_of::<Self>() as u64)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{Memory, WASM_PAGE_SIZE};
use std::cell::RefCell;
use std::rc::Rc;
fn make_memory() -> Rc<RefCell<Vec<u8>>> {
Rc::new(RefCell::new(Vec::new()))
}
#[test]
fn new_and_load() {
let mem = make_memory();
let allocator_addr = Address::from(0);
let allocation_size = Bytes::from(16u64);
Allocator::new(mem.clone(), allocator_addr, allocation_size);
let allocator = Allocator::load(mem.clone(), allocator_addr);
assert_eq!(allocator.allocation_size, allocation_size);
assert_eq!(
allocator.free_list_head,
allocator_addr + AllocatorHeader::size()
);
let chunk = ChunkHeader::load(allocator.free_list_head, &mem);
assert_eq!(chunk.next, NULL);
}
#[test]
fn allocate() {
let mem = make_memory();
let allocation_size = Bytes::from(16u64);
let mut allocator = Allocator::new(mem, Address::from(0), allocation_size);
let original_free_list_head = allocator.free_list_head;
for i in 1..=3 {
allocator.allocate();
assert_eq!(
allocator.free_list_head,
original_free_list_head + allocator.chunk_size() * i
);
}
}
#[test]
fn allocate_large() {
let mem = make_memory();
assert_eq!(mem.size(), 0);
let allocator_addr = Address::from(0);
let allocation_size = Bytes::from(WASM_PAGE_SIZE);
let mut allocator = Allocator::new(mem.clone(), allocator_addr, allocation_size);
assert_eq!(mem.size(), 1);
allocator.allocate();
assert_eq!(mem.size(), 2);
allocator.allocate();
assert_eq!(mem.size(), 3);
allocator.allocate();
assert_eq!(mem.size(), 4);
assert_eq!(
allocator.free_list_head,
allocator_addr + AllocatorHeader::size() + allocator.chunk_size() * 3
);
assert_eq!(allocator.num_allocated_chunks, 3);
let allocator = Allocator::load(mem, Address::from(0));
assert_eq!(
allocator.free_list_head,
allocator_addr + AllocatorHeader::size() + allocator.chunk_size() * 3
);
assert_eq!(allocator.num_allocated_chunks, 3);
}
#[test]
fn allocate_then_deallocate() {
let mem = make_memory();
let allocation_size = Bytes::from(16u64);
let allocator_addr = Address::from(0);
let mut allocator = Allocator::new(mem.clone(), allocator_addr, allocation_size);
let chunk_addr = allocator.allocate();
assert_eq!(
allocator.free_list_head,
allocator_addr + AllocatorHeader::size() + allocator.chunk_size()
);
allocator.deallocate(chunk_addr);
assert_eq!(
allocator.free_list_head,
allocator_addr + AllocatorHeader::size()
);
assert_eq!(allocator.num_allocated_chunks, 0);
let allocator = Allocator::load(mem, allocator_addr);
assert_eq!(
allocator.free_list_head,
allocator_addr + AllocatorHeader::size()
);
assert_eq!(allocator.num_allocated_chunks, 0);
}
#[test]
fn clear_deallocates_all_allocated_chunks() {
let mem = make_memory();
let allocation_size = Bytes::from(16u64);
let allocator_addr = Address::from(0);
let mut allocator = Allocator::new(mem.clone(), allocator_addr, allocation_size);
allocator.allocate();
allocator.allocate();
assert_eq!(
allocator.free_list_head,
allocator_addr
+ AllocatorHeader::size()
+ allocator.chunk_size()
+ allocator.chunk_size()
);
allocator.clear();
let header_actual: AllocatorHeader = read_struct(allocator_addr, &mem);
Allocator::new(mem.clone(), allocator_addr, allocation_size);
let header_expected: AllocatorHeader = read_struct(allocator_addr, &mem);
assert_eq!(header_actual, header_expected);
}
#[test]
fn allocate_deallocate_2() {
let mem = make_memory();
let allocation_size = Bytes::from(16u64);
let mut allocator = Allocator::new(mem, Address::from(0), allocation_size);
let _chunk_addr_1 = allocator.allocate();
let chunk_addr_2 = allocator.allocate();
assert_eq!(allocator.free_list_head, chunk_addr_2 + allocation_size);
allocator.deallocate(chunk_addr_2);
assert_eq!(allocator.free_list_head, chunk_addr_2 - ChunkHeader::size());
let chunk_addr_3 = allocator.allocate();
assert_eq!(chunk_addr_3, chunk_addr_2);
assert_eq!(allocator.free_list_head, chunk_addr_3 + allocation_size);
}
#[test]
#[should_panic]
fn deallocate_free_chunk() {
let mem = make_memory();
let allocation_size: u64 = 16;
let mut allocator = Allocator::new(mem, Address::from(0), Bytes::from(allocation_size));
let chunk_addr = allocator.allocate();
allocator.deallocate(chunk_addr);
allocator.deallocate(chunk_addr);
}
}