use std::io::{self, Read, Seek, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy)]
pub struct VfsMetadata {
pub len: u64,
pub is_file: bool,
}
pub trait VfsWrite: Write + Seek + Send {
fn set_len(&mut self, len: u64) -> io::Result<()>;
fn sync_all(&mut self) -> io::Result<()>;
}
pub trait VfsRead: Read + Seek + Send {}
impl<T: Read + Seek + Send + ?Sized> VfsRead for T {}
pub trait Vfs: Send + Sync {
fn open_write(&self, path: &Path) -> io::Result<Box<dyn VfsWrite>>;
fn open_read(&self, path: &Path) -> io::Result<Box<dyn VfsRead>>;
fn create_dir_all(&self, path: &Path) -> io::Result<()>;
fn remove_file(&self, path: &Path) -> io::Result<()>;
fn remove_dir(&self, path: &Path) -> io::Result<()>;
fn read_dir(&self, path: &Path) -> io::Result<Vec<PathBuf>>;
fn exists(&self, path: &Path) -> bool;
fn metadata(&self, path: &Path) -> io::Result<VfsMetadata>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct StdFs;
impl StdFs {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl VfsWrite for std::fs::File {
fn set_len(&mut self, len: u64) -> io::Result<()> {
std::fs::File::set_len(self, len)
}
fn sync_all(&mut self) -> io::Result<()> {
std::fs::File::sync_all(self)
}
}
impl Vfs for StdFs {
fn open_write(&self, path: &Path) -> io::Result<Box<dyn VfsWrite>> {
let file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(false)
.open(path)?;
Ok(Box::new(file))
}
fn open_read(&self, path: &Path) -> io::Result<Box<dyn VfsRead>> {
Ok(Box::new(std::fs::File::open(path)?))
}
fn create_dir_all(&self, path: &Path) -> io::Result<()> {
std::fs::create_dir_all(path)
}
fn remove_file(&self, path: &Path) -> io::Result<()> {
std::fs::remove_file(path)
}
fn remove_dir(&self, path: &Path) -> io::Result<()> {
std::fs::remove_dir(path)
}
fn read_dir(&self, path: &Path) -> io::Result<Vec<PathBuf>> {
let mut out = Vec::new();
for entry in std::fs::read_dir(path)? {
out.push(entry?.path());
}
Ok(out)
}
fn exists(&self, path: &Path) -> bool {
path.exists()
}
fn metadata(&self, path: &Path) -> io::Result<VfsMetadata> {
let m = std::fs::metadata(path)?;
Ok(VfsMetadata {
len: m.len(),
is_file: m.is_file(),
})
}
}
#[derive(Debug, Default, Clone)]
struct MemFile {
bytes: Vec<u8>,
}
#[derive(Debug, Default)]
struct MemFsInner {
files: std::collections::BTreeMap<PathBuf, MemFile>,
dirs: std::collections::BTreeSet<PathBuf>,
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryFs {
inner: Arc<Mutex<MemFsInner>>,
}
impl InMemoryFs {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn snapshot_files(&self) -> std::collections::BTreeMap<PathBuf, Vec<u8>> {
let inner = self.inner.lock().unwrap();
inner
.files
.iter()
.map(|(k, v)| (k.clone(), v.bytes.clone()))
.collect()
}
#[must_use]
pub fn snapshot_dirs(&self) -> std::collections::BTreeSet<PathBuf> {
self.inner.lock().unwrap().dirs.clone()
}
#[must_use]
pub fn read_file(&self, path: &Path) -> Option<Vec<u8>> {
self.inner
.lock()
.unwrap()
.files
.get(path)
.map(|f| f.bytes.clone())
}
}
struct MemFileHandle {
inner: Arc<Mutex<MemFsInner>>,
path: PathBuf,
cursor: u64,
bytes: Vec<u8>,
}
impl MemFileHandle {
fn commit(&self) {
let mut inner = self.inner.lock().unwrap();
inner
.files
.entry(self.path.clone())
.or_default()
.bytes
.clone_from(&self.bytes);
}
}
impl Drop for MemFileHandle {
fn drop(&mut self) {
self.commit();
}
}
impl Write for MemFileHandle {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let pos = usize::try_from(self.cursor)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "cursor overflow"))?;
let end = pos
.checked_add(buf.len())
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "write past usize"))?;
if end > self.bytes.len() {
self.bytes.resize(end, 0);
}
self.bytes[pos..end].copy_from_slice(buf);
self.cursor = end as u64;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.commit();
Ok(())
}
}
impl Seek for MemFileHandle {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
let new_pos: i128 = match pos {
io::SeekFrom::Start(n) => i128::from(n),
io::SeekFrom::Current(d) => i128::from(self.cursor) + i128::from(d),
io::SeekFrom::End(d) => i128::from(self.bytes.len() as u64) + i128::from(d),
};
if new_pos < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"negative seek position",
));
}
self.cursor = new_pos as u64;
Ok(self.cursor)
}
}
impl VfsWrite for MemFileHandle {
fn set_len(&mut self, len: u64) -> io::Result<()> {
let new_len = usize::try_from(len)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "len overflow"))?;
self.bytes.resize(new_len, 0);
if self.cursor > len {
self.cursor = len;
}
self.commit();
Ok(())
}
fn sync_all(&mut self) -> io::Result<()> {
self.commit();
Ok(())
}
}
struct MemReadHandle {
bytes: Vec<u8>,
cursor: u64,
}
impl Read for MemReadHandle {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let pos = usize::try_from(self.cursor).unwrap_or(usize::MAX);
if pos >= self.bytes.len() {
return Ok(0);
}
let available = &self.bytes[pos..];
let n = available.len().min(buf.len());
buf[..n].copy_from_slice(&available[..n]);
self.cursor += n as u64;
Ok(n)
}
}
impl Seek for MemReadHandle {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
let new_pos: i128 = match pos {
io::SeekFrom::Start(n) => i128::from(n),
io::SeekFrom::Current(d) => i128::from(self.cursor) + i128::from(d),
io::SeekFrom::End(d) => i128::from(self.bytes.len() as u64) + i128::from(d),
};
if new_pos < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"negative seek position",
));
}
self.cursor = new_pos as u64;
Ok(self.cursor)
}
}
impl Vfs for InMemoryFs {
fn open_write(&self, path: &Path) -> io::Result<Box<dyn VfsWrite>> {
let inner = self.inner.lock().unwrap();
let bytes = inner
.files
.get(path)
.map(|f| f.bytes.clone())
.unwrap_or_default();
drop(inner);
Ok(Box::new(MemFileHandle {
inner: Arc::clone(&self.inner),
path: path.to_path_buf(),
cursor: 0,
bytes,
}))
}
fn open_read(&self, path: &Path) -> io::Result<Box<dyn VfsRead>> {
let inner = self.inner.lock().unwrap();
let Some(f) = inner.files.get(path) else {
return Err(io::Error::from(io::ErrorKind::NotFound));
};
let bytes = f.bytes.clone();
Ok(Box::new(MemReadHandle { bytes, cursor: 0 }))
}
fn create_dir_all(&self, path: &Path) -> io::Result<()> {
let mut inner = self.inner.lock().unwrap();
let mut cur = PathBuf::new();
for comp in path.components() {
cur.push(comp);
inner.dirs.insert(cur.clone());
}
Ok(())
}
fn remove_file(&self, path: &Path) -> io::Result<()> {
let mut inner = self.inner.lock().unwrap();
if inner.files.remove(path).is_none() {
return Err(io::Error::from(io::ErrorKind::NotFound));
}
Ok(())
}
fn remove_dir(&self, path: &Path) -> io::Result<()> {
let mut inner = self.inner.lock().unwrap();
if !inner.dirs.remove(path) {
return Err(io::Error::from(io::ErrorKind::NotFound));
}
Ok(())
}
fn read_dir(&self, path: &Path) -> io::Result<Vec<PathBuf>> {
let inner = self.inner.lock().unwrap();
if !inner.dirs.contains(path) && !inner.files.keys().any(|p| p.parent() == Some(path)) {
return Err(io::Error::from(io::ErrorKind::NotFound));
}
let mut out = Vec::new();
for p in inner.files.keys() {
if p.parent() == Some(path) {
out.push(p.clone());
}
}
for d in &inner.dirs {
if d.parent() == Some(path) {
out.push(d.clone());
}
}
Ok(out)
}
fn exists(&self, path: &Path) -> bool {
let inner = self.inner.lock().unwrap();
inner.files.contains_key(path) || inner.dirs.contains(path)
}
fn metadata(&self, path: &Path) -> io::Result<VfsMetadata> {
let inner = self.inner.lock().unwrap();
if let Some(f) = inner.files.get(path) {
Ok(VfsMetadata {
len: f.bytes.len() as u64,
is_file: true,
})
} else if inner.dirs.contains(path) {
Ok(VfsMetadata {
len: 0,
is_file: false,
})
} else {
Err(io::Error::from(io::ErrorKind::NotFound))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::SeekFrom;
#[test]
fn memfs_write_then_read_round_trips() {
let fs = InMemoryFs::new();
{
let mut w = fs.open_write(Path::new("/a/b.txt")).unwrap();
w.write_all(b"hello").unwrap();
w.flush().unwrap();
}
let bytes = fs.read_file(Path::new("/a/b.txt")).unwrap();
assert_eq!(bytes, b"hello");
}
#[test]
fn memfs_seek_write_extends_with_zeros() {
let fs = InMemoryFs::new();
let mut w = fs.open_write(Path::new("/x")).unwrap();
w.seek(SeekFrom::Start(8)).unwrap();
w.write_all(b"AB").unwrap();
w.flush().unwrap();
let bytes = fs.read_file(Path::new("/x")).unwrap();
assert_eq!(bytes.len(), 10);
assert_eq!(&bytes[..8], &[0u8; 8]);
assert_eq!(&bytes[8..], b"AB");
}
#[test]
fn memfs_set_len_truncates() {
let fs = InMemoryFs::new();
let mut w = fs.open_write(Path::new("/t")).unwrap();
w.write_all(b"abcdef").unwrap();
w.set_len(3).unwrap();
drop(w);
assert_eq!(fs.read_file(Path::new("/t")).unwrap(), b"abc");
}
#[test]
fn memfs_remove_file_clears_entry() {
let fs = InMemoryFs::new();
drop(fs.open_write(Path::new("/r")).unwrap());
assert!(fs.exists(Path::new("/r")));
fs.remove_file(Path::new("/r")).unwrap();
assert!(!fs.exists(Path::new("/r")));
}
#[test]
fn memfs_create_dir_all_records_each_ancestor() {
let fs = InMemoryFs::new();
fs.create_dir_all(Path::new("/a/b/c")).unwrap();
let dirs = fs.snapshot_dirs();
assert!(dirs.contains(Path::new("/a")));
assert!(dirs.contains(Path::new("/a/b")));
assert!(dirs.contains(Path::new("/a/b/c")));
}
#[test]
fn memfs_read_dir_enumerates_children() {
let fs = InMemoryFs::new();
fs.create_dir_all(Path::new("/p")).unwrap();
drop(fs.open_write(Path::new("/p/a")).unwrap());
drop(fs.open_write(Path::new("/p/b")).unwrap());
let entries = fs.read_dir(Path::new("/p")).unwrap();
assert_eq!(entries.len(), 2);
}
#[test]
fn stdfs_round_trip_against_tempdir() {
let tmp = tempfile::tempdir().unwrap();
let fs = StdFs::new();
let p = tmp.path().join("hello.txt");
{
let mut w = fs.open_write(&p).unwrap();
w.write_all(b"world").unwrap();
w.flush().unwrap();
}
let mut r = fs.open_read(&p).unwrap();
let mut buf = Vec::new();
r.read_to_end(&mut buf).unwrap();
assert_eq!(buf, b"world");
}
}