use super::{
FileHint, Fs, FsCapabilities, FsDirEntry, FsFile, FsMetadata, FsOpenOptions, SyncMode,
};
use crate::io;
use crate::path::{Path, PathBuf};
use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec::Vec;
use hashbrown::{HashMap, HashSet};
#[derive(Default)]
struct CrashState {
durable: HashMap<PathBuf, Vec<u8>>,
touched: HashSet<PathBuf>,
}
#[derive(Clone)]
pub struct CrashFs {
inner: Arc<dyn Fs>,
state: Arc<spin::Mutex<CrashState>>,
}
impl CrashFs {
#[must_use]
pub fn new<F: Fs>(inner: F) -> Self {
Self::from_shared(Arc::new(inner))
}
#[must_use]
pub fn from_shared(inner: Arc<dyn Fs>) -> Self {
Self {
inner,
state: Arc::new(spin::Mutex::new(CrashState::default())),
}
}
#[must_use]
pub fn inner(&self) -> Arc<dyn Fs> {
Arc::clone(&self.inner)
}
pub fn crash(&self) {
let mut state = self.state.lock();
let paths: Vec<PathBuf> = state
.touched
.iter()
.chain(state.durable.keys())
.cloned()
.collect();
for path in paths {
match state.durable.get(&path) {
Some(bytes) => self.restore_durable(&path, bytes),
None => {
match self.inner.remove_file(&path) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::NotFound => {}
Err(e) => {
panic!("crash(): removing un-synced {} failed: {e}", path.display())
}
}
}
}
}
state.touched = state.durable.keys().cloned().collect();
}
fn restore_durable(&self, path: &Path, bytes: &[u8]) {
let mut file = self
.inner
.open(
path,
&FsOpenOptions::new().write(true).create(true).truncate(true),
)
.unwrap_or_else(|e| {
panic!(
"crash(): reopening {} for rollback failed: {e}",
path.display()
)
});
std::io::Write::write_all(&mut file, bytes).unwrap_or_else(|e| {
panic!(
"crash(): rewriting durable image of {} failed: {e}",
path.display()
)
});
}
fn read_baseline(&self, path: &Path) -> io::Result<Option<Vec<u8>>> {
let mut f = match self.inner.open(path, &FsOpenOptions::new().read(true)) {
Ok(f) => f,
Err(e) if e.kind() == io::ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(e),
};
let mut buf = Vec::new();
std::io::Read::read_to_end(&mut f, &mut buf)?;
Ok(Some(buf))
}
fn capture_first_touch(&self, path: &Path) -> io::Result<()> {
let pb = path.to_path_buf();
let first_touch = !self.state.lock().touched.contains(&pb);
if first_touch {
let baseline = self.read_baseline(path)?;
if let Some(bytes) = baseline {
self.state.lock().durable.insert(pb.clone(), bytes);
}
}
self.state.lock().touched.insert(pb);
Ok(())
}
fn track_copy(&self, src: &Path, dst: &Path) -> io::Result<()> {
let (src_durable, src_touched) = {
let state = self.state.lock();
(state.durable.get(src).cloned(), state.touched.contains(src))
};
let dst_image = match src_durable {
Some(bytes) => Some(bytes),
None if !src_touched => self.read_baseline(src)?,
None => None,
};
let mut state = self.state.lock();
if let Some(bytes) = dst_image {
state.durable.insert(dst.to_path_buf(), bytes);
}
state.touched.insert(dst.to_path_buf());
Ok(())
}
}
impl Fs for CrashFs {
fn open(&self, path: &Path, opts: &FsOpenOptions) -> io::Result<Box<dyn FsFile>> {
let writable = opts.write || opts.create || opts.create_new || opts.append || opts.truncate;
if writable {
self.capture_first_touch(path)?;
}
let inner = self.inner.open(path, opts)?;
Ok(Box::new(CrashFile {
inner,
path: path.to_path_buf(),
fs: Arc::clone(&self.inner),
state: Arc::clone(&self.state),
}))
}
fn create_dir_all(&self, path: &Path) -> io::Result<()> {
self.inner.create_dir_all(path)
}
fn create_dir(&self, path: &Path) -> io::Result<()> {
self.inner.create_dir(path)
}
fn read_dir(&self, path: &Path) -> io::Result<Vec<FsDirEntry>> {
self.inner.read_dir(path)
}
fn remove_file(&self, path: &Path) -> io::Result<()> {
self.inner.remove_file(path)?;
let mut state = self.state.lock();
state.durable.remove(path);
state.touched.remove(path);
Ok(())
}
fn remove_dir_all(&self, path: &Path) -> io::Result<()> {
self.inner.remove_dir_all(path)?;
let mut state = self.state.lock();
state.durable.retain(|k, _| !k.starts_with(path));
state.touched.retain(|k| !k.starts_with(path));
Ok(())
}
fn rename(&self, from: &Path, to: &Path) -> io::Result<()> {
self.inner.rename(from, to)?;
let mut state = self.state.lock();
let from_durable = state.durable.remove(from);
state.durable.remove(to);
if let Some(bytes) = from_durable {
state.durable.insert(to.to_path_buf(), bytes);
}
let from_touched = state.touched.remove(from);
state.touched.remove(to);
if from_touched {
state.touched.insert(to.to_path_buf());
}
Ok(())
}
fn metadata(&self, path: &Path) -> io::Result<FsMetadata> {
self.inner.metadata(path)
}
fn sync_directory(&self, path: &Path) -> io::Result<()> {
self.inner.sync_directory(path)
}
fn sync_directory_with(&self, path: &Path, mode: SyncMode) -> io::Result<()> {
self.inner.sync_directory_with(path, mode)
}
fn exists(&self, path: &Path) -> io::Result<bool> {
self.inner.exists(path)
}
fn hard_link(&self, src: &Path, dst: &Path) -> io::Result<()> {
self.inner.hard_link(src, dst)?;
self.track_copy(src, dst)?;
Ok(())
}
fn backend_id(&self) -> Option<u64> {
self.inner.backend_id()
}
fn volume_id(&self, path: &Path) -> Option<u64> {
self.inner.volume_id(path)
}
fn capabilities(&self, path: &Path) -> FsCapabilities {
self.inner.capabilities(path)
}
fn try_disable_cow(&self, path: &Path) -> io::Result<()> {
self.inner.try_disable_cow(path)
}
fn punch_hole(&self, path: &Path, offset: u64, len: u64) -> io::Result<()> {
self.capture_first_touch(path)?;
self.inner.punch_hole(path, offset, len)
}
fn reflink_file(&self, src: &Path, dst: &Path) -> io::Result<()> {
self.inner.reflink_file(src, dst)?;
self.track_copy(src, dst)?;
Ok(())
}
fn truncate_file(&self, path: &Path) -> io::Result<()> {
self.capture_first_touch(path)?;
self.inner.truncate_file(path)
}
fn hard_link_count(&self, path: &Path) -> io::Result<u64> {
self.inner.hard_link_count(path)
}
fn available_space(&self, path: &Path) -> io::Result<u64> {
self.inner.available_space(path)
}
}
struct CrashFile {
inner: Box<dyn FsFile>,
path: PathBuf,
fs: Arc<dyn Fs>,
state: Arc<spin::Mutex<CrashState>>,
}
impl CrashFile {
fn snapshot(&self) -> io::Result<()> {
let mut rf = self.fs.open(&self.path, &FsOpenOptions::new().read(true))?;
let mut buf = Vec::new();
std::io::Read::read_to_end(&mut rf, &mut buf)?;
self.state.lock().durable.insert(self.path.clone(), buf);
Ok(())
}
}
impl std::io::Read for CrashFile {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.inner.read(buf)
}
}
impl std::io::Write for CrashFile {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
}
impl std::io::Seek for CrashFile {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
self.inner.seek(pos)
}
}
impl FsFile for CrashFile {
fn sync_all(&self) -> io::Result<()> {
self.inner.sync_all()?;
self.snapshot()
}
fn sync_data(&self) -> io::Result<()> {
self.inner.sync_data()?;
self.snapshot()
}
fn sync_all_with(&self, mode: SyncMode) -> io::Result<()> {
self.inner.sync_all_with(mode)?;
self.snapshot()
}
fn sync_data_with(&self, mode: SyncMode) -> io::Result<()> {
self.inner.sync_data_with(mode)?;
self.snapshot()
}
fn metadata(&self) -> io::Result<FsMetadata> {
self.inner.metadata()
}
fn set_len(&self, size: u64) -> io::Result<()> {
self.inner.set_len(size)
}
fn read_at(&self, buf: &mut [u8], offset: u64) -> io::Result<usize> {
self.inner.read_at(buf, offset)
}
fn lock_exclusive(&self) -> io::Result<()> {
self.inner.lock_exclusive()
}
fn try_lock_exclusive(&self) -> io::Result<bool> {
self.inner.try_lock_exclusive()
}
fn hint(&self, hint: FileHint) -> io::Result<()> {
self.inner.hint(hint)
}
}
#[cfg(test)]
#[expect(clippy::unwrap_used, clippy::expect_used, reason = "test code")]
mod tests;