use super::{
FileHint, Fs, FsCapabilities, FsDirEntry, FsFile, FsMetadata, FsOpenOptions, SyncMode,
};
use crate::io;
use crate::path::{Path, PathBuf};
use alloc::boxed::Box;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum FaultOp {
Open,
CreateDirAll,
CreateDir,
RemoveFile,
Rename,
SyncDirectory,
Write,
Flush,
ReadAt,
Read,
SyncAll,
SyncData,
SetLen,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum Fault {
Error(io::ErrorKind),
ShortWrite(usize),
}
#[derive(Clone, Debug)]
pub struct FaultRule {
op: FaultOp,
path_substr: Option<String>,
skip: u64,
count: u64,
fault: Fault,
}
impl FaultRule {
#[must_use]
pub const fn new(op: FaultOp, fault: Fault) -> Self {
Self {
op,
path_substr: None,
skip: 0,
count: u64::MAX,
fault,
}
}
#[must_use]
pub fn on_path(mut self, substr: impl Into<String>) -> Self {
self.path_substr = Some(substr.into());
self
}
#[must_use]
pub const fn skip(mut self, n: u64) -> Self {
self.skip = n;
self
}
#[must_use]
pub const fn times(mut self, n: u64) -> Self {
self.count = n;
self
}
#[must_use]
pub const fn once(self) -> Self {
self.times(1)
}
fn matches(&self, op: FaultOp, path: Option<&Path>) -> bool {
if self.op != op {
return false;
}
match (&self.path_substr, path) {
(None, _) => true,
(Some(sub), Some(p)) => p.to_string_lossy().contains(sub.as_str()),
(Some(_), None) => false,
}
}
}
#[derive(Default)]
pub struct FaultInjector {
rules: spin::Mutex<Vec<FaultRule>>,
}
impl FaultInjector {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn arm(&self, rule: FaultRule) {
self.rules.lock().push(rule);
}
pub fn clear(&self) {
self.rules.lock().clear();
}
fn check(&self, op: FaultOp, path: Option<&Path>) -> Option<Fault> {
let mut rules = self.rules.lock();
for rule in rules.iter_mut() {
if !rule.matches(op, path) {
continue;
}
if rule.skip > 0 {
rule.skip -= 1;
return None;
}
if rule.count == 0 {
return None;
}
if rule.count != u64::MAX {
rule.count -= 1;
}
return Some(rule.fault);
}
None
}
}
fn fault_error(kind: io::ErrorKind, op: FaultOp) -> io::Error {
io::Error::new(kind, alloc::format!("injected fault on {op:?}"))
}
fn fault_error_std(kind: io::ErrorKind, op: FaultOp) -> std::io::Error {
fault_error(kind, op).into()
}
pub struct FaultFs<F> {
inner: F,
injector: Arc<FaultInjector>,
}
impl<F: Fs> FaultFs<F> {
#[must_use]
pub fn new(inner: F) -> Self {
Self {
inner,
injector: Arc::new(FaultInjector::new()),
}
}
#[must_use]
pub fn with_injector(inner: F, injector: Arc<FaultInjector>) -> Self {
Self { inner, injector }
}
#[must_use]
pub fn injector(&self) -> Arc<FaultInjector> {
Arc::clone(&self.injector)
}
}
impl<F: Fs> Fs for FaultFs<F> {
fn open(&self, path: &Path, opts: &FsOpenOptions) -> io::Result<Box<dyn FsFile>> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::Open, Some(path)) {
return Err(fault_error(kind, FaultOp::Open));
}
let inner = self.inner.open(path, opts)?;
Ok(Box::new(FaultFile {
inner,
path: path.to_path_buf(),
injector: Arc::clone(&self.injector),
}))
}
fn create_dir_all(&self, path: &Path) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::CreateDirAll, Some(path)) {
return Err(fault_error(kind, FaultOp::CreateDirAll));
}
self.inner.create_dir_all(path)
}
fn create_dir(&self, path: &Path) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::CreateDir, Some(path)) {
return Err(fault_error(kind, FaultOp::CreateDir));
}
self.inner.create_dir(path)
}
fn read_dir(&self, path: &Path) -> io::Result<Vec<FsDirEntry>> {
self.inner.read_dir(path)
}
fn remove_file(&self, path: &Path) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::RemoveFile, Some(path)) {
return Err(fault_error(kind, FaultOp::RemoveFile));
}
self.inner.remove_file(path)
}
fn remove_dir_all(&self, path: &Path) -> io::Result<()> {
self.inner.remove_dir_all(path)
}
fn rename(&self, from: &Path, to: &Path) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::Rename, Some(to)) {
return Err(fault_error(kind, FaultOp::Rename));
}
self.inner.rename(from, to)
}
fn metadata(&self, path: &Path) -> io::Result<FsMetadata> {
self.inner.metadata(path)
}
fn sync_directory(&self, path: &Path) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::SyncDirectory, Some(path)) {
return Err(fault_error(kind, FaultOp::SyncDirectory));
}
self.inner.sync_directory(path)
}
fn sync_directory_with(&self, path: &Path, mode: SyncMode) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::SyncDirectory, Some(path)) {
return Err(fault_error(kind, FaultOp::SyncDirectory));
}
self.inner.sync_directory_with(path, mode)
}
fn exists(&self, path: &Path) -> io::Result<bool> {
self.inner.exists(path)
}
fn hard_link(&self, src: &Path, dst: &Path) -> io::Result<()> {
self.inner.hard_link(src, dst)
}
fn backend_id(&self) -> Option<u64> {
self.inner.backend_id()
}
fn volume_id(&self, path: &Path) -> Option<u64> {
self.inner.volume_id(path)
}
fn capabilities(&self, path: &Path) -> FsCapabilities {
self.inner.capabilities(path)
}
fn try_disable_cow(&self, path: &Path) -> io::Result<()> {
self.inner.try_disable_cow(path)
}
fn punch_hole(&self, path: &Path, offset: u64, len: u64) -> io::Result<()> {
self.inner.punch_hole(path, offset, len)
}
fn reflink_file(&self, src: &Path, dst: &Path) -> io::Result<()> {
self.inner.reflink_file(src, dst)
}
fn truncate_file(&self, path: &Path) -> io::Result<()> {
self.inner.truncate_file(path)
}
fn hard_link_count(&self, path: &Path) -> io::Result<u64> {
self.inner.hard_link_count(path)
}
fn available_space(&self, path: &Path) -> io::Result<u64> {
self.inner.available_space(path)
}
}
struct FaultFile {
inner: Box<dyn FsFile>,
path: PathBuf,
injector: Arc<FaultInjector>,
}
impl std::io::Read for FaultFile {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::Read, Some(&self.path)) {
return Err(fault_error_std(kind, FaultOp::Read));
}
self.inner.read(buf)
}
}
impl std::io::Write for FaultFile {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self.injector.check(FaultOp::Write, Some(&self.path)) {
Some(Fault::Error(kind)) => Err(fault_error_std(kind, FaultOp::Write)),
Some(Fault::ShortWrite(n)) => {
let take = n.min(buf.len());
if take == 0 {
return Ok(0);
}
let (head, _) = buf.split_at(take);
self.inner.write_all(head)?;
Ok(take)
}
None => self.inner.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::Flush, Some(&self.path)) {
return Err(fault_error_std(kind, FaultOp::Flush));
}
self.inner.flush()
}
}
impl std::io::Seek for FaultFile {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
self.inner.seek(pos)
}
}
impl FsFile for FaultFile {
fn sync_all(&self) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::SyncAll, Some(&self.path)) {
return Err(fault_error(kind, FaultOp::SyncAll));
}
self.inner.sync_all()
}
fn sync_data(&self) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::SyncData, Some(&self.path)) {
return Err(fault_error(kind, FaultOp::SyncData));
}
self.inner.sync_data()
}
fn sync_all_with(&self, mode: SyncMode) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::SyncAll, Some(&self.path)) {
return Err(fault_error(kind, FaultOp::SyncAll));
}
self.inner.sync_all_with(mode)
}
fn sync_data_with(&self, mode: SyncMode) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::SyncData, Some(&self.path)) {
return Err(fault_error(kind, FaultOp::SyncData));
}
self.inner.sync_data_with(mode)
}
fn metadata(&self) -> io::Result<FsMetadata> {
self.inner.metadata()
}
fn set_len(&self, size: u64) -> io::Result<()> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::SetLen, Some(&self.path)) {
return Err(fault_error(kind, FaultOp::SetLen));
}
self.inner.set_len(size)
}
fn read_at(&self, buf: &mut [u8], offset: u64) -> io::Result<usize> {
if let Some(Fault::Error(kind)) = self.injector.check(FaultOp::ReadAt, Some(&self.path)) {
return Err(fault_error(kind, FaultOp::ReadAt));
}
self.inner.read_at(buf, offset)
}
fn lock_exclusive(&self) -> io::Result<()> {
self.inner.lock_exclusive()
}
fn try_lock_exclusive(&self) -> io::Result<bool> {
self.inner.try_lock_exclusive()
}
fn hint(&self, hint: FileHint) -> io::Result<()> {
self.inner.hint(hint)
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
reason = "test code"
)]
mod tests;