use super::{Fs, FsCapabilities, FsDirEntry, FsFile, FsMetadata, FsOpenOptions};
use crate::io::{self, SeekFrom};
#[cfg(not(feature = "std"))]
use crate::io::{Read, Seek, Write};
use crate::path::{Path, PathBuf};
#[cfg(not(feature = "std"))]
use alloc::borrow::ToOwned;
use alloc::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, vec::Vec};
use hashbrown::{HashMap, HashSet};
use spin::{Mutex, RwLock};
#[derive(Clone, Debug)]
pub struct MemFs {
state: Arc<RwLock<State>>,
namespace_id: u64,
capacity: Arc<portable_atomic::AtomicU64>,
punched_total: Arc<portable_atomic::AtomicU64>,
has_punches: Arc<portable_atomic::AtomicBool>,
punch_hole_supported: Arc<portable_atomic::AtomicBool>,
}
#[derive(Debug, Default)]
struct State {
files: HashMap<PathBuf, Arc<Mutex<Vec<u8>>>>,
dirs: HashSet<PathBuf>,
punched: HashMap<PathBuf, Vec<(u64, u64)>>,
}
fn merge_punched_range(ranges: &mut Vec<(u64, u64)>, start: u64, end: u64) {
if start >= end {
return;
}
ranges.push((start, end));
ranges.sort_unstable();
let mut merged: Vec<(u64, u64)> = Vec::with_capacity(ranges.len());
for &(s, e) in ranges.iter() {
match merged.last_mut() {
Some(last) if s <= last.1 => last.1 = last.1.max(e),
_ => merged.push((s, e)),
}
}
*ranges = merged;
}
fn subtract_punched_range(ranges: &mut Vec<(u64, u64)>, start: u64, end: u64) {
if start >= end {
return;
}
let mut out: Vec<(u64, u64)> = Vec::with_capacity(ranges.len() + 1);
for &(s, e) in ranges.iter() {
if s < start {
out.push((s, start.min(e)));
}
if e > end {
out.push((s.max(end), e));
}
}
out.retain(|&(s, e)| s < e);
*ranges = out;
}
fn clipped_punched_len(ranges: &[(u64, u64)], len: u64) -> u64 {
ranges
.iter()
.map(|&(s, e)| {
let e = e.min(len);
let s = s.min(e);
e - s
})
.sum()
}
impl MemFs {
#[must_use]
pub fn new() -> Self {
let mut state = State::default();
state.dirs.insert(PathBuf::from("/"));
Self {
state: Arc::new(RwLock::new(state)),
namespace_id: next_mem_fs_namespace_id(),
capacity: Arc::new(portable_atomic::AtomicU64::new(u64::MAX)),
punched_total: Arc::new(portable_atomic::AtomicU64::new(0)),
has_punches: Arc::new(portable_atomic::AtomicBool::new(false)),
punch_hole_supported: Arc::new(portable_atomic::AtomicBool::new(true)),
}
}
pub fn set_punch_hole_supported(&self, supported: bool) {
self.punch_hole_supported
.store(supported, portable_atomic::Ordering::Relaxed);
}
#[must_use]
pub fn with_capacity(capacity_bytes: u64) -> Self {
let fs = Self::new();
fs.set_capacity(capacity_bytes);
fs
}
pub fn set_capacity(&self, capacity_bytes: u64) {
self.capacity
.store(capacity_bytes, portable_atomic::Ordering::Relaxed);
}
fn stored_bytes(&self) -> u64 {
let state = self.state.read();
state
.files
.iter()
.map(|(path, data)| {
let len = data.lock().len() as u64;
let punched = state
.punched
.get(path)
.map_or(0, |ranges| clipped_punched_len(ranges, len));
len - punched
})
.sum()
}
#[must_use]
pub fn punched_bytes(&self) -> u64 {
self.punched_total.load(portable_atomic::Ordering::Relaxed)
}
}
fn next_mem_fs_namespace_id() -> u64 {
use core::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(1);
u64::from(COUNTER.fetch_add(1, Ordering::Relaxed))
}
impl Default for MemFs {
fn default() -> Self {
Self::new()
}
}
struct MemFile {
data: Arc<Mutex<Vec<u8>>>,
cursor: u64,
readable: bool,
writable: bool,
is_append: bool,
state: Arc<RwLock<State>>,
path: PathBuf,
has_punches: Arc<portable_atomic::AtomicBool>,
}
fn copy_from_data(buf: &mut [u8], data: &[u8], pos: usize) -> usize {
let available = data.get(pos..).unwrap_or_default();
let n = buf.len().min(available.len());
if let (Some(dst), Some(src)) = (buf.get_mut(..n), available.get(..n)) {
dst.copy_from_slice(src);
}
n
}
impl MemFile {
fn read_impl(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if !self.readable {
return Err(io::Error::other("file not opened for reading"));
}
let data = lock(&self.data)?;
let pos = usize::try_from(self.cursor).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"cursor exceeds addressable memory",
)
})?;
let n = copy_from_data(buf, &data, pos);
drop(data);
self.cursor += n as u64;
Ok(n)
}
fn write_impl(&mut self, buf: &[u8]) -> io::Result<usize> {
if !self.writable {
return Err(io::Error::other("file not opened for writing"));
}
if buf.is_empty() {
return Ok(0);
}
let mut data = lock(&self.data)?;
let pos = if self.is_append {
data.len()
} else {
usize::try_from(self.cursor).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"write position exceeds addressable memory",
)
})?
};
let end = pos.checked_add(buf.len()).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "write position overflow")
})?;
if end > data.len() {
data.resize(end, 0);
}
if let Some(dst) = data.get_mut(pos..end) {
dst.copy_from_slice(buf);
}
drop(data);
self.drop_punched_overlap(pos as u64, end as u64);
self.cursor = end as u64;
Ok(buf.len())
}
fn current_path(&self, state: &State) -> Option<PathBuf> {
if state
.files
.get(&self.path)
.is_some_and(|data| Arc::ptr_eq(data, &self.data))
{
return Some(self.path.clone());
}
state
.files
.iter()
.find(|(_, data)| Arc::ptr_eq(data, &self.data))
.map(|(path, _)| path.clone())
}
fn drop_punched_overlap(&self, start: u64, end: u64) {
if !self.has_punches.load(portable_atomic::Ordering::Relaxed) {
return;
}
let mut state = self.state.write();
let Some(path) = self.current_path(&state) else {
return;
};
if let Some(ranges) = state.punched.get_mut(&path) {
subtract_punched_range(ranges, start, end);
if ranges.is_empty() {
state.punched.remove(&path);
}
}
}
fn clip_punched_to(&self, new_len: u64) {
if !self.has_punches.load(portable_atomic::Ordering::Relaxed) {
return;
}
let mut state = self.state.write();
let Some(path) = self.current_path(&state) else {
return;
};
if let Some(ranges) = state.punched.get_mut(&path) {
ranges.retain_mut(|(s, e)| {
*e = (*e).min(new_len);
*s < *e
});
if ranges.is_empty() {
state.punched.remove(&path);
}
}
}
fn seek_impl(&mut self, pos: SeekFrom) -> io::Result<u64> {
let new_pos: u64 = match pos {
SeekFrom::Start(n) => n,
SeekFrom::End(n) => {
let len = {
let data = lock(&self.data)?;
u64::try_from(data.len()).map_err(|_| {
io::Error::other("in-memory file length does not fit in u64")
})?
};
let result = i128::from(len) + i128::from(n);
if result < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"seek to negative position",
));
}
u64::try_from(result).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, "seek position overflow")
})?
}
SeekFrom::Current(n) => {
let result = i128::from(self.cursor) + i128::from(n);
if result < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"seek to negative position",
));
}
u64::try_from(result).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, "seek position overflow")
})?
}
};
self.cursor = new_pos;
Ok(self.cursor)
}
}
#[cfg(feature = "std")]
impl std::io::Read for MemFile {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.read_impl(buf).map_err(Into::into)
}
}
#[cfg(not(feature = "std"))]
impl Read for MemFile {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.read_impl(buf)
}
}
#[cfg(feature = "std")]
impl std::io::Write for MemFile {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.write_impl(buf).map_err(Into::into)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(not(feature = "std"))]
impl Write for MemFile {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_impl(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[cfg(feature = "std")]
impl std::io::Seek for MemFile {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
self.seek_impl(pos.into()).map_err(Into::into)
}
}
#[cfg(not(feature = "std"))]
impl Seek for MemFile {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
self.seek_impl(pos)
}
}
impl FsFile for MemFile {
fn sync_all(&self) -> io::Result<()> {
Ok(())
}
fn sync_data(&self) -> io::Result<()> {
Ok(())
}
fn metadata(&self) -> io::Result<FsMetadata> {
let data = lock(&self.data)?;
Ok(FsMetadata {
len: data.len() as u64,
is_dir: false,
is_file: true,
})
}
fn set_len(&self, size: u64) -> io::Result<()> {
if !self.writable {
return Err(io::Error::other("set_len requires write access"));
}
let new_len = usize::try_from(size).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"set_len size exceeds usize::MAX",
)
})?;
lock(&self.data)?.resize(new_len, 0);
self.clip_punched_to(size);
Ok(())
}
fn read_at(&self, buf: &mut [u8], offset: u64) -> io::Result<usize> {
if !self.readable {
return Err(io::Error::other("read_at requires read access"));
}
let offset = usize::try_from(offset).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"read_at offset exceeds usize::MAX",
)
})?;
let data = lock(&self.data)?;
Ok(copy_from_data(buf, &data, offset))
}
fn lock_exclusive(&self) -> io::Result<()> {
Ok(())
}
fn try_lock_exclusive(&self) -> io::Result<bool> {
Ok(true)
}
}
fn ensure_non_empty_path(path: &Path) -> io::Result<()> {
if path.as_os_str().is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty path"));
}
Ok(())
}
fn ensure_parent_dir(path: &Path, state: &State) -> io::Result<()> {
if let Some(parent) = path.parent()
&& !parent.as_os_str().is_empty()
&& parent != Path::new("/")
&& !state.dirs.contains(parent)
{
if state.files.contains_key(parent) {
return Err(io::Error::other(format!(
"parent is not a directory: {}",
parent.display()
)));
}
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("parent directory does not exist: {}", parent.display()),
));
}
Ok(())
}
impl Fs for MemFs {
fn open(&self, path: &Path, opts: &FsOpenOptions) -> io::Result<Box<dyn FsFile>> {
ensure_non_empty_path(path)?;
let mut state = write_state(&self.state)?;
let path = path.to_path_buf();
let wants_write = opts.write || opts.append;
if !opts.read && !wants_write {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"open requires at least read, write, or append access",
));
}
if opts.truncate && opts.append {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"truncate and append cannot be used together",
));
}
if opts.truncate && !opts.write {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"truncate requires write access",
));
}
if (opts.create || opts.create_new) && !wants_write {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"create/create_new requires write or append access",
));
}
ensure_parent_dir(&path, &state)?;
let exists = state.files.contains_key(&path);
let is_dir = state.dirs.contains(&path);
if is_dir && !opts.create && !opts.create_new {
return Err(io::Error::other(format!(
"path is a directory: {}",
path.display()
)));
}
if is_dir && (opts.create || opts.create_new) {
return Err(io::Error::new(
io::ErrorKind::AlreadyExists,
format!("path is a directory: {}", path.display()),
));
}
if opts.create_new {
if exists {
return Err(io::Error::new(
io::ErrorKind::AlreadyExists,
format!("file already exists: {}", path.display()),
));
}
let data = Arc::new(Mutex::new(Vec::new()));
state.files.insert(path.clone(), Arc::clone(&data));
return Ok(Box::new(MemFile {
data,
cursor: 0,
readable: opts.read,
writable: opts.write || opts.append,
is_append: opts.append,
state: Arc::clone(&self.state),
path,
has_punches: Arc::clone(&self.has_punches),
}));
}
if exists {
let data = state
.files
.get(&path)
.map(Arc::clone)
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "concurrent removal"))?;
if opts.truncate {
lock(&data)?.clear();
state.punched.remove(&path);
}
let cursor = 0;
Ok(Box::new(MemFile {
data,
cursor,
readable: opts.read,
writable: opts.write || opts.append,
is_append: opts.append,
state: Arc::clone(&self.state),
path,
has_punches: Arc::clone(&self.has_punches),
}))
} else if opts.create {
let data = Arc::new(Mutex::new(Vec::new()));
state.files.insert(path.clone(), Arc::clone(&data));
Ok(Box::new(MemFile {
data,
cursor: 0,
readable: opts.read,
writable: opts.write || opts.append,
is_append: opts.append,
state: Arc::clone(&self.state),
path,
has_punches: Arc::clone(&self.has_punches),
}))
} else {
Err(io::Error::new(
io::ErrorKind::NotFound,
format!("file not found: {}", path.display()),
))
}
}
fn create_dir_all(&self, path: &Path) -> io::Result<()> {
ensure_non_empty_path(path)?;
let mut state = write_state(&self.state)?;
let mut to_create = Vec::new();
let mut current = path.to_path_buf();
loop {
if state.files.contains_key(¤t) {
return Err(io::Error::new(
io::ErrorKind::AlreadyExists,
format!("path conflicts with existing file: {}", current.display()),
));
}
to_create.push(current.clone());
if !current.pop() || current.as_os_str().is_empty() {
break;
}
}
for dir in to_create {
state.dirs.insert(dir);
}
Ok(())
}
fn create_dir(&self, path: &Path) -> io::Result<()> {
ensure_non_empty_path(path)?;
let mut state = write_state(&self.state)?;
if state.dirs.contains(path) || state.files.contains_key(path) {
return Err(io::Error::new(
io::ErrorKind::AlreadyExists,
format!("path already exists: {}", path.display()),
));
}
ensure_parent_dir(path, &state)?;
state.dirs.insert(path.to_path_buf());
Ok(())
}
fn read_dir(&self, path: &Path) -> io::Result<Vec<FsDirEntry>> {
let state = read_state(&self.state)?;
if !state.dirs.contains(path) {
if state.files.contains_key(path) {
return Err(io::Error::other(format!(
"not a directory: {}",
path.display()
)));
}
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("directory not found: {}", path.display()),
));
}
let mut entries = Vec::new();
for file_path in state.files.keys() {
if file_path.parent() == Some(path)
&& let Some(name) = file_path.file_name()
{
#[cfg(feature = "std")]
let file_name = name.to_str().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!(
"non-UTF-8 filename in directory {}: {}",
path.display(),
name.display()
),
)
})?;
#[cfg(not(feature = "std"))]
let file_name = name;
entries.push(FsDirEntry {
path: file_path.clone(),
file_name: file_name.to_owned(),
is_dir: false,
});
}
}
for dir_path in &state.dirs {
if dir_path.parent() == Some(path)
&& dir_path != path
&& let Some(name) = dir_path.file_name()
{
#[cfg(feature = "std")]
let file_name = name.to_str().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!(
"non-UTF-8 filename in directory {}: {}",
path.display(),
name.display()
),
)
})?;
#[cfg(not(feature = "std"))]
let file_name = name;
entries.push(FsDirEntry {
path: dir_path.clone(),
file_name: file_name.to_owned(),
is_dir: true,
});
}
}
Ok(entries)
}
fn remove_file(&self, path: &Path) -> io::Result<()> {
let mut state = write_state(&self.state)?;
if state.dirs.contains(path) {
return Err(io::Error::other(format!(
"cannot remove_file on directory: {}",
path.display()
)));
}
if state.files.remove(path).is_none() {
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("file not found: {}", path.display()),
));
}
state.punched.remove(path);
Ok(())
}
fn remove_dir_all(&self, path: &Path) -> io::Result<()> {
let mut state = write_state(&self.state)?;
if state.files.contains_key(path) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("path is not a directory: {}", path.display()),
));
}
if !state.dirs.contains(path) {
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("path not found: {}", path.display()),
));
}
state.files.retain(|p, _| !p.starts_with(path));
state.dirs.retain(|p| !p.starts_with(path));
state.punched.retain(|p, _| !p.starts_with(path));
state.dirs.insert(PathBuf::from("/"));
Ok(())
}
fn rename(&self, from: &Path, to: &Path) -> io::Result<()> {
ensure_non_empty_path(from)?;
ensure_non_empty_path(to)?;
let mut state = write_state(&self.state)?;
ensure_parent_dir(to, &state)?;
if state.dirs.contains(to) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("destination is a directory: {}", to.display()),
));
}
if state.dirs.contains(from) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("path is a directory: {}", from.display()),
));
}
if let Some(data) = state.files.remove(from) {
state.files.insert(to.to_path_buf(), data);
match state.punched.remove(from) {
Some(ranges) => {
state.punched.insert(to.to_path_buf(), ranges);
}
None => {
state.punched.remove(to);
}
}
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::NotFound,
format!("file not found: {}", from.display()),
))
}
}
fn metadata(&self, path: &Path) -> io::Result<FsMetadata> {
let state = read_state(&self.state)?;
if let Some(data) = state.files.get(path) {
let d = lock(data)?;
Ok(FsMetadata {
len: d.len() as u64,
is_dir: false,
is_file: true,
})
} else if state.dirs.contains(path) {
Ok(FsMetadata {
len: 0,
is_dir: true,
is_file: false,
})
} else {
Err(io::Error::new(
io::ErrorKind::NotFound,
format!("path not found: {}", path.display()),
))
}
}
fn available_space(&self, _path: &Path) -> io::Result<u64> {
let capacity = self.capacity.load(portable_atomic::Ordering::Relaxed);
if capacity == u64::MAX {
Ok(u64::MAX)
} else {
Ok(capacity.saturating_sub(self.stored_bytes()))
}
}
fn sync_directory(&self, path: &Path) -> io::Result<()> {
let state = read_state(&self.state)?;
if !state.dirs.contains(path) {
if state.files.contains_key(path) {
return Err(io::Error::other(format!(
"sync_directory: not a directory: {}",
path.display()
)));
}
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("sync_directory: path not found: {}", path.display()),
));
}
Ok(())
}
fn exists(&self, path: &Path) -> io::Result<bool> {
let state = read_state(&self.state)?;
Ok(state.files.contains_key(path) || state.dirs.contains(path))
}
fn hard_link(&self, src: &Path, dst: &Path) -> io::Result<()> {
ensure_non_empty_path(src)?;
ensure_non_empty_path(dst)?;
let mut state = write_state(&self.state)?;
ensure_parent_dir(dst, &state)?;
if state.dirs.contains(dst) {
return Err(io::Error::new(
io::ErrorKind::AlreadyExists,
format!("destination is a directory: {}", dst.display()),
));
}
if state.files.contains_key(dst) {
return Err(io::Error::new(
io::ErrorKind::AlreadyExists,
format!("destination already exists: {}", dst.display()),
));
}
let bytes = {
let src_data = state.files.get(src).ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("source file not found: {}", src.display()),
)
})?;
let guard = lock(src_data)?;
guard.clone()
};
state
.files
.insert(dst.to_path_buf(), Arc::new(Mutex::new(bytes)));
Ok(())
}
fn backend_id(&self) -> Option<u64> {
Some(self.namespace_id)
}
fn volume_id(&self, _path: &Path) -> Option<u64> {
Some(self.namespace_id)
}
fn capabilities(&self, _path: &Path) -> FsCapabilities {
FsCapabilities {
punch_hole: self
.punch_hole_supported
.load(portable_atomic::Ordering::Relaxed),
..FsCapabilities::default()
}
}
fn punch_hole(&self, path: &Path, offset: u64, len: u64) -> io::Result<()> {
let mut state = self.state.write();
let (start, end) = {
let data = state.files.get(path).ok_or_else(|| {
io::Error::new(io::ErrorKind::NotFound, "punch_hole: file not found")
})?;
let mut buf = data.lock();
let file_len = buf.len() as u64;
let start = offset.min(file_len);
let end = offset.saturating_add(len).min(file_len);
if start >= end {
return Ok(());
}
#[expect(
clippy::cast_possible_truncation,
reason = "start/end are clamped to buf.len() (a usize), so they fit usize"
)]
let (s, e) = (start as usize, end as usize);
if let Some(slice) = buf.get_mut(s..e) {
slice.fill(0);
}
(start, end)
};
let ranges = state.punched.entry(path.to_path_buf()).or_default();
let before: u64 = ranges.iter().map(|&(s, e)| e - s).sum();
merge_punched_range(ranges, start, end);
let after: u64 = ranges.iter().map(|&(s, e)| e - s).sum();
self.punched_total
.fetch_add(after - before, portable_atomic::Ordering::Relaxed);
self.has_punches
.store(true, portable_atomic::Ordering::Relaxed);
Ok(())
}
}
#[expect(
clippy::unnecessary_wraps,
reason = "Result kept for ?-compatible call sites and future fallible-lock parity"
)]
fn lock<T>(m: &Mutex<T>) -> io::Result<impl core::ops::DerefMut<Target = T> + '_> {
Ok(m.lock())
}
#[expect(
clippy::unnecessary_wraps,
reason = "Result kept for ?-compatible call sites and future fallible-lock parity"
)]
fn read_state(rw: &RwLock<State>) -> io::Result<impl core::ops::Deref<Target = State> + '_> {
Ok(rw.read())
}
#[expect(
clippy::unnecessary_wraps,
reason = "Result kept for ?-compatible call sites and future fallible-lock parity"
)]
fn write_state(rw: &RwLock<State>) -> io::Result<impl core::ops::DerefMut<Target = State> + '_> {
Ok(rw.write())
}
#[cfg(test)]
#[expect(
clippy::unwrap_used,
clippy::indexing_slicing,
clippy::unnecessary_wraps,
reason = "test code"
)]
mod tests;