use std::fs::{File, OpenOptions};
use std::io::{Seek, SeekFrom, Write as IoWrite};
use std::path::Path;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use memmap2::{MmapMut, MmapOptions};
use parking_lot::RwLock;
use super::block_storage::BlockStorage;
use super::error::{PersistentARTrieError, Result};
pub const BLOCK_SIZE: usize = 256 * 1024;
pub const MAX_BLOCK_COUNT: u32 = 1 << 24;
pub const MAGIC_NUMBER: u64 = 0x5041_5254_0001_0000;
pub const FORMAT_VERSION: u32 = 2;
#[repr(C, align(64))]
#[derive(Debug)]
pub struct FileHeader {
pub magic: u64,
pub version: u32,
pub flags: u32,
pub root_ptr: AtomicU64,
pub block_count: AtomicU32,
_pad1: u32,
pub free_list_head: AtomicU64,
pub entry_count: AtomicU64,
pub checksum: u64,
pub image_checkpoint_lsn: AtomicU64,
}
impl FileHeader {
pub fn new() -> Self {
Self {
magic: MAGIC_NUMBER,
version: FORMAT_VERSION,
flags: 0,
root_ptr: AtomicU64::new(0),
block_count: AtomicU32::new(1), _pad1: 0,
free_list_head: AtomicU64::new(0),
entry_count: AtomicU64::new(0),
checksum: 0,
image_checkpoint_lsn: AtomicU64::new(0),
}
}
pub fn validate(&self) -> Result<()> {
if self.magic != MAGIC_NUMBER {
return Err(PersistentARTrieError::InvalidMagic {
expected: MAGIC_NUMBER,
found: self.magic,
});
}
if self.version > FORMAT_VERSION {
return Err(PersistentARTrieError::UnsupportedVersion {
max_supported: FORMAT_VERSION,
found: self.version,
});
}
Ok(())
}
pub fn compute_checksum(&self) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325; let prime: u64 = 0x100000001b3;
for byte in self.magic.to_le_bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(prime);
}
for byte in self.version.to_le_bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(prime);
}
for byte in self.flags.to_le_bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(prime);
}
for byte in self.root_ptr.load(Ordering::SeqCst).to_le_bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(prime);
}
for byte in self.block_count.load(Ordering::SeqCst).to_le_bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(prime);
}
for byte in self.free_list_head.load(Ordering::SeqCst).to_le_bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(prime);
}
for byte in self.entry_count.load(Ordering::SeqCst).to_le_bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(prime);
}
if self.version >= 2 {
for byte in self
.image_checkpoint_lsn
.load(Ordering::SeqCst)
.to_le_bytes()
{
hash ^= byte as u64;
hash = hash.wrapping_mul(prime);
}
}
hash
}
pub fn update_checksum(&mut self) {
self.checksum = self.compute_checksum();
}
pub fn verify_checksum(&self) -> bool {
self.checksum == self.compute_checksum()
}
pub fn to_bytes(&self) -> [u8; 64] {
let mut bytes = [0u8; 64];
bytes[0..8].copy_from_slice(&self.magic.to_le_bytes());
bytes[8..12].copy_from_slice(&self.version.to_le_bytes());
bytes[12..16].copy_from_slice(&self.flags.to_le_bytes());
bytes[16..24].copy_from_slice(&self.root_ptr.load(Ordering::SeqCst).to_le_bytes());
bytes[24..28].copy_from_slice(&self.block_count.load(Ordering::SeqCst).to_le_bytes());
bytes[28..32].copy_from_slice(&0u32.to_le_bytes()); bytes[32..40].copy_from_slice(&self.free_list_head.load(Ordering::SeqCst).to_le_bytes());
bytes[40..48].copy_from_slice(&self.entry_count.load(Ordering::SeqCst).to_le_bytes());
bytes[48..56].copy_from_slice(&self.checksum.to_le_bytes());
bytes[56..64].copy_from_slice(
&self
.image_checkpoint_lsn
.load(Ordering::SeqCst)
.to_le_bytes(),
);
bytes
}
pub fn from_bytes(bytes: &[u8; 64]) -> Self {
Self {
magic: u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
version: u32::from_le_bytes(bytes[8..12].try_into().unwrap()),
flags: u32::from_le_bytes(bytes[12..16].try_into().unwrap()),
root_ptr: AtomicU64::new(u64::from_le_bytes(bytes[16..24].try_into().unwrap())),
block_count: AtomicU32::new(u32::from_le_bytes(bytes[24..28].try_into().unwrap())),
_pad1: 0,
free_list_head: AtomicU64::new(u64::from_le_bytes(bytes[32..40].try_into().unwrap())),
entry_count: AtomicU64::new(u64::from_le_bytes(bytes[40..48].try_into().unwrap())),
checksum: u64::from_le_bytes(bytes[48..56].try_into().unwrap()),
image_checkpoint_lsn: AtomicU64::new(u64::from_le_bytes(
bytes[56..64].try_into().unwrap(),
)),
}
}
}
impl Default for FileHeader {
fn default() -> Self {
Self::new()
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct FreeBlockEntry {
pub next: u64,
}
pub type DiskManager = MmapDiskManager;
pub struct MmapDiskManager {
file: File,
mmap: Option<RwLock<MmapMut>>,
file_size: AtomicU64,
block_count: AtomicU32,
path: String,
}
impl MmapDiskManager {
fn validate_byte_range(
block_id: u32,
offset_in_block: usize,
len: usize,
) -> Result<(usize, usize)> {
if offset_in_block > BLOCK_SIZE || len > BLOCK_SIZE.saturating_sub(offset_in_block) {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Range [{}, {}) exceeds block size {}",
offset_in_block,
offset_in_block.saturating_add(len),
BLOCK_SIZE
),
});
}
let block_offset = (block_id as usize).checked_mul(BLOCK_SIZE).ok_or_else(|| {
PersistentARTrieError::InvalidBlockId {
block_id,
reason: "Block offset overflowed usize".to_string(),
}
})?;
let file_offset = block_offset.checked_add(offset_in_block).ok_or_else(|| {
PersistentARTrieError::InvalidBlockId {
block_id,
reason: "File offset overflowed usize".to_string(),
}
})?;
let end_offset =
file_offset
.checked_add(len)
.ok_or_else(|| PersistentARTrieError::InvalidBlockId {
block_id,
reason: "End offset overflowed usize".to_string(),
})?;
Ok((file_offset, end_offset))
}
pub fn create<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_str = path.as_ref().to_string_lossy().to_string();
if let Some(parent) = path.as_ref().parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent).map_err(|e| PersistentARTrieError::IoError {
operation: "create parent directory".to_string(),
path: parent.display().to_string(),
source: e,
})?;
}
}
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&path)
.map_err(|e| PersistentARTrieError::IoError {
operation: "create file".to_string(),
path: path_str.clone(),
source: e,
})?;
let metadata = file
.metadata()
.map_err(|e| PersistentARTrieError::IoError {
operation: "get metadata".to_string(),
path: path_str.clone(),
source: e,
})?;
let file_size = metadata.len();
if file_size == 0 {
Self::initialize_file(&file, &path_str)?;
}
let file_size = file
.metadata()
.map_err(|e| PersistentARTrieError::IoError {
operation: "get metadata after init".to_string(),
path: path_str.clone(),
source: e,
})?
.len();
let mmap = if file_size > 0 {
let mmap = unsafe {
MmapOptions::new()
.len(file_size as usize)
.map_mut(&file)
.map_err(|e| PersistentARTrieError::MmapError {
operation: "create mmap".to_string(),
source: e,
})?
};
Some(RwLock::new(mmap))
} else {
None
};
let block_count = (file_size / BLOCK_SIZE as u64) as u32;
Ok(Self {
file,
mmap,
file_size: AtomicU64::new(file_size),
block_count: AtomicU32::new(block_count),
path: path_str,
})
}
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_str = path.as_ref().to_string_lossy().to_string();
let file = OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.map_err(|e| PersistentARTrieError::IoError {
operation: "open file".to_string(),
path: path_str.clone(),
source: e,
})?;
let file_size = file
.metadata()
.map_err(|e| PersistentARTrieError::IoError {
operation: "get metadata".to_string(),
path: path_str.clone(),
source: e,
})?
.len();
if file_size < BLOCK_SIZE as u64 {
return Err(PersistentARTrieError::CorruptedFile {
reason: "File too small to contain header block".to_string(),
});
}
let mmap = unsafe {
MmapOptions::new()
.len(file_size as usize)
.map_mut(&file)
.map_err(|e| PersistentARTrieError::MmapError {
operation: "create mmap".to_string(),
source: e,
})?
};
let block_count = (file_size / BLOCK_SIZE as u64) as u32;
let manager = Self {
file,
mmap: Some(RwLock::new(mmap)),
file_size: AtomicU64::new(file_size),
block_count: AtomicU32::new(block_count),
path: path_str,
};
let header = manager.read_header()?;
header.validate()?;
if !header.verify_checksum() {
return Err(PersistentARTrieError::ChecksumMismatch {
block_id: 0,
expected: header.compute_checksum(),
found: header.checksum,
});
}
Ok(manager)
}
pub fn open_without_validation<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_str = path.as_ref().to_string_lossy().to_string();
let file = OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.map_err(|e| PersistentARTrieError::IoError {
operation: "open file".to_string(),
path: path_str.clone(),
source: e,
})?;
let file_size = file
.metadata()
.map_err(|e| PersistentARTrieError::IoError {
operation: "get metadata".to_string(),
path: path_str.clone(),
source: e,
})?
.len();
if file_size < BLOCK_SIZE as u64 {
return Err(PersistentARTrieError::CorruptedFile {
reason: "File too small to contain header block".to_string(),
});
}
let mmap = unsafe {
MmapOptions::new()
.len(file_size as usize)
.map_mut(&file)
.map_err(|e| PersistentARTrieError::MmapError {
operation: "create mmap".to_string(),
source: e,
})?
};
let block_count = (file_size / BLOCK_SIZE as u64) as u32;
Ok(Self {
file,
mmap: Some(RwLock::new(mmap)),
file_size: AtomicU64::new(file_size),
block_count: AtomicU32::new(block_count),
path: path_str,
})
}
fn initialize_file(file: &File, path: &str) -> Result<()> {
file.set_len(BLOCK_SIZE as u64)
.map_err(|e| PersistentARTrieError::IoError {
operation: "set initial file length".to_string(),
path: path.to_string(),
source: e,
})?;
let mut header = FileHeader::new();
header.update_checksum();
let mut file_writer = file;
file_writer
.seek(SeekFrom::Start(0))
.map_err(|e| PersistentARTrieError::IoError {
operation: "seek to header".to_string(),
path: path.to_string(),
source: e,
})?;
file_writer
.write_all(&header.to_bytes())
.map_err(|e| PersistentARTrieError::IoError {
operation: "write header".to_string(),
path: path.to_string(),
source: e,
})?;
file_writer
.sync_all()
.map_err(|e| PersistentARTrieError::IoError {
operation: "sync after header write".to_string(),
path: path.to_string(),
source: e,
})?;
Ok(())
}
pub fn read_header(&self) -> Result<FileHeader> {
let mmap_guard =
self.mmap
.as_ref()
.ok_or_else(|| PersistentARTrieError::CorruptedFile {
reason: "No memory map available".to_string(),
})?;
let mmap = mmap_guard.read();
if mmap.len() < 64 {
return Err(PersistentARTrieError::CorruptedFile {
reason: "File too small for header".to_string(),
});
}
let bytes: [u8; 64] =
mmap[0..64]
.try_into()
.map_err(|_| PersistentARTrieError::CorruptedFile {
reason: "Failed to read header bytes".to_string(),
})?;
Ok(FileHeader::from_bytes(&bytes))
}
pub fn write_header(&self, header: &FileHeader) -> Result<()> {
let mmap_guard =
self.mmap
.as_ref()
.ok_or_else(|| PersistentARTrieError::CorruptedFile {
reason: "No memory map available".to_string(),
})?;
let mut mmap = mmap_guard.write();
let bytes = header.to_bytes();
mmap[0..64].copy_from_slice(&bytes);
Ok(())
}
pub fn path(&self) -> &str {
&self.path
}
pub fn allocate_block(&self) -> Result<u32> {
let header = self.read_header()?;
let free_head = header.free_list_head.load(Ordering::Acquire);
if free_head != 0 {
let block_id = (free_head >> 40) as u32;
let next = self.read_free_block_next(block_id)?;
header.free_list_head.store(next, Ordering::Release);
self.write_header(&header)?;
return Ok(block_id);
}
loop {
let current_count = self.block_count.load(Ordering::Acquire);
if current_count >= MAX_BLOCK_COUNT {
return Err(PersistentARTrieError::OutOfSpace {
current_blocks: current_count,
max_blocks: MAX_BLOCK_COUNT,
});
}
let new_block_id = current_count;
let new_count = current_count + 1;
let new_file_size = new_count as u64 * BLOCK_SIZE as u64;
match self.block_count.compare_exchange(
current_count,
new_count,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
{
let mmap_guard = self.mmap.as_ref().expect("mmap should exist");
let mut mmap = mmap_guard.write();
let current_actual_size = self
.file
.metadata()
.map_err(|e| PersistentARTrieError::IoError {
operation: "get file metadata before extend".to_string(),
path: self.path.clone(),
source: e,
})?
.len();
if new_file_size > current_actual_size {
self.file.set_len(new_file_size).map_err(|e| {
PersistentARTrieError::IoError {
operation: "extend file".to_string(),
path: self.path.clone(),
source: e,
}
})?;
}
#[cfg(unix)]
{
use std::os::unix::fs::FileExt;
let offset = new_block_id as u64 * BLOCK_SIZE as u64;
let zeros = [0u8; 8];
let _ = self.file.write_at(&zeros, offset);
}
let actual_file_size = self
.file
.metadata()
.map_err(|e| PersistentARTrieError::IoError {
operation: "get file metadata for remap".to_string(),
path: self.path.clone(),
source: e,
})?
.len();
let remap_size = actual_file_size.max(new_file_size);
if remap_size as usize > mmap.len() {
let new_mmap = unsafe {
MmapOptions::new()
.len(remap_size as usize)
.map_mut(&self.file)
.map_err(|e| PersistentARTrieError::MmapError {
operation: "remap after extend".to_string(),
source: e,
})?
};
*mmap = new_mmap;
}
std::sync::atomic::fence(Ordering::SeqCst);
loop {
let current = self.file_size.load(Ordering::Acquire);
if remap_size <= current {
break;
}
match self.file_size.compare_exchange(
current,
remap_size,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(_) => continue, }
}
}
self.persist_header_block_count(new_count);
return Ok(new_block_id);
}
Err(_) => {
std::hint::spin_loop();
continue;
}
}
}
}
fn persist_header_block_count(&self, count: u32) {
if let Some(mmap_guard) = self.mmap.as_ref() {
if let Some(mut mmap) = mmap_guard.try_write() {
const BLOCK_COUNT_OFFSET: usize = 24;
mmap[BLOCK_COUNT_OFFSET..BLOCK_COUNT_OFFSET + 4]
.copy_from_slice(&count.to_le_bytes());
}
}
}
pub fn free_block(&self, block_id: u32) -> Result<()> {
if block_id == 0 {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: "Cannot free header block".to_string(),
});
}
let block_count = self.block_count.load(Ordering::Acquire);
let header = self.read_header()?;
if block_id >= block_count {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!("Block ID {} >= block count {}", block_id, block_count),
});
}
let old_head = header.free_list_head.load(Ordering::SeqCst);
self.write_free_block_next(block_id, old_head)?;
let new_head = (block_id as u64) << 40;
header.free_list_head.store(new_head, Ordering::SeqCst);
let mut updated_header = self.read_header()?;
updated_header
.free_list_head
.store(new_head, Ordering::SeqCst);
updated_header.checksum = updated_header.compute_checksum();
self.write_header(&updated_header)?;
Ok(())
}
fn read_free_block_next(&self, block_id: u32) -> Result<u64> {
let offset = block_id as usize * BLOCK_SIZE;
let mmap_guard =
self.mmap
.as_ref()
.ok_or_else(|| PersistentARTrieError::CorruptedFile {
reason: "No memory map available".to_string(),
})?;
let mmap = mmap_guard.read();
if offset + 8 > mmap.len() {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: "Block offset exceeds file size".to_string(),
});
}
let bytes: [u8; 8] = mmap[offset..offset + 8].try_into().map_err(|_| {
PersistentARTrieError::CorruptedFile {
reason: "Failed to read free block next pointer".to_string(),
}
})?;
Ok(u64::from_le_bytes(bytes))
}
fn write_free_block_next(&self, block_id: u32, next: u64) -> Result<()> {
let offset = block_id as usize * BLOCK_SIZE;
let mmap_guard =
self.mmap
.as_ref()
.ok_or_else(|| PersistentARTrieError::CorruptedFile {
reason: "No memory map available".to_string(),
})?;
let mut mmap = mmap_guard.write();
if offset + 8 > mmap.len() {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: "Block offset exceeds file size".to_string(),
});
}
mmap[offset..offset + 8].copy_from_slice(&next.to_le_bytes());
Ok(())
}
pub fn read_block(&self, block_id: u32, buffer: &mut [u8; BLOCK_SIZE]) -> Result<()> {
let offset = block_id as usize * BLOCK_SIZE;
let end_offset = offset + BLOCK_SIZE;
let current_block_count = self.block_count.load(Ordering::Acquire);
if block_id >= current_block_count {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Block ID {} >= block count {}",
block_id, current_block_count
),
});
}
let mmap_guard =
self.mmap
.as_ref()
.ok_or_else(|| PersistentARTrieError::CorruptedFile {
reason: "No memory map available".to_string(),
})?;
let mmap = mmap_guard.read();
let current_file_size = self.file_size.load(Ordering::Acquire);
if end_offset as u64 > current_file_size {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Block {} not yet accessible (file_size={}, need={})",
block_id, current_file_size, end_offset
),
});
}
buffer.copy_from_slice(&mmap[offset..end_offset]);
Ok(())
}
pub fn write_block(&self, block_id: u32, buffer: &[u8; BLOCK_SIZE]) -> Result<()> {
let offset = block_id as usize * BLOCK_SIZE;
let end_offset = offset + BLOCK_SIZE;
let current_block_count = self.block_count.load(Ordering::Acquire);
if block_id >= current_block_count {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Block ID {} >= block count {}",
block_id, current_block_count
),
});
}
let mmap_guard =
self.mmap
.as_ref()
.ok_or_else(|| PersistentARTrieError::CorruptedFile {
reason: "No memory map available".to_string(),
})?;
let mut mmap = mmap_guard.write();
let current_file_size = self.file_size.load(Ordering::Acquire);
if end_offset as u64 > current_file_size {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Block {} not yet accessible (file_size={}, need={})",
block_id, current_file_size, end_offset
),
});
}
mmap[offset..end_offset].copy_from_slice(buffer);
Ok(())
}
pub fn read_bytes(
&self,
block_id: u32,
offset_in_block: usize,
buffer: &mut [u8],
) -> Result<()> {
let (file_offset, end_offset) =
Self::validate_byte_range(block_id, offset_in_block, buffer.len())?;
let current_block_count = self.block_count.load(Ordering::Acquire);
if block_id >= current_block_count {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Block ID {} >= block count {}",
block_id, current_block_count
),
});
}
let mmap_guard =
self.mmap
.as_ref()
.ok_or_else(|| PersistentARTrieError::CorruptedFile {
reason: "No memory map available".to_string(),
})?;
let mmap = mmap_guard.read();
let current_file_size = self.file_size.load(Ordering::Acquire);
if end_offset as u64 > current_file_size {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Read range [{}, {}) not accessible (file_size={})",
file_offset, end_offset, current_file_size
),
});
}
buffer.copy_from_slice(&mmap[file_offset..end_offset]);
Ok(())
}
pub fn write_bytes(&self, block_id: u32, offset_in_block: usize, buffer: &[u8]) -> Result<()> {
let (file_offset, end_offset) =
Self::validate_byte_range(block_id, offset_in_block, buffer.len())?;
let current_block_count = self.block_count.load(Ordering::Acquire);
if block_id >= current_block_count {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Block ID {} >= block count {}",
block_id, current_block_count
),
});
}
let mmap_guard =
self.mmap
.as_ref()
.ok_or_else(|| PersistentARTrieError::CorruptedFile {
reason: "No memory map available".to_string(),
})?;
let mut mmap = mmap_guard.write();
let current_file_size = self.file_size.load(Ordering::Acquire);
if end_offset as u64 > current_file_size {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Write range [{}, {}) not accessible (file_size={})",
file_offset, end_offset, current_file_size
),
});
}
mmap[file_offset..end_offset].copy_from_slice(buffer);
Ok(())
}
pub fn sync(&self) -> Result<()> {
if let Some(mmap_guard) = &self.mmap {
let mut mmap = mmap_guard.write();
if mmap.len() < 64 {
return Err(PersistentARTrieError::CorruptedFile {
reason: "File too small for header checksum refresh".to_string(),
});
}
let header_bytes: [u8; 64] =
mmap[0..64]
.try_into()
.map_err(|_| PersistentARTrieError::CorruptedFile {
reason: "Failed to read header bytes for checksum refresh".to_string(),
})?;
let mut header = FileHeader::from_bytes(&header_bytes);
header.checksum = header.compute_checksum();
mmap[0..64].copy_from_slice(&header.to_bytes());
mmap.flush().map_err(|e| PersistentARTrieError::MmapError {
operation: "flush mmap".to_string(),
source: e,
})?;
}
self.file
.sync_all()
.map_err(|e| PersistentARTrieError::IoError {
operation: "sync file".to_string(),
path: self.path.clone(),
source: e,
})?;
Ok(())
}
pub fn file_size(&self) -> u64 {
self.file_size.load(Ordering::SeqCst)
}
pub fn block_count(&self) -> Result<u32> {
Ok(self.block_count.load(Ordering::Acquire))
}
pub fn entry_count(&self) -> Result<u64> {
let header = self.read_header()?;
Ok(header.entry_count.load(Ordering::SeqCst))
}
pub fn set_entry_count(&self, count: u64) -> Result<()> {
let header = self.read_header()?;
header.entry_count.store(count, Ordering::SeqCst);
let mut updated_header = header;
updated_header.checksum = updated_header.compute_checksum();
self.write_header(&updated_header)?;
Ok(())
}
pub fn root_ptr(&self) -> Result<u64> {
let header = self.read_header()?;
Ok(header.root_ptr.load(Ordering::SeqCst))
}
pub fn set_root_ptr(&self, ptr: u64) -> Result<()> {
let header = self.read_header()?;
header.root_ptr.store(ptr, Ordering::SeqCst);
let mut updated_header = header;
updated_header.checksum = updated_header.compute_checksum();
self.write_header(&updated_header)?;
Ok(())
}
pub fn image_checkpoint_lsn(&self) -> Result<u64> {
let header = self.read_header()?;
Ok(header.image_checkpoint_lsn.load(Ordering::SeqCst))
}
pub fn set_image_checkpoint_lsn(&self, lsn: u64) -> Result<()> {
let header = self.read_header()?;
header.image_checkpoint_lsn.store(lsn, Ordering::SeqCst);
let mut updated_header = header;
updated_header.version = FORMAT_VERSION;
updated_header.checksum = updated_header.compute_checksum();
self.write_header(&updated_header)?;
Ok(())
}
pub fn write_header_bytes(&self, bytes: &[u8]) -> Result<()> {
self.write_bytes(0, 0, bytes)
}
pub fn read_header_bytes(&self, buffer: &mut [u8]) -> Result<()> {
self.read_bytes(0, 0, buffer)
}
pub unsafe fn raw_ptr(&self, block_id: u32, offset_in_block: usize) -> Result<*const u8> {
if offset_in_block >= BLOCK_SIZE {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!(
"Offset {} exceeds block size {}",
offset_in_block, BLOCK_SIZE
),
});
}
let (file_offset, _) = Self::validate_byte_range(block_id, offset_in_block, 1)?;
let mmap_guard =
self.mmap
.as_ref()
.ok_or_else(|| PersistentARTrieError::CorruptedFile {
reason: "No memory map available".to_string(),
})?;
let mmap = mmap_guard.read();
if file_offset >= mmap.len() {
return Err(PersistentARTrieError::InvalidBlockId {
block_id,
reason: format!("Offset {} exceeds file size {}", file_offset, mmap.len()),
});
}
Ok(mmap.as_ptr().add(file_offset))
}
}
impl BlockStorage for MmapDiskManager {
fn read_block(&self, block_id: u32, buffer: &mut [u8; BLOCK_SIZE]) -> Result<()> {
MmapDiskManager::read_block(self, block_id, buffer)
}
fn write_block(&self, block_id: u32, buffer: &[u8; BLOCK_SIZE]) -> Result<()> {
MmapDiskManager::write_block(self, block_id, buffer)
}
fn read_bytes(&self, block_id: u32, offset: usize, buffer: &mut [u8]) -> Result<()> {
MmapDiskManager::read_bytes(self, block_id, offset, buffer)
}
fn write_bytes(&self, block_id: u32, offset: usize, data: &[u8]) -> Result<()> {
MmapDiskManager::write_bytes(self, block_id, offset, data)
}
fn allocate_block(&self) -> Result<u32> {
MmapDiskManager::allocate_block(self)
}
fn free_block(&self, block_id: u32) -> Result<()> {
MmapDiskManager::free_block(self, block_id)
}
fn read_header(&self) -> Result<FileHeader> {
MmapDiskManager::read_header(self)
}
fn write_header(&self, header: &FileHeader) -> Result<()> {
MmapDiskManager::write_header(self, header)
}
fn read_header_bytes(&self, buffer: &mut [u8]) -> Result<()> {
MmapDiskManager::read_header_bytes(self, buffer)
}
fn write_header_bytes(&self, bytes: &[u8]) -> Result<()> {
MmapDiskManager::write_header_bytes(self, bytes)
}
fn root_ptr(&self) -> Result<u64> {
MmapDiskManager::root_ptr(self)
}
fn set_root_ptr(&self, ptr: u64) -> Result<()> {
MmapDiskManager::set_root_ptr(self, ptr)
}
fn entry_count(&self) -> Result<u64> {
MmapDiskManager::entry_count(self)
}
fn set_entry_count(&self, count: u64) -> Result<()> {
MmapDiskManager::set_entry_count(self, count)
}
fn image_checkpoint_lsn(&self) -> Result<u64> {
MmapDiskManager::image_checkpoint_lsn(self)
}
fn set_image_checkpoint_lsn(&self, lsn: u64) -> Result<()> {
MmapDiskManager::set_image_checkpoint_lsn(self, lsn)
}
fn file_size(&self) -> u64 {
MmapDiskManager::file_size(self)
}
fn block_count(&self) -> Result<u32> {
MmapDiskManager::block_count(self)
}
fn path(&self) -> &str {
MmapDiskManager::path(self)
}
fn sync(&self) -> Result<()> {
MmapDiskManager::sync(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_file_header_serialization() {
let header = FileHeader::new();
let bytes = header.to_bytes();
let restored = FileHeader::from_bytes(&bytes);
assert_eq!(restored.magic, MAGIC_NUMBER);
assert_eq!(restored.version, FORMAT_VERSION);
assert_eq!(restored.flags, 0);
assert_eq!(restored.root_ptr.load(Ordering::SeqCst), 0);
assert_eq!(restored.block_count.load(Ordering::SeqCst), 1);
assert_eq!(restored.free_list_head.load(Ordering::SeqCst), 0);
assert_eq!(restored.entry_count.load(Ordering::SeqCst), 0);
}
#[test]
fn test_header_checksum() {
let mut header = FileHeader::new();
header.update_checksum();
assert!(header.verify_checksum());
header.entry_count.store(42, Ordering::SeqCst);
assert!(!header.verify_checksum());
header.update_checksum();
assert!(header.verify_checksum());
}
#[test]
fn image_checkpoint_lsn_round_trips_through_v2_bytes() {
let mut header = FileHeader::new();
assert_eq!(header.version, FORMAT_VERSION, "new headers are v2");
header.image_checkpoint_lsn.store(42, Ordering::SeqCst);
header.update_checksum();
let restored = FileHeader::from_bytes(&header.to_bytes());
assert_eq!(
restored.image_checkpoint_lsn.load(Ordering::SeqCst),
42,
"image_checkpoint_lsn must round-trip through bytes[56..64]"
);
assert!(
restored.verify_checksum(),
"the round-tripped v2 checksum must verify"
);
}
#[test]
fn image_checkpoint_lsn_is_covered_by_the_v2_checksum() {
let mut header = FileHeader::new();
header.image_checkpoint_lsn.store(7, Ordering::SeqCst);
header.update_checksum();
assert!(header.verify_checksum());
header.image_checkpoint_lsn.store(9, Ordering::SeqCst);
assert!(
!header.verify_checksum(),
"#48: the image-coverage frontier must be inside the v2 checksum (fail-closed on a torn write)"
);
}
#[test]
fn v1_header_excludes_image_checkpoint_lsn_from_checksum() {
let mut h0 = FileHeader::new();
h0.version = 1;
h0.image_checkpoint_lsn.store(0, Ordering::SeqCst);
let mut h99 = FileHeader::new();
h99.version = 1;
h99.image_checkpoint_lsn.store(99, Ordering::SeqCst);
assert_eq!(
h0.compute_checksum(),
h99.compute_checksum(),
"a v1 checksum must NOT depend on the coverage field (byte-identical to pre-#48)"
);
}
#[test]
fn test_create_and_open() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test.part");
{
let dm = DiskManager::create(&path).expect("Failed to create DiskManager");
assert_eq!(dm.file_size(), BLOCK_SIZE as u64);
let header = dm.read_header().expect("Failed to read header");
assert_eq!(header.magic, MAGIC_NUMBER);
assert_eq!(header.block_count.load(Ordering::SeqCst), 1);
}
{
let dm = DiskManager::open(&path).expect("Failed to open DiskManager");
let header = dm.read_header().expect("Failed to read header");
assert_eq!(header.magic, MAGIC_NUMBER);
}
}
#[test]
fn test_allocate_blocks() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_alloc.part");
let dm = DiskManager::create(&path).expect("Failed to create DiskManager");
let block1 = dm.allocate_block().expect("Failed to allocate block 1");
let block2 = dm.allocate_block().expect("Failed to allocate block 2");
let block3 = dm.allocate_block().expect("Failed to allocate block 3");
assert_eq!(block1, 1);
assert_eq!(block2, 2);
assert_eq!(block3, 3);
assert_eq!(dm.block_count().expect("Failed to get block count"), 4);
assert_eq!(dm.file_size(), 4 * BLOCK_SIZE as u64);
}
#[test]
fn test_free_list() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_free.part");
let dm = DiskManager::create(&path).expect("Failed to create DiskManager");
let block1 = dm.allocate_block().expect("alloc 1");
let block2 = dm.allocate_block().expect("alloc 2");
let block3 = dm.allocate_block().expect("alloc 3");
assert_eq!(block1, 1);
assert_eq!(block2, 2);
assert_eq!(block3, 3);
dm.free_block(block2).expect("free block 2");
let block4 = dm.allocate_block().expect("alloc 4");
assert_eq!(block4, 2);
let block5 = dm.allocate_block().expect("alloc 5");
assert_eq!(block5, 4);
}
#[test]
fn test_read_write_block() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_rw.part");
let dm = DiskManager::create(&path).expect("Failed to create DiskManager");
let block_id = dm.allocate_block().expect("Failed to allocate block");
let mut write_buf = [0u8; BLOCK_SIZE];
write_buf[0] = 0xDE;
write_buf[1] = 0xAD;
write_buf[2] = 0xBE;
write_buf[3] = 0xEF;
write_buf[BLOCK_SIZE - 1] = 0xFF;
dm.write_block(block_id, &write_buf)
.expect("Failed to write block");
let mut read_buf = [0u8; BLOCK_SIZE];
dm.read_block(block_id, &mut read_buf)
.expect("Failed to read block");
assert_eq!(read_buf[0], 0xDE);
assert_eq!(read_buf[1], 0xAD);
assert_eq!(read_buf[2], 0xBE);
assert_eq!(read_buf[3], 0xEF);
assert_eq!(read_buf[BLOCK_SIZE - 1], 0xFF);
}
#[test]
fn test_read_write_bytes() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_bytes.part");
let dm = DiskManager::create(&path).expect("Failed to create DiskManager");
let block_id = dm.allocate_block().expect("Failed to allocate block");
let data = b"Hello, World!";
dm.write_bytes(block_id, 100, data)
.expect("Failed to write bytes");
let mut read_buf = [0u8; 13];
dm.read_bytes(block_id, 100, &mut read_buf)
.expect("Failed to read bytes");
assert_eq!(&read_buf, data);
}
#[test]
fn test_root_ptr() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_root.part");
let dm = DiskManager::create(&path).expect("Failed to create DiskManager");
assert_eq!(dm.root_ptr().expect("root_ptr"), 0);
dm.set_root_ptr(0x123456789ABCDEF0).expect("set_root_ptr");
assert_eq!(
dm.root_ptr().expect("root_ptr after set"),
0x123456789ABCDEF0
);
dm.sync().expect("sync");
drop(dm);
let dm2 = DiskManager::open(&path).expect("reopen");
assert_eq!(
dm2.root_ptr().expect("root_ptr after reopen"),
0x123456789ABCDEF0
);
}
#[test]
fn test_entry_count() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_entry.part");
let dm = DiskManager::create(&path).expect("Failed to create DiskManager");
assert_eq!(dm.entry_count().expect("entry_count"), 0);
dm.set_entry_count(12345).expect("set_entry_count");
assert_eq!(dm.entry_count().expect("entry_count after set"), 12345);
}
#[test]
fn test_cannot_free_header_block() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_no_free_header.part");
let dm = DiskManager::create(&path).expect("Failed to create DiskManager");
let result = dm.free_block(0);
assert!(result.is_err());
if let Err(PersistentARTrieError::InvalidBlockId { block_id, reason }) = result {
assert_eq!(block_id, 0);
assert!(reason.contains("header"));
} else {
panic!("Expected InvalidBlockId error");
}
}
#[test]
fn test_invalid_block_id() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_invalid.part");
let dm = DiskManager::create(&path).expect("Failed to create DiskManager");
let mut buf = [0u8; BLOCK_SIZE];
let result = dm.read_block(999, &mut buf);
assert!(result.is_err());
}
#[test]
fn test_concurrent_block_allocation() {
use std::sync::Arc;
use std::thread;
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_concurrent_alloc.part");
let dm = Arc::new(DiskManager::create(&path).expect("Failed to create DiskManager"));
const NUM_THREADS: usize = 8;
const BLOCKS_PER_THREAD: usize = 100;
let mut handles = Vec::with_capacity(NUM_THREADS);
for thread_id in 0..NUM_THREADS {
let dm = Arc::clone(&dm);
handles.push(thread::spawn(move || {
let mut allocated_ids = Vec::with_capacity(BLOCKS_PER_THREAD);
for i in 0..BLOCKS_PER_THREAD {
let block_id = dm.allocate_block().unwrap_or_else(|e| {
panic!(
"Thread {} failed to allocate block {}: {:?}",
thread_id, i, e
)
});
let mut buf = [0u8; BLOCK_SIZE];
buf[0..8].copy_from_slice(&(thread_id as u64).to_le_bytes());
buf[8..16].copy_from_slice(&(i as u64).to_le_bytes());
dm.write_block(block_id, &buf).unwrap_or_else(|e| {
panic!(
"Thread {} failed to write block {} (id={}): {:?}",
thread_id, i, block_id, e
)
});
let mut read_buf = [0u8; BLOCK_SIZE];
dm.read_block(block_id, &mut read_buf).unwrap_or_else(|e| {
panic!(
"Thread {} failed to read block {} (id={}): {:?}",
thread_id, i, block_id, e
)
});
assert_eq!(&read_buf[0..8], &(thread_id as u64).to_le_bytes());
assert_eq!(&read_buf[8..16], &(i as u64).to_le_bytes());
allocated_ids.push(block_id);
}
allocated_ids
}));
}
let mut all_ids: Vec<u32> = handles
.into_iter()
.flat_map(|h| h.join().expect("Thread panicked"))
.collect();
all_ids.sort();
let original_len = all_ids.len();
all_ids.dedup();
assert_eq!(
all_ids.len(),
original_len,
"Duplicate block IDs were allocated! Expected {} unique IDs, got {}",
original_len,
all_ids.len()
);
assert_eq!(
original_len,
NUM_THREADS * BLOCKS_PER_THREAD,
"Expected {} allocated blocks, got {}",
NUM_THREADS * BLOCKS_PER_THREAD,
original_len
);
let block_count = dm.block_count().expect("Failed to get block count");
assert_eq!(
block_count as usize,
1 + NUM_THREADS * BLOCKS_PER_THREAD,
"Block count mismatch"
);
}
#[test]
fn test_concurrent_allocate_and_access() {
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().join("test_concurrent_access.part");
let dm = Arc::new(DiskManager::create(&path).expect("Failed to create DiskManager"));
let stop = Arc::new(AtomicBool::new(false));
let safe_to_read = Arc::new(AtomicU64::new(0));
const NUM_ALLOCATORS: usize = 4;
const NUM_ACCESSORS: usize = 4;
const ALLOCATIONS_PER_THREAD: usize = 50;
let mut allocator_handles: Vec<JoinHandle<Vec<u32>>> = Vec::new();
let mut accessor_handles: Vec<JoinHandle<u64>> = Vec::new();
for thread_id in 0..NUM_ALLOCATORS {
let dm = Arc::clone(&dm);
let stop = Arc::clone(&stop);
let safe_to_read = Arc::clone(&safe_to_read);
allocator_handles.push(thread::spawn(move || {
let mut ids = Vec::new();
for i in 0..ALLOCATIONS_PER_THREAD {
let block_id = dm.allocate_block().unwrap_or_else(|e| {
panic!("Allocator {} failed at {}: {:?}", thread_id, i, e)
});
ids.push(block_id);
let mut buf = [0u8; BLOCK_SIZE];
buf[0..4].copy_from_slice(&block_id.to_le_bytes());
dm.write_block(block_id, &buf).unwrap_or_else(|e| {
panic!(
"Allocator {} failed to write block {}: {:?}",
thread_id, block_id, e
)
});
let mut read_buf = [0u8; BLOCK_SIZE];
dm.read_block(block_id, &mut read_buf).unwrap_or_else(|e| {
panic!(
"Allocator {} failed to read-back block {}: {:?}",
thread_id, block_id, e
)
});
assert_eq!(
&read_buf[0..4],
&block_id.to_le_bytes(),
"Allocator {} read-back mismatch for block {}",
thread_id,
block_id
);
loop {
let current = safe_to_read.load(AtomicOrdering::Acquire);
if block_id as u64 <= current {
break; }
match safe_to_read.compare_exchange(
current,
block_id as u64,
AtomicOrdering::AcqRel,
AtomicOrdering::Acquire,
) {
Ok(_) => break,
Err(_) => continue,
}
}
}
stop.store(true, AtomicOrdering::Release);
ids
}));
}
for thread_id in 0..NUM_ACCESSORS {
let dm = Arc::clone(&dm);
let stop = Arc::clone(&stop);
let safe_to_read = Arc::clone(&safe_to_read);
accessor_handles.push(thread::spawn(move || {
let mut successful_reads = 0u64;
while !stop.load(AtomicOrdering::Acquire) {
let safe_block = safe_to_read.load(AtomicOrdering::Acquire);
if safe_block >= 1 {
let block_id = ((successful_reads % safe_block) + 1) as u32;
let mut buf = [0u8; BLOCK_SIZE];
match dm.read_block(block_id, &mut buf) {
Ok(_) => successful_reads += 1,
Err(e) => {
panic!(
"Accessor {} failed to read safe block {} (safe_to_read={}): {:?}",
thread_id, block_id, safe_block, e
);
}
}
}
std::hint::spin_loop();
}
successful_reads
}));
}
let mut all_allocated: Vec<u32> = Vec::new();
for handle in allocator_handles {
let ids = handle.join().expect("Allocator thread panicked");
all_allocated.extend(ids);
}
let mut total_reads = 0u64;
for handle in accessor_handles {
let reads = handle.join().expect("Accessor thread panicked");
total_reads += reads;
}
all_allocated.sort();
let original_len = all_allocated.len();
all_allocated.dedup();
assert_eq!(all_allocated.len(), original_len, "Duplicate block IDs!");
eprintln!(
"Concurrent access test: {} blocks allocated, {} successful reads",
original_len, total_reads
);
}
}