use std::collections::HashSet;
use std::path::Path;
use crate::error::{PageError, PageResult};
use crate::file::PageFile;
use crate::page::{Page, PageId, PageSize};
use crate::store::PageStore;
use crate::sync::{self, Mutex};
const NO_PAGE: u64 = u64::MAX;
const SB_MAGIC: u32 = u32::from_le_bytes([b'P', b'G', b'S', b'B']);
const SB_VERSION: u16 = 1;
const SB_MAGIC_OFF: usize = 0;
const SB_VERSION_OFF: usize = 4;
const SB_HEAD_OFF: usize = 8;
const SB_NEXT_OFF: usize = 16;
const SB_FREECOUNT_OFF: usize = 24;
const LINK_NEXT_OFF: usize = 0;
struct AllocState {
free_list: Vec<u64>,
next_new: u64,
scratch: Page,
}
pub struct PageAllocator<S = PageFile> {
store: S,
state: Mutex<AllocState>,
}
impl PageAllocator<PageFile> {
pub fn open<P: AsRef<Path>>(path: P, page_size: PageSize) -> PageResult<Self> {
let file = PageFile::open(path, page_size)?;
Self::new(file)
}
}
impl<S: PageStore> PageAllocator<S> {
pub fn new(store: S) -> PageResult<Self> {
let mut scratch = store.allocate_page();
let (free_list, next_new, fresh) = match store.read_into(PageId::new(0), &mut scratch) {
Ok(()) => {
let payload = scratch.payload();
if read_u32(payload, SB_MAGIC_OFF) != SB_MAGIC
|| read_u16(payload, SB_VERSION_OFF) != SB_VERSION
{
return Err(PageError::InvalidSuperblock);
}
let head = read_u64(payload, SB_HEAD_OFF);
let next_new = read_u64(payload, SB_NEXT_OFF);
let free_count = read_u64(payload, SB_FREECOUNT_OFF);
if next_new < 1 {
return Err(PageError::InvalidSuperblock);
}
let free_list = walk_free_chain(&store, &mut scratch, head, next_new, free_count)?;
(free_list, next_new, false)
}
Err(PageError::ShortRead { .. }) => (Vec::new(), 1, true),
Err(err) => return Err(err),
};
let allocator = Self {
store,
state: Mutex::new(AllocState {
free_list,
next_new,
scratch,
}),
};
if fresh {
let mut state = sync::lock(&allocator.state);
allocator.persist_superblock(&mut state)?;
}
Ok(allocator)
}
pub fn allocate(&self) -> PageResult<PageId> {
let mut state = sync::lock(&self.state);
if let Some(id) = state.free_list.pop() {
return Ok(PageId::new(id));
}
let id = state.next_new;
match id.checked_add(1) {
Some(next) if next != NO_PAGE => {
state.next_new = next;
Ok(PageId::new(id))
}
_ => Err(PageError::InvalidPageId { page_id: id }),
}
}
pub fn free(&self, id: PageId) -> PageResult<()> {
let raw = id.get();
if raw == 0 {
return Err(PageError::InvalidPageId { page_id: 0 });
}
let mut state = sync::lock(&self.state);
if raw >= state.next_new {
return Err(PageError::InvalidPageId { page_id: raw });
}
state.free_list.push(raw);
Ok(())
}
#[must_use]
pub fn high_water(&self) -> u64 {
sync::lock(&self.state).next_new
}
#[must_use]
pub fn free_count(&self) -> u64 {
sync::lock(&self.state).free_list.len() as u64
}
pub fn sync(&self) -> PageResult<()> {
{
let mut state = sync::lock(&self.state);
self.persist_superblock(&mut state)?;
}
self.store.sync()
}
fn persist_superblock(&self, state: &mut AllocState) -> PageResult<()> {
let len = state.free_list.len();
for i in 0..len {
let id = state.free_list[i];
let next = if i + 1 < len {
state.free_list[i + 1]
} else {
NO_PAGE
};
state.scratch.reset();
write_u64(state.scratch.payload_mut(), LINK_NEXT_OFF, next);
self.store.write_page(PageId::new(id), &mut state.scratch)?;
}
let head = state.free_list.first().copied().unwrap_or(NO_PAGE);
state.scratch.reset();
let payload = state.scratch.payload_mut();
write_u32(payload, SB_MAGIC_OFF, SB_MAGIC);
write_u16(payload, SB_VERSION_OFF, SB_VERSION);
write_u64(payload, SB_HEAD_OFF, head);
write_u64(payload, SB_NEXT_OFF, state.next_new);
write_u64(payload, SB_FREECOUNT_OFF, len as u64);
self.store.write_page(PageId::new(0), &mut state.scratch)
}
}
fn walk_free_chain<S: PageStore>(
store: &S,
scratch: &mut Page,
head: u64,
next_new: u64,
free_count: u64,
) -> PageResult<Vec<u64>> {
let mut ids = Vec::new();
let mut seen = HashSet::new();
let mut cur = head;
while cur != NO_PAGE {
if ids.len() as u64 >= free_count {
return Err(PageError::InvalidSuperblock);
}
if cur == 0 || cur >= next_new {
return Err(PageError::InvalidSuperblock);
}
if !seen.insert(cur) {
return Err(PageError::InvalidSuperblock);
}
ids.push(cur);
store.read_into(PageId::new(cur), scratch)?;
cur = read_u64(scratch.payload(), LINK_NEXT_OFF);
}
if ids.len() as u64 != free_count {
return Err(PageError::InvalidSuperblock);
}
Ok(ids)
}
#[inline]
fn read_u16(bytes: &[u8], off: usize) -> u16 {
u16::from_le_bytes([bytes[off], bytes[off + 1]])
}
#[inline]
fn read_u32(bytes: &[u8], off: usize) -> u32 {
u32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]])
}
#[inline]
fn read_u64(bytes: &[u8], off: usize) -> u64 {
u64::from_le_bytes([
bytes[off],
bytes[off + 1],
bytes[off + 2],
bytes[off + 3],
bytes[off + 4],
bytes[off + 5],
bytes[off + 6],
bytes[off + 7],
])
}
#[inline]
fn write_u16(bytes: &mut [u8], off: usize, value: u16) {
bytes[off..off + 2].copy_from_slice(&value.to_le_bytes());
}
#[inline]
fn write_u32(bytes: &mut [u8], off: usize, value: u32) {
bytes[off..off + 4].copy_from_slice(&value.to_le_bytes());
}
#[inline]
fn write_u64(bytes: &mut [u8], off: usize, value: u64) {
bytes[off..off + 8].copy_from_slice(&value.to_le_bytes());
}
#[cfg(all(test, not(loom)))]
mod tests {
#![allow(clippy::unwrap_used, clippy::expect_used)]
use std::collections::HashSet;
use proptest::prelude::*;
use super::*;
use crate::test_store::MemStore;
const PS: usize = 4096;
fn allocator() -> PageAllocator<MemStore> {
PageAllocator::new(MemStore::new(PS)).unwrap()
}
#[test]
fn test_allocate_starts_at_one_and_increments() {
let alloc = allocator();
assert_eq!(alloc.allocate().unwrap(), PageId::new(1));
assert_eq!(alloc.allocate().unwrap(), PageId::new(2));
assert_eq!(alloc.allocate().unwrap(), PageId::new(3));
assert_eq!(alloc.high_water(), 4);
}
#[test]
fn test_free_then_allocate_reuses_id() {
let alloc = allocator();
let a = alloc.allocate().unwrap();
let b = alloc.allocate().unwrap();
alloc.free(a).unwrap();
assert_eq!(alloc.free_count(), 1);
let c = alloc.allocate().unwrap();
assert_eq!(c, a);
assert_ne!(c, b);
assert_eq!(alloc.free_count(), 0);
}
#[test]
fn test_free_list_is_lifo() {
let alloc = allocator();
let ids: Vec<_> = (0..4).map(|_| alloc.allocate().unwrap()).collect();
for &id in &ids {
alloc.free(id).unwrap();
}
let mut reused = Vec::new();
for _ in 0..4 {
reused.push(alloc.allocate().unwrap());
}
let expected: Vec<_> = ids.into_iter().rev().collect();
assert_eq!(reused, expected);
}
#[test]
fn test_free_rejects_superblock_and_unallocated() {
let alloc = allocator();
let _ = alloc.allocate().unwrap(); assert!(matches!(
alloc.free(PageId::new(0)),
Err(PageError::InvalidPageId { page_id: 0 })
));
assert!(matches!(
alloc.free(PageId::new(5)),
Err(PageError::InvalidPageId { page_id: 5 })
));
}
#[test]
fn test_state_survives_reopen() {
let store = MemStore::new(PS);
{
let alloc = PageAllocator::new(store).unwrap();
let _ = alloc.allocate().unwrap(); let b = alloc.allocate().unwrap(); let _ = alloc.allocate().unwrap(); alloc.free(b).unwrap(); alloc.sync().unwrap(); let alloc2 = PageAllocator::new(alloc.into_store()).unwrap();
assert_eq!(alloc2.high_water(), 4);
assert_eq!(alloc2.free_count(), 1);
assert_eq!(alloc2.allocate().unwrap(), PageId::new(2));
}
}
#[test]
fn test_new_rejects_non_superblock_page_zero() {
let store = MemStore::new(PS);
{
let mut page = store.allocate_page();
page.payload_mut()[0] = 0xFF;
store.write_page(PageId::new(0), &mut page).unwrap();
}
assert!(matches!(
PageAllocator::new(store),
Err(PageError::InvalidSuperblock)
));
}
fn write_superblock(store: &MemStore, head: u64, next_new: u64, free_count: u64) {
let mut page = store.allocate_page();
let payload = page.payload_mut();
write_u32(payload, SB_MAGIC_OFF, SB_MAGIC);
write_u16(payload, SB_VERSION_OFF, SB_VERSION);
write_u64(payload, SB_HEAD_OFF, head);
write_u64(payload, SB_NEXT_OFF, next_new);
write_u64(payload, SB_FREECOUNT_OFF, free_count);
store.write_page(PageId::new(0), &mut page).unwrap();
}
fn write_link(store: &MemStore, id: u64, next: u64) {
let mut page = store.allocate_page();
write_u64(page.payload_mut(), LINK_NEXT_OFF, next);
store.write_page(PageId::new(id), &mut page).unwrap();
}
#[test]
fn test_new_rejects_cycled_free_chain() {
let store = MemStore::new(PS);
write_superblock(&store, 1, 3, 10); write_link(&store, 1, 2);
write_link(&store, 2, 1); assert!(matches!(
PageAllocator::new(store),
Err(PageError::InvalidSuperblock)
));
}
#[test]
fn test_new_rejects_out_of_range_link() {
let store = MemStore::new(PS);
write_superblock(&store, 5, 3, 1); assert!(matches!(
PageAllocator::new(store),
Err(PageError::InvalidSuperblock)
));
}
#[test]
fn test_new_rejects_free_count_mismatch() {
let store = MemStore::new(PS);
write_superblock(&store, 1, 3, 5); write_link(&store, 1, NO_PAGE); assert!(matches!(
PageAllocator::new(store),
Err(PageError::InvalidSuperblock)
));
}
impl PageAllocator<MemStore> {
fn into_store(self) -> MemStore {
self.store
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(48))]
#[test]
fn allocate_free_never_double_allocates(
ops in proptest::collection::vec(any::<bool>(), 1..200),
) {
let alloc = allocator();
let mut live: HashSet<u64> = HashSet::new();
let mut freed_pool: Vec<u64> = Vec::new();
for want_alloc in ops {
if want_alloc || live.is_empty() {
let id = alloc.allocate().unwrap().get();
prop_assert!(!live.contains(&id), "id {} double-allocated", id);
prop_assert!(id >= 1, "id 0 is reserved");
let _ = live.insert(id);
freed_pool.retain(|&f| f != id);
} else {
let victim = *live.iter().next().unwrap();
let _ = live.remove(&victim);
alloc.free(PageId::new(victim)).unwrap();
freed_pool.push(victim);
}
prop_assert_eq!(alloc.free_count(), freed_pool.len() as u64);
}
}
}
}
#[cfg(all(test, loom))]
mod loom_tests {
use super::*;
use crate::sync::Arc;
use crate::test_store::MemStore;
#[test]
fn loom_concurrent_allocate_is_unique() {
loom::model(|| {
let alloc = Arc::new(PageAllocator::new(MemStore::new(4096)).unwrap());
let a = Arc::clone(&alloc);
let t = loom::thread::spawn(move || a.allocate().unwrap());
let first = alloc.allocate().unwrap();
let second = t.join().unwrap();
assert_ne!(first, second);
});
}
}