use super::Header;
use crate::{
iouring::{self, should_retry},
Error,
};
use commonware_codec::Encode;
use commonware_utils::{from_hex, hex, StableBuf};
use futures::{
channel::{mpsc, oneshot},
executor::block_on,
SinkExt as _,
};
use io_uring::{opcode, types};
use prometheus_client::registry::Registry;
use std::{
fs::{self, File},
io::{Error as IoError, Read, Seek, SeekFrom, Write},
ops::RangeInclusive,
os::fd::AsRawFd,
path::{Path, PathBuf},
sync::Arc,
};
fn sync_dir(path: &Path) -> Result<(), Error> {
let dir = File::open(path).map_err(|e| {
Error::BlobOpenFailed(
path.to_string_lossy().to_string(),
"directory".to_string(),
e,
)
})?;
dir.sync_all().map_err(|e| {
Error::BlobSyncFailed(
path.to_string_lossy().to_string(),
"directory".to_string(),
e,
)
})
}
#[derive(Clone, Debug)]
pub struct Config {
pub storage_directory: PathBuf,
pub iouring_config: iouring::Config,
}
#[derive(Clone)]
pub struct Storage {
storage_directory: PathBuf,
io_sender: mpsc::Sender<iouring::Op>,
}
impl Storage {
pub fn start(mut cfg: Config, registry: &mut Registry) -> Self {
let (io_sender, receiver) = mpsc::channel::<iouring::Op>(cfg.iouring_config.size as usize);
let storage = Self {
storage_directory: cfg.storage_directory.clone(),
io_sender,
};
let metrics = Arc::new(iouring::Metrics::new(registry));
cfg.iouring_config.single_issuer = true;
std::thread::spawn(|| block_on(iouring::run(cfg.iouring_config, metrics, receiver)));
storage
}
}
impl crate::Storage for Storage {
type Blob = Blob;
async fn open_versioned(
&self,
partition: &str,
name: &[u8],
versions: RangeInclusive<u16>,
) -> Result<(Blob, u64, u16), Error> {
super::validate_partition_name(partition)?;
let path = self.storage_directory.join(partition).join(hex(name));
let parent = path
.parent()
.ok_or_else(|| Error::PartitionMissing(partition.into()))?;
let parent_existed = parent.exists();
fs::create_dir_all(parent).map_err(|_| Error::PartitionCreationFailed(partition.into()))?;
let mut file = fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)
.map_err(|e| Error::BlobOpenFailed(partition.into(), hex(name), e))?;
let raw_len = file.metadata().map_err(|_| Error::ReadFailed)?.len();
let (blob_version, logical_len) = if Header::missing(raw_len) {
let (header, blob_version) = Header::new(&versions);
file.set_len(Header::SIZE_U64)
.map_err(|e| Error::BlobResizeFailed(partition.into(), hex(name), e))?;
file.seek(SeekFrom::Start(0))
.map_err(|_| Error::WriteFailed)?;
file.write_all(&header.encode())
.map_err(|_| Error::WriteFailed)?;
file.sync_all()
.map_err(|e| Error::BlobSyncFailed(partition.into(), hex(name), e))?;
if raw_len == 0 {
sync_dir(parent)?;
if !parent_existed {
sync_dir(&self.storage_directory)?;
}
}
(blob_version, 0)
} else {
file.seek(SeekFrom::Start(0))
.map_err(|_| Error::ReadFailed)?;
let mut header_bytes = [0u8; Header::SIZE];
file.read_exact(&mut header_bytes)
.map_err(|_| Error::ReadFailed)?;
Header::from(header_bytes, raw_len, &versions)
.map_err(|e| e.into_error(partition, name))?
};
let blob = Blob::new(partition.into(), name, file, self.io_sender.clone());
Ok((blob, logical_len, blob_version))
}
async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
super::validate_partition_name(partition)?;
let path = self.storage_directory.join(partition);
if let Some(name) = name {
let blob_path = path.join(hex(name));
fs::remove_file(blob_path)
.map_err(|_| Error::BlobMissing(partition.into(), hex(name)))?;
sync_dir(&path)?;
} else {
fs::remove_dir_all(&path).map_err(|_| Error::PartitionMissing(partition.into()))?;
sync_dir(&self.storage_directory)?;
}
Ok(())
}
async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
super::validate_partition_name(partition)?;
let path = self.storage_directory.join(partition);
let entries =
std::fs::read_dir(&path).map_err(|_| Error::PartitionMissing(partition.into()))?;
let mut blobs = Vec::new();
for entry in entries {
let entry = entry.map_err(|_| Error::ReadFailed)?;
let file_type = entry.file_type().map_err(|_| Error::ReadFailed)?;
if !file_type.is_file() {
return Err(Error::PartitionCorrupt(partition.into()));
}
if let Some(name) = entry.file_name().to_str() {
let name = from_hex(name).ok_or(Error::PartitionCorrupt(partition.into()))?;
blobs.push(name);
}
}
Ok(blobs)
}
}
pub struct Blob {
partition: String,
name: Vec<u8>,
file: Arc<File>,
io_sender: mpsc::Sender<iouring::Op>,
}
impl Clone for Blob {
fn clone(&self) -> Self {
Self {
partition: self.partition.clone(),
name: self.name.clone(),
file: self.file.clone(),
io_sender: self.io_sender.clone(),
}
}
}
impl Blob {
fn new(
partition: String,
name: &[u8],
file: File,
io_sender: mpsc::Sender<iouring::Op>,
) -> Self {
Self {
partition,
name: name.to_vec(),
file: Arc::new(file),
io_sender,
}
}
}
impl crate::Blob for Blob {
async fn read_at(
&self,
buf: impl Into<StableBuf> + Send,
offset: u64,
) -> Result<StableBuf, Error> {
let mut buf = buf.into();
let fd = types::Fd(self.file.as_raw_fd());
let mut bytes_read = 0;
let buf_len = buf.len();
let mut io_sender = self.io_sender.clone();
let offset = offset
.checked_add(Header::SIZE_U64)
.ok_or(Error::OffsetOverflow)?;
while bytes_read < buf_len {
let remaining = unsafe {
std::slice::from_raw_parts_mut(
buf.as_mut_ptr().add(bytes_read),
buf_len - bytes_read,
)
};
let offset = offset + bytes_read as u64;
let op = opcode::Read::new(fd, remaining.as_mut_ptr(), remaining.len() as _)
.offset(offset as _)
.build();
let (sender, receiver) = oneshot::channel();
io_sender
.send(iouring::Op {
work: op,
sender,
buffer: Some(buf),
})
.await
.map_err(|_| Error::ReadFailed)?;
let (result, got_buf) = receiver.await.map_err(|_| Error::ReadFailed)?;
buf = got_buf.unwrap();
if should_retry(result) {
continue;
}
let op_bytes_read: usize = result.try_into().map_err(|_| Error::ReadFailed)?;
if op_bytes_read == 0 {
return Err(Error::BlobInsufficientLength);
}
bytes_read += op_bytes_read;
}
Ok(buf)
}
async fn write_at(&self, buf: impl Into<StableBuf> + Send, offset: u64) -> Result<(), Error> {
let mut buf = buf.into();
let fd = types::Fd(self.file.as_raw_fd());
let mut bytes_written = 0;
let buf_len = buf.len();
let mut io_sender = self.io_sender.clone();
let offset = offset
.checked_add(Header::SIZE_U64)
.ok_or(Error::OffsetOverflow)?;
while bytes_written < buf_len {
let remaining = unsafe {
std::slice::from_raw_parts(
buf.as_mut_ptr().add(bytes_written) as *const u8,
buf_len - bytes_written,
)
};
let offset = offset + bytes_written as u64;
let op = opcode::Write::new(fd, remaining.as_ptr(), remaining.len() as _)
.offset(offset as _)
.build();
let (sender, receiver) = oneshot::channel();
io_sender
.send(iouring::Op {
work: op,
sender,
buffer: Some(buf),
})
.await
.map_err(|_| Error::WriteFailed)?;
let (return_value, got_buf) = receiver.await.map_err(|_| Error::WriteFailed)?;
buf = got_buf.unwrap();
if should_retry(return_value) {
continue;
}
let op_bytes_written: usize =
return_value.try_into().map_err(|_| Error::WriteFailed)?;
bytes_written += op_bytes_written;
}
Ok(())
}
async fn resize(&self, len: u64) -> Result<(), Error> {
let len = len
.checked_add(Header::SIZE_U64)
.ok_or(Error::OffsetOverflow)?;
self.file.set_len(len).map_err(|e| {
Error::BlobResizeFailed(self.partition.clone(), hex(&self.name), IoError::other(e))
})
}
async fn sync(&self) -> Result<(), Error> {
loop {
let op = opcode::Fsync::new(types::Fd(self.file.as_raw_fd())).build();
let (sender, receiver) = oneshot::channel();
self.io_sender
.clone()
.send(iouring::Op {
work: op,
sender,
buffer: None,
})
.await
.map_err(|_| {
Error::BlobSyncFailed(
self.partition.clone(),
hex(&self.name),
IoError::other("failed to send work"),
)
})?;
let (return_value, _) = receiver.await.map_err(|_| {
Error::BlobSyncFailed(
self.partition.clone(),
hex(&self.name),
IoError::other("failed to read result"),
)
})?;
if should_retry(return_value) {
continue;
}
if return_value < 0 {
return Err(Error::BlobSyncFailed(
self.partition.clone(),
hex(&self.name),
IoError::other(format!("error code: {return_value}")),
));
}
return Ok(());
}
}
}
#[cfg(test)]
mod tests {
use super::{Header, *};
use crate::{storage::tests::run_storage_tests, Blob, Storage as _};
use rand::{Rng as _, SeedableRng as _};
use std::env;
fn create_test_storage() -> (Storage, PathBuf) {
let mut rng = rand::rngs::StdRng::from_entropy();
let storage_directory =
env::temp_dir().join(format!("commonware_iouring_storage_{}", rng.gen::<u64>()));
let storage = Storage::start(
Config {
storage_directory: storage_directory.clone(),
iouring_config: Default::default(),
},
&mut Registry::default(),
);
(storage, storage_directory)
}
#[tokio::test]
async fn test_iouring_storage() {
let (storage, storage_directory) = create_test_storage();
run_storage_tests(storage).await;
let _ = std::fs::remove_dir_all(storage_directory);
}
#[tokio::test]
async fn test_blob_header_handling() {
let (storage, storage_directory) = create_test_storage();
let (blob, size) = storage.open("partition", b"test").await.unwrap();
assert_eq!(size, 0, "new blob should have logical size 0");
let file_path = storage_directory.join("partition").join(hex(b"test"));
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64,
"raw file should have 8-byte header"
);
let data = b"hello world";
blob.write_at(data.to_vec(), 0).await.unwrap();
blob.sync().await.unwrap();
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(metadata.len(), Header::SIZE_U64 + data.len() as u64);
let raw_content = std::fs::read(&file_path).unwrap();
assert_eq!(&raw_content[..Header::MAGIC_LENGTH], &Header::MAGIC);
assert_eq!(
&raw_content[Header::MAGIC_LENGTH..Header::MAGIC_LENGTH + Header::VERSION_LENGTH],
&Header::RUNTIME_VERSION.to_be_bytes()
);
assert_eq!(&raw_content[Header::SIZE..], data);
let read_buf = blob.read_at(vec![0u8; data.len()], 0).await.unwrap();
assert_eq!(read_buf.as_ref(), data);
blob.resize(5).await.unwrap();
blob.sync().await.unwrap();
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64 + 5,
"resize(5) should result in 13 raw bytes"
);
blob.resize(0).await.unwrap();
blob.sync().await.unwrap();
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64,
"resize(0) should leave only header"
);
blob.write_at(b"test data".to_vec(), 0).await.unwrap();
blob.sync().await.unwrap();
drop(blob);
let (blob2, size2) = storage.open("partition", b"test").await.unwrap();
assert_eq!(size2, 9, "reopened blob should have logical size 9");
let read_buf = blob2.read_at(vec![0u8; 9], 0).await.unwrap();
assert_eq!(read_buf.as_ref(), b"test data");
drop(blob2);
let corrupted_path = storage_directory.join("partition").join(hex(b"corrupted"));
std::fs::write(&corrupted_path, vec![0u8; 4]).unwrap();
let (blob3, size3) = storage.open("partition", b"corrupted").await.unwrap();
assert_eq!(size3, 0, "corrupted blob should return logical size 0");
let metadata = std::fs::metadata(&corrupted_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64,
"corrupted blob should be reset to header-only"
);
drop(blob3);
let _ = std::fs::remove_dir_all(&storage_directory);
}
#[tokio::test]
async fn test_blob_magic_mismatch() {
let (storage, storage_directory) = create_test_storage();
let partition_path = storage_directory.join("partition");
std::fs::create_dir_all(&partition_path).unwrap();
let bad_magic_path = partition_path.join(hex(b"bad_magic"));
std::fs::write(&bad_magic_path, vec![0u8; Header::SIZE]).unwrap();
let result = storage.open("partition", b"bad_magic").await;
assert!(
matches!(result, Err(crate::Error::BlobCorrupt(_, _, reason)) if reason.contains("invalid magic"))
);
let _ = std::fs::remove_dir_all(&storage_directory);
}
}