use rand::{RngExt, SeedableRng};
use rand_chacha::ChaCha8Rng;
use std::io;
pub const SECTOR_SIZE: usize = 512;
const MAX_OVERLAYS: usize = 2;
#[derive(Debug, Clone)]
pub struct SectorBitSet {
bits: Vec<u64>,
len: usize,
}
impl SectorBitSet {
pub fn new(num_sectors: usize) -> Self {
let num_words = num_sectors.div_ceil(64);
Self {
bits: vec![0; num_words],
len: num_sectors,
}
}
fn indices(&self, sector: usize) -> (usize, usize) {
assert!(sector < self.len, "sector index out of bounds");
(sector / 64, sector % 64)
}
pub fn set(&mut self, sector: usize) {
let (word, bit) = self.indices(sector);
self.bits[word] |= 1 << bit;
}
pub fn clear(&mut self, sector: usize) {
let (word, bit) = self.indices(sector);
self.bits[word] &= !(1 << bit);
}
pub fn is_set(&self, sector: usize) -> bool {
let (word, bit) = self.indices(sector);
(self.bits[word] & (1 << bit)) != 0
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn resize_copy(other: &Self, new_len: usize) -> Self {
let mut new_bitset = Self::new(new_len);
let copy_len = other.len.min(new_len);
for sector in 0..copy_len {
if other.is_set(sector) {
new_bitset.set(sector);
}
}
new_bitset
}
}
#[derive(Debug, Clone)]
struct WriteOverlay {
offset: u64,
size: u32,
data: Vec<u8>,
active: bool,
}
#[derive(Debug, Clone)]
struct PendingWrite {
offset: u64,
data: Vec<u8>,
is_phantom: bool,
}
#[derive(Debug)]
pub struct InMemoryStorage {
data: Vec<u8>,
written: SectorBitSet,
faults: SectorBitSet,
overlays: [Option<WriteOverlay>; MAX_OVERLAYS],
pending_writes: Vec<PendingWrite>,
size: u64,
seed: u64,
}
impl InMemoryStorage {
pub fn new(size: u64, seed: u64) -> Self {
let num_sectors = (size as usize).div_ceil(SECTOR_SIZE);
Self {
data: vec![0; size as usize],
written: SectorBitSet::new(num_sectors),
faults: SectorBitSet::new(num_sectors),
overlays: [const { None }; MAX_OVERLAYS],
pending_writes: Vec::new(),
size,
seed,
}
}
pub fn size(&self) -> u64 {
self.size
}
pub fn num_sectors(&self) -> usize {
self.written.len()
}
pub fn resize(&mut self, new_size: u64) {
let old_size = self.size;
self.size = new_size;
self.data.resize(new_size as usize, 0);
let new_num_sectors = (new_size as usize).div_ceil(SECTOR_SIZE);
let old_num_sectors = self.written.len();
if new_num_sectors != old_num_sectors {
self.written = SectorBitSet::resize_copy(&self.written, new_num_sectors);
self.faults = SectorBitSet::resize_copy(&self.faults, new_num_sectors);
}
tracing::trace!(
"InMemoryStorage resized from {} to {} bytes ({} to {} sectors)",
old_size,
new_size,
old_num_sectors,
new_num_sectors
);
}
pub fn read(&self, offset: u64, buf: &mut [u8]) -> io::Result<()> {
let end = offset
.checked_add(buf.len() as u64)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "offset overflow"))?;
if end > self.size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"read past end of storage: offset={}, len={}, size={}",
offset,
buf.len(),
self.size
),
));
}
let offset_usize = offset as usize;
buf.copy_from_slice(&self.data[offset_usize..offset_usize + buf.len()]);
let start_sector = offset_usize / SECTOR_SIZE;
let end_sector = (offset_usize + buf.len()).div_ceil(SECTOR_SIZE);
for sector in start_sector..end_sector {
if sector >= self.written.len() {
break;
}
let sector_start = sector * SECTOR_SIZE;
let sector_end = sector_start + SECTOR_SIZE;
let buf_start = sector_start.saturating_sub(offset_usize);
let buf_end = (sector_end.saturating_sub(offset_usize)).min(buf.len());
if buf_start >= buf_end {
continue;
}
let sector_buf = &mut buf[buf_start..buf_end];
if !self.written.is_set(sector) {
self.fill_unwritten_sector(sector, sector_buf, sector_start, offset_usize);
}
if self.faults.is_set(sector) {
self.apply_corruption(sector, sector_buf);
}
}
self.apply_overlays(offset, buf);
Ok(())
}
fn fill_unwritten_sector(
&self,
sector: usize,
buf: &mut [u8],
sector_start: usize,
read_offset: usize,
) {
let mut rng = ChaCha8Rng::seed_from_u64(self.seed.wrapping_add(sector as u64));
let mut sector_data = [0u8; SECTOR_SIZE];
rng.fill(&mut sector_data);
let offset_in_sector = read_offset.saturating_sub(sector_start);
let copy_start = offset_in_sector.min(SECTOR_SIZE);
let copy_len = buf.len().min(SECTOR_SIZE - copy_start);
buf[..copy_len].copy_from_slice(§or_data[copy_start..copy_start + copy_len]);
}
fn apply_corruption(&self, sector: usize, buf: &mut [u8]) {
if buf.is_empty() {
return;
}
let sector_start = sector * SECTOR_SIZE;
let mut seed_bytes = [0u8; 8];
if sector_start + 8 <= self.data.len() {
seed_bytes.copy_from_slice(&self.data[sector_start..sector_start + 8]);
}
let seed = u64::from_le_bytes(seed_bytes);
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let byte_idx = rng.random_range(0..buf.len());
let bit_idx = rng.random_range(0..8u8);
buf[byte_idx] ^= 1 << bit_idx;
}
fn apply_overlays(&self, offset: u64, buf: &mut [u8]) {
for overlay in self.overlays.iter().flatten() {
if !overlay.active {
continue;
}
let overlay_end = overlay.offset + overlay.size as u64;
let read_end = offset + buf.len() as u64;
if overlay.offset >= read_end || overlay_end <= offset {
continue;
}
let intersect_start = overlay.offset.max(offset);
let intersect_end = overlay_end.min(read_end);
let buf_offset = (intersect_start - offset) as usize;
let overlay_offset = (intersect_start - overlay.offset) as usize;
let copy_len = (intersect_end - intersect_start) as usize;
buf[buf_offset..buf_offset + copy_len]
.copy_from_slice(&overlay.data[overlay_offset..overlay_offset + copy_len]);
}
}
pub fn write(&mut self, offset: u64, data: &[u8], is_synced: bool) -> io::Result<()> {
let end = offset
.checked_add(data.len() as u64)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "offset overflow"))?;
if end > self.size {
self.resize(end);
}
let offset_usize = offset as usize;
let start_sector = offset_usize / SECTOR_SIZE;
let end_sector = (offset_usize + data.len()).div_ceil(SECTOR_SIZE);
for sector in start_sector..end_sector {
if sector < self.written.len() {
self.written.set(sector);
self.faults.clear(sector);
}
}
self.data[offset_usize..offset_usize + data.len()].copy_from_slice(data);
if !is_synced {
self.pending_writes.push(PendingWrite {
offset,
data: data.to_vec(),
is_phantom: false,
});
}
Ok(())
}
pub fn sync(&mut self) {
self.pending_writes.clear();
}
pub fn apply_misdirected_write(
&mut self,
intended_offset: u64,
mistaken_offset: u64,
data: &[u8],
) -> io::Result<()> {
let intended_end = intended_offset
.checked_add(data.len() as u64)
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "intended offset overflow")
})?;
let mistaken_end = mistaken_offset
.checked_add(data.len() as u64)
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "mistaken offset overflow")
})?;
let required_size = intended_end.max(mistaken_end);
if required_size > self.size {
self.resize(required_size);
}
let intended_usize = intended_offset as usize;
let old_data = self.data[intended_usize..intended_usize + data.len()].to_vec();
self.overlays[0] = Some(WriteOverlay {
offset: intended_offset,
size: data.len() as u32,
data: old_data,
active: true,
});
self.overlays[1] = Some(WriteOverlay {
offset: mistaken_offset,
size: data.len() as u32,
data: data.to_vec(),
active: true,
});
self.data[intended_usize..intended_usize + data.len()].copy_from_slice(data);
let start_sector = intended_usize / SECTOR_SIZE;
let end_sector = (intended_usize + data.len()).div_ceil(SECTOR_SIZE);
for sector in start_sector..end_sector {
if sector < self.written.len() {
self.written.set(sector);
}
}
Ok(())
}
pub fn clear_overlays(&mut self) {
for overlay in &mut self.overlays {
*overlay = None;
}
}
pub fn read_misdirected(&self, offset: u64, buf: &mut [u8]) -> io::Result<()> {
if buf.is_empty() {
return Ok(());
}
let mut rng = ChaCha8Rng::seed_from_u64(self.seed.wrapping_add(offset));
let max_offset = self.size.saturating_sub(buf.len() as u64);
if max_offset == 0 {
return self.read(0, buf);
}
let mut misdirected_offset = rng.random_range(0..max_offset);
if misdirected_offset == offset {
misdirected_offset = (misdirected_offset + SECTOR_SIZE as u64) % max_offset;
}
self.read(misdirected_offset, buf)
}
pub fn record_phantom_write(&mut self, offset: u64, data: &[u8]) {
self.pending_writes.push(PendingWrite {
offset,
data: data.to_vec(),
is_phantom: true,
});
}
pub fn apply_crash(&mut self, crash_fault_probability: f64) {
let mut rng = ChaCha8Rng::seed_from_u64(self.seed);
for pending in &self.pending_writes {
if pending.is_phantom {
continue;
}
if rng.random::<f64>() >= crash_fault_probability {
continue;
}
let offset_usize = pending.offset as usize;
let start_sector = offset_usize / SECTOR_SIZE;
let end_sector = (offset_usize + pending.data.len()).div_ceil(SECTOR_SIZE);
if start_sector < end_sector && end_sector <= self.faults.len() {
let faulted_sector = rng.random_range(start_sector..end_sector);
self.faults.set(faulted_sector);
}
}
self.pending_writes.clear();
}
pub fn set_fault(&mut self, sector: usize) {
if sector < self.faults.len() {
self.faults.set(sector);
}
}
pub fn clear_fault(&mut self, sector: usize) {
if sector < self.faults.len() {
self.faults.clear(sector);
}
}
pub fn has_fault(&self, sector: usize) -> bool {
sector < self.faults.len() && self.faults.is_set(sector)
}
pub fn is_written(&self, sector: usize) -> bool {
sector < self.written.len() && self.written.is_set(sector)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_write_read() {
let mut storage = InMemoryStorage::new(4096, 42);
let data = b"Hello, World!";
storage.write(0, data, true).expect("write failed");
let mut buf = vec![0u8; data.len()];
storage.read(0, &mut buf).expect("read failed");
assert_eq!(&buf, data);
}
#[test]
fn test_unwritten_sector_deterministic() {
let storage1 = InMemoryStorage::new(4096, 42);
let storage2 = InMemoryStorage::new(4096, 42);
let mut buf1 = vec![0u8; SECTOR_SIZE];
let mut buf2 = vec![0u8; SECTOR_SIZE];
storage1.read(0, &mut buf1).expect("read1 failed");
storage2.read(0, &mut buf2).expect("read2 failed");
assert_eq!(buf1, buf2);
let storage3 = InMemoryStorage::new(4096, 99);
let mut buf3 = vec![0u8; SECTOR_SIZE];
storage3.read(0, &mut buf3).expect("read3 failed");
assert_ne!(buf1, buf3);
}
#[test]
fn test_fault_corruption() {
let mut storage = InMemoryStorage::new(4096, 42);
let data = vec![0xAA; SECTOR_SIZE];
storage.write(0, &data, true).expect("write failed");
let mut buf_clean = vec![0u8; SECTOR_SIZE];
storage.read(0, &mut buf_clean).expect("read failed");
assert_eq!(buf_clean, data);
storage.set_fault(0);
let mut buf_faulted = vec![0u8; SECTOR_SIZE];
storage.read(0, &mut buf_faulted).expect("read failed");
assert_ne!(buf_faulted, data);
let bit_diffs: u32 = buf_clean
.iter()
.zip(buf_faulted.iter())
.map(|(a, b)| (*a ^ *b).count_ones())
.sum();
assert_eq!(bit_diffs, 1, "Expected exactly one bit flip");
}
#[test]
fn test_corruption_determinism() {
let mut storage = InMemoryStorage::new(4096, 42);
let data = vec![0xAA; SECTOR_SIZE];
storage.write(0, &data, true).expect("write failed");
storage.set_fault(0);
let mut buf1 = vec![0u8; SECTOR_SIZE];
let mut buf2 = vec![0u8; SECTOR_SIZE];
storage.read(0, &mut buf1).expect("read1 failed");
storage.read(0, &mut buf2).expect("read2 failed");
assert_eq!(buf1, buf2);
}
#[test]
fn test_misdirected_write() {
let mut storage = InMemoryStorage::new(4096, 42);
let original_intended = vec![0x11; SECTOR_SIZE];
let original_mistaken = vec![0x22; SECTOR_SIZE];
storage
.write(0, &original_intended, true)
.expect("write1 failed");
storage
.write(SECTOR_SIZE as u64, &original_mistaken, true)
.expect("write2 failed");
let new_data = vec![0xFF; SECTOR_SIZE];
storage
.apply_misdirected_write(0, SECTOR_SIZE as u64, &new_data)
.expect("misdirect failed");
let mut buf_intended = vec![0u8; SECTOR_SIZE];
storage.read(0, &mut buf_intended).expect("read failed");
assert_eq!(buf_intended, original_intended);
let mut buf_mistaken = vec![0u8; SECTOR_SIZE];
storage
.read(SECTOR_SIZE as u64, &mut buf_mistaken)
.expect("read failed");
assert_eq!(buf_mistaken, new_data);
storage.clear_overlays();
storage.read(0, &mut buf_intended).expect("read failed");
assert_eq!(buf_intended, new_data); }
#[test]
fn test_phantom_write_lost_on_crash() {
let mut storage = InMemoryStorage::new(4096, 42);
let real_data = vec![0x11; SECTOR_SIZE];
storage.write(0, &real_data, true).expect("write failed");
let phantom_data = vec![0xFF; SECTOR_SIZE];
storage.record_phantom_write(0, &phantom_data);
let mut buf = vec![0u8; SECTOR_SIZE];
storage.read(0, &mut buf).expect("read failed");
assert_eq!(buf, real_data);
storage.apply_crash(1.0);
storage.read(0, &mut buf).expect("read failed");
assert_eq!(buf, real_data);
}
#[test]
fn test_crash_faults_pending_writes() {
let mut storage = InMemoryStorage::new(4096, 42);
let data = vec![0xAA; SECTOR_SIZE];
storage.write(0, &data, false).expect("write failed");
storage.apply_crash(1.0);
assert!(storage.has_fault(0));
let mut buf = vec![0u8; SECTOR_SIZE];
storage.read(0, &mut buf).expect("read failed");
assert_ne!(buf, data);
}
#[test]
fn test_sync_clears_pending() {
let mut storage = InMemoryStorage::new(4096, 42);
let data = vec![0xAA; SECTOR_SIZE];
storage.write(0, &data, false).expect("write failed");
storage.sync();
storage.apply_crash(1.0);
assert!(!storage.has_fault(0));
let mut buf = vec![0u8; SECTOR_SIZE];
storage.read(0, &mut buf).expect("read failed");
assert_eq!(buf, data);
}
#[test]
fn test_sector_bitset() {
let mut bitset = SectorBitSet::new(100);
assert!(!bitset.is_set(0));
assert!(!bitset.is_set(50));
assert!(!bitset.is_set(99));
bitset.set(0);
bitset.set(50);
bitset.set(99);
assert!(bitset.is_set(0));
assert!(bitset.is_set(50));
assert!(bitset.is_set(99));
assert!(!bitset.is_set(1));
bitset.clear(50);
assert!(!bitset.is_set(50));
assert_eq!(bitset.len(), 100);
}
#[test]
fn test_read_past_end() {
let storage = InMemoryStorage::new(1024, 42);
let mut buf = vec![0u8; 100];
let result = storage.read(1000, &mut buf);
assert!(result.is_err());
}
#[test]
fn test_write_past_end_auto_extends() {
let mut storage = InMemoryStorage::new(1024, 42);
let data = vec![0xAB; 100];
let result = storage.write(1000, &data, true);
assert!(result.is_ok());
assert_eq!(storage.size(), 1100);
let mut read_buf = vec![0u8; 100];
storage.read(1000, &mut read_buf).expect("read failed");
assert_eq!(read_buf, data);
}
#[test]
fn test_read_misdirected() {
let mut storage = InMemoryStorage::new(4096, 42);
let data0 = vec![0x11; SECTOR_SIZE];
let data1 = vec![0x22; SECTOR_SIZE];
let data2 = vec![0x33; SECTOR_SIZE];
storage.write(0, &data0, true).expect("write failed");
storage
.write(SECTOR_SIZE as u64, &data1, true)
.expect("write failed");
storage
.write(2 * SECTOR_SIZE as u64, &data2, true)
.expect("write failed");
let mut buf = vec![0u8; SECTOR_SIZE];
storage.read_misdirected(0, &mut buf).expect("read failed");
assert_ne!(buf, data0);
}
#[test]
fn test_partial_sector_read() {
let mut storage = InMemoryStorage::new(4096, 42);
let data: Vec<u8> = (0..SECTOR_SIZE).map(|i| (i % 256) as u8).collect();
storage.write(0, &data, true).expect("write failed");
let mut buf = vec![0u8; 100];
storage.read(50, &mut buf).expect("read failed");
assert_eq!(buf, &data[50..150]);
}
#[test]
fn test_multi_sector_read() {
let mut storage = InMemoryStorage::new(4096, 42);
let data = vec![0xAB; SECTOR_SIZE * 3];
storage.write(0, &data, true).expect("write failed");
let mut buf = vec![0u8; SECTOR_SIZE * 3];
storage.read(0, &mut buf).expect("read failed");
assert_eq!(buf, data);
}
}