use crate::error::Result;
use std::future::Future;
use std::path::Path;
#[allow(clippy::module_name_repetitions)]
pub trait LocalFs: Send + Sync {
#[allow(clippy::manual_async_fn)]
fn read_all(&self, path: &Path) -> impl Future<Output = Result<Vec<u8>>> + Send;
#[allow(clippy::manual_async_fn)]
fn write_all(&self, path: &Path, bytes: &[u8]) -> impl Future<Output = Result<()>> + Send;
#[allow(clippy::manual_async_fn)]
fn exists(&self, path: &Path) -> impl Future<Output = bool> + Send;
}
#[non_exhaustive]
#[derive(Clone, Debug, Default)]
pub struct TokioLocalFs;
impl TokioLocalFs {
#[allow(clippy::new_without_default)]
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl LocalFs for TokioLocalFs {
async fn read_all(&self, path: &Path) -> Result<Vec<u8>> {
tokio::time::timeout(std::time::Duration::from_secs(60), async {
tokio::fs::read(path).await.map_err(|e| {
crate::error::Error::io("failed to read local file", e)
.with("local_path", path.display().to_string())
})
})
.await
.map_err(|_e| crate::error::Error::timeout("local file operation exceeded 60s"))?
}
async fn write_all(&self, path: &Path, bytes: &[u8]) -> Result<()> {
let write_fut = async {
use tokio::io::AsyncWriteExt;
let mut file = tokio::fs::File::create(path).await.map_err(|e| {
crate::error::Error::io("failed to create local file", e)
.with("local_path", path.display().to_string())
})?;
file.write_all(bytes).await.map_err(|e| {
crate::error::Error::io("failed to write local file", e)
.with("local_path", path.display().to_string())
})?;
file.flush().await.map_err(|e| {
crate::error::Error::io("failed to flush local file", e)
.with("local_path", path.display().to_string())
})
};
tokio::time::timeout(std::time::Duration::from_secs(60), write_fut)
.await
.map_err(|_e| crate::error::Error::timeout("local file operation exceeded 60s"))?
}
async fn exists(&self, path: &Path) -> bool {
tokio::fs::metadata(path).await.is_ok()
}
}
#[non_exhaustive]
#[derive(Clone, Debug, Default)]
pub struct MemoryLocalFs {
store:
std::sync::Arc<std::sync::RwLock<std::collections::HashMap<std::path::PathBuf, Vec<u8>>>>,
}
impl MemoryLocalFs {
#[allow(clippy::new_without_default)]
#[must_use]
pub fn new() -> Self {
Self {
store: std::sync::Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
}
}
}
impl LocalFs for MemoryLocalFs {
async fn read_all(&self, path: &Path) -> Result<Vec<u8>> {
let guard = self.store.read().map_err(|e| {
crate::error::Error::io("lock poisoned", std::io::Error::other(e.to_string()))
})?;
guard.get(path).cloned().ok_or_else(|| {
crate::error::Error::not_found(format!("file not found: {}", path.display()))
})
}
async fn write_all(&self, path: &Path, bytes: &[u8]) -> Result<()> {
let mut guard = self.store.write().map_err(|e| {
crate::error::Error::io("lock poisoned", std::io::Error::other(e.to_string()))
})?;
guard.insert(path.to_path_buf(), bytes.to_vec());
Ok(())
}
async fn exists(&self, path: &Path) -> bool {
let guard = self.store.read().unwrap();
guard.contains_key(path)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
#[async_test_macros::async_test]
async fn memory_local_fs_write_then_read() {
let fs = MemoryLocalFs::new();
let path = Path::new("/mem/test.txt");
fs.write_all(path, b"content").await.unwrap();
let bytes = fs.read_all(path).await.unwrap();
assert_eq!(bytes, b"content");
}
#[async_test_macros::async_test]
async fn memory_local_fs_exists() {
let fs = MemoryLocalFs::new();
let path = Path::new("/mem/foo");
assert!(!fs.exists(path).await);
fs.write_all(path, b"x").await.unwrap();
assert!(fs.exists(path).await);
}
#[async_test_macros::async_test]
async fn memory_local_fs_read_missing_returns_error() {
let fs = MemoryLocalFs::new();
let res = fs.read_all(Path::new("/mem/nonexistent")).await;
assert!(res.is_err());
assert_eq!(res.unwrap_err().kind, crate::error::ErrorKind::NotFound);
}
#[async_test_macros::async_test]
async fn memory_local_fs_overwrite() {
let fs = MemoryLocalFs::new();
let path = Path::new("/mem/overwrite.txt");
fs.write_all(path, b"first").await.unwrap();
fs.write_all(path, b"second").await.unwrap();
let bytes = fs.read_all(path).await.unwrap();
assert_eq!(bytes, b"second");
}
}