use super::Header;
use crate::{BufferPool, Error};
use commonware_codec::Encode;
use commonware_utils::{from_hex, hex};
#[cfg(unix)]
use std::path::Path;
use std::{ops::RangeInclusive, path::PathBuf, sync::Arc};
use tokio::{
fs,
io::{AsyncReadExt, AsyncWriteExt},
sync::Mutex,
};
#[cfg(not(unix))]
mod fallback;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
async fn sync_dir(path: &Path) -> Result<(), Error> {
let dir = fs::File::open(path).await.map_err(|e| {
Error::BlobOpenFailed(
path.to_string_lossy().to_string(),
"directory".to_string(),
e,
)
})?;
dir.sync_all().await.map_err(|e| {
Error::BlobSyncFailed(
path.to_string_lossy().to_string(),
"directory".to_string(),
e,
)
})
}
#[derive(Clone)]
pub struct Config {
pub storage_directory: PathBuf,
pub maximum_buffer_size: usize,
}
impl Config {
pub const fn new(storage_directory: PathBuf, maximum_buffer_size: usize) -> Self {
Self {
storage_directory,
maximum_buffer_size,
}
}
}
#[derive(Clone)]
pub struct Storage {
lock: Arc<Mutex<()>>,
cfg: Config,
pool: BufferPool,
}
impl Storage {
pub fn new(cfg: Config, pool: BufferPool) -> Self {
Self {
lock: Arc::new(Mutex::new(())),
cfg,
pool,
}
}
}
impl crate::Storage for Storage {
#[cfg(unix)]
type Blob = unix::Blob;
#[cfg(not(unix))]
type Blob = fallback::Blob;
async fn open_versioned(
&self,
partition: &str,
name: &[u8],
versions: RangeInclusive<u16>,
) -> Result<(Self::Blob, u64, u16), Error> {
super::validate_partition_name(partition)?;
let _guard = self.lock.lock().await;
let path = self.cfg.storage_directory.join(partition).join(hex(name));
let parent = match path.parent() {
Some(parent) => parent,
None => return Err(Error::PartitionCreationFailed(partition.into())),
};
#[cfg(unix)]
let parent_existed = parent.exists();
fs::create_dir_all(parent)
.await
.map_err(|_| Error::PartitionCreationFailed(partition.into()))?;
let mut file = fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)
.await
.map_err(|e| Error::BlobOpenFailed(partition.into(), hex(name), e))?;
let len = file.metadata().await.map_err(|_| Error::ReadFailed)?.len();
let newly_created = len == 0;
if newly_created {
file.sync_all()
.await
.map_err(|e| Error::BlobSyncFailed(partition.into(), hex(name), e))?;
#[cfg(unix)]
{
sync_dir(parent).await?;
if !parent_existed {
sync_dir(&self.cfg.storage_directory).await?;
}
}
}
file.set_max_buf_size(self.cfg.maximum_buffer_size);
let (blob_version, logical_size) = if Header::missing(len) {
let (header, blob_version) = Header::new(&versions);
file.set_len(Header::SIZE_U64)
.await
.map_err(|e| Error::BlobResizeFailed(partition.into(), hex(name), e))?;
file.write_all(&header.encode())
.await
.map_err(|_| Error::WriteFailed)?;
file.sync_all()
.await
.map_err(|e| Error::BlobSyncFailed(partition.into(), hex(name), e))?;
(blob_version, 0)
} else {
let mut header_bytes = [0u8; Header::SIZE];
file.read_exact(&mut header_bytes)
.await
.map_err(|_| Error::ReadFailed)?;
Header::from(header_bytes, len, &versions).map_err(|e| e.into_error(partition, name))?
};
#[cfg(unix)]
{
let file = file.into_std().await;
Ok((
Self::Blob::new(partition.into(), name, file, self.pool.clone()),
logical_size,
blob_version,
))
}
#[cfg(not(unix))]
{
Ok((
Self::Blob::new(partition.into(), name, file, self.pool.clone()),
logical_size,
blob_version,
))
}
}
async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
super::validate_partition_name(partition)?;
let _guard = self.lock.lock().await;
let path = self.cfg.storage_directory.join(partition);
if let Some(name) = name {
let blob_path = path.join(hex(name));
fs::remove_file(blob_path)
.await
.map_err(|_| Error::BlobMissing(partition.into(), hex(name)))?;
#[cfg(unix)]
sync_dir(&path).await?;
} else {
fs::remove_dir_all(&path)
.await
.map_err(|_| Error::PartitionMissing(partition.into()))?;
#[cfg(unix)]
sync_dir(&self.cfg.storage_directory).await?;
}
Ok(())
}
async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
super::validate_partition_name(partition)?;
let _guard = self.lock.lock().await;
let path = self.cfg.storage_directory.join(partition);
let mut entries = fs::read_dir(path)
.await
.map_err(|_| Error::PartitionMissing(partition.into()))?;
let mut blobs = Vec::new();
while let Some(entry) = entries.next_entry().await.map_err(|_| Error::ReadFailed)? {
let file_type = entry.file_type().await.map_err(|_| Error::ReadFailed)?;
if !file_type.is_file() {
return Err(Error::PartitionCorrupt(partition.into()));
}
if let Some(name) = entry.file_name().to_str() {
let name = from_hex(name).ok_or(Error::PartitionCorrupt(partition.into()))?;
blobs.push(name);
}
}
Ok(blobs)
}
}
#[cfg(test)]
mod tests {
use super::{Header, *};
use crate::{storage::tests::run_storage_tests, Blob, BufferPoolConfig, Storage as _};
use rand::{Rng as _, SeedableRng};
use std::env;
fn test_pool() -> BufferPool {
BufferPool::new(
BufferPoolConfig::for_storage(),
&mut prometheus_client::registry::Registry::default(),
)
}
#[tokio::test]
async fn test_storage() {
let mut rng = rand::rngs::StdRng::from_entropy();
let storage_directory = env::temp_dir().join(format!("storage_tokio_{}", rng.gen::<u64>()));
let config = Config::new(storage_directory, 2 * 1024 * 1024);
let storage = Storage::new(config, test_pool());
run_storage_tests(storage).await;
}
#[tokio::test]
async fn test_blob_header_handling() {
let mut rng = rand::rngs::StdRng::from_entropy();
let storage_directory =
env::temp_dir().join(format!("storage_tokio_header_{}", rng.gen::<u64>()));
let config = Config::new(storage_directory.clone(), 2 * 1024 * 1024);
let storage = Storage::new(config, test_pool());
let (blob, size) = storage.open("partition", b"test").await.unwrap();
assert_eq!(size, 0, "new blob should have logical size 0");
let file_path = storage_directory.join("partition").join(hex(b"test"));
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64,
"raw file should have 8-byte header"
);
let data = b"hello world";
blob.write_at(0, data).await.unwrap();
blob.sync().await.unwrap();
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(metadata.len(), Header::SIZE_U64 + data.len() as u64);
let raw_content = std::fs::read(&file_path).unwrap();
assert_eq!(&raw_content[..Header::MAGIC_LENGTH], &Header::MAGIC);
assert_eq!(
&raw_content[Header::MAGIC_LENGTH..Header::MAGIC_LENGTH + Header::VERSION_LENGTH],
&Header::RUNTIME_VERSION.to_be_bytes()
);
assert_eq!(&raw_content[Header::SIZE..], data);
let read_buf = blob.read_at(0, data.len()).await.unwrap();
assert_eq!(read_buf.coalesce(), data);
blob.resize(5).await.unwrap();
blob.sync().await.unwrap();
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64 + 5,
"resize(5) should result in 13 raw bytes"
);
blob.resize(0).await.unwrap();
blob.sync().await.unwrap();
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64,
"resize(0) should leave only header"
);
blob.write_at(0, b"test data").await.unwrap();
blob.sync().await.unwrap();
drop(blob);
let (blob2, size2) = storage.open("partition", b"test").await.unwrap();
assert_eq!(size2, 9, "reopened blob should have logical size 9");
let read_buf = blob2.read_at(0, 9).await.unwrap();
assert_eq!(read_buf.coalesce(), b"test data");
drop(blob2);
let corrupted_path = storage_directory.join("partition").join(hex(b"corrupted"));
std::fs::write(&corrupted_path, vec![0u8; 4]).unwrap();
let (blob3, size3) = storage.open("partition", b"corrupted").await.unwrap();
assert_eq!(size3, 0, "corrupted blob should return logical size 0");
let metadata = std::fs::metadata(&corrupted_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64,
"corrupted blob should be reset to header-only"
);
drop(blob3);
let _ = std::fs::remove_dir_all(&storage_directory);
}
#[tokio::test]
async fn test_blob_magic_mismatch() {
let storage_directory =
env::temp_dir().join(format!("test_magic_mismatch_{}", rand::random::<u64>()));
let storage = Storage::new(
Config {
storage_directory: storage_directory.clone(),
maximum_buffer_size: 1024 * 1024,
},
test_pool(),
);
let partition_path = storage_directory.join("partition");
std::fs::create_dir_all(&partition_path).unwrap();
let bad_magic_path = partition_path.join(hex(b"bad_magic"));
std::fs::write(&bad_magic_path, vec![0u8; Header::SIZE]).unwrap();
let result = storage.open("partition", b"bad_magic").await;
assert!(
matches!(result, Err(crate::Error::BlobCorrupt(_, _, reason)) if reason.contains("invalid magic"))
);
let _ = std::fs::remove_dir_all(&storage_directory);
}
}