use std::io::{self, Read, Seek, SeekFrom, Write};
use crate::Result;
use crate::block::BlockDevice;
#[derive(Debug, Clone, Copy)]
pub enum FailAfter {
Writes(u64),
Bytes(u64),
Never,
}
pub struct CrashInject<B: BlockDevice> {
inner: B,
fail: FailAfter,
writes_seen: u64,
bytes_written: u64,
crashed: bool,
}
impl<B: BlockDevice> CrashInject<B> {
pub fn new(inner: B, fail: FailAfter) -> Self {
Self {
inner,
fail,
writes_seen: 0,
bytes_written: 0,
crashed: false,
}
}
pub fn crashed(&self) -> bool {
self.crashed
}
pub fn into_inner(self) -> B {
self.inner
}
fn check_crash_and_account(&mut self, bytes: u64) {
if self.crashed {
return;
}
match self.fail {
FailAfter::Writes(n) => {
self.writes_seen += 1;
if self.writes_seen > n {
self.crashed = true;
}
}
FailAfter::Bytes(n) => {
self.bytes_written = self.bytes_written.saturating_add(bytes);
if self.bytes_written > n {
self.crashed = true;
}
}
FailAfter::Never => {}
}
}
}
impl<B: BlockDevice> Read for CrashInject<B> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl<B: BlockDevice> Write for CrashInject<B> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.check_crash_and_account(buf.len() as u64);
if self.crashed {
return Ok(buf.len()); }
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
if self.crashed {
return Ok(());
}
self.inner.flush()
}
}
impl<B: BlockDevice> Seek for CrashInject<B> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
self.inner.seek(pos)
}
}
impl<B: BlockDevice> BlockDevice for CrashInject<B> {
fn block_size(&self) -> u32 {
self.inner.block_size()
}
fn total_size(&self) -> u64 {
self.inner.total_size()
}
fn sync(&mut self) -> Result<()> {
if self.crashed {
return Ok(());
}
self.inner.sync()
}
fn write_at(&mut self, offset: u64, buf: &[u8]) -> Result<()> {
self.check_crash_and_account(buf.len() as u64);
if self.crashed {
return Ok(());
}
self.inner.write_at(offset, buf)
}
fn read_at(&mut self, offset: u64, buf: &mut [u8]) -> Result<()> {
self.inner.read_at(offset, buf)
}
fn zero_range(&mut self, offset: u64, len: u64) -> Result<()> {
self.check_crash_and_account(len);
if self.crashed {
return Ok(());
}
self.inner.zero_range(offset, len)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::block::MemoryBackend;
#[test]
fn fail_after_writes_drops_subsequent() {
let mem = MemoryBackend::new(1024);
let mut dev = CrashInject::new(mem, FailAfter::Writes(1));
dev.write_at(0, &[0x11; 16]).unwrap();
assert!(!dev.crashed());
dev.write_at(16, &[0x22; 16]).unwrap();
assert!(dev.crashed());
let mut buf = [0u8; 32];
dev.read_at(0, &mut buf).unwrap();
assert_eq!(&buf[..16], &[0x11; 16]);
assert_eq!(&buf[16..], &[0; 16]);
}
#[test]
fn fail_after_bytes_drops_past_threshold() {
let mem = MemoryBackend::new(1024);
let mut dev = CrashInject::new(mem, FailAfter::Bytes(20));
dev.write_at(0, &[0xAA; 16]).unwrap();
assert!(!dev.crashed());
dev.write_at(16, &[0xBB; 16]).unwrap();
assert!(dev.crashed());
}
#[test]
fn fail_after_never_is_passthrough() {
let mem = MemoryBackend::new(1024);
let mut dev = CrashInject::new(mem, FailAfter::Never);
for i in 0..32 {
dev.write_at(i * 16, &[i as u8; 16]).unwrap();
}
assert!(!dev.crashed());
let mut buf = [0u8; 16];
dev.read_at(31 * 16, &mut buf).unwrap();
assert_eq!(buf, [31u8; 16]);
}
}