use super::Header;
use crate::{
iouring::{self, should_retry, OpBuffer, OpFd, OpIovecs},
Buf, BufferPool, Error, IoBuf, IoBufs, IoBufsMut,
};
use commonware_codec::Encode;
use commonware_utils::{channel::oneshot, from_hex, hex};
use io_uring::{opcode, types::Fd};
use prometheus_client::registry::Registry;
use std::{
fs::{self, File},
io::{Error as IoError, Read, Seek, SeekFrom, Write},
ops::RangeInclusive,
os::fd::AsRawFd,
path::{Path, PathBuf},
sync::Arc,
};
const IOVEC_BATCH_SIZE: usize = 32;
fn sync_dir(path: &Path) -> Result<(), Error> {
let dir = File::open(path).map_err(|e| {
Error::BlobOpenFailed(
path.to_string_lossy().to_string(),
"directory".to_string(),
e,
)
})?;
dir.sync_all().map_err(|e| {
Error::BlobSyncFailed(
path.to_string_lossy().to_string(),
"directory".to_string(),
e,
)
})
}
#[derive(Clone, Debug)]
pub struct Config {
pub storage_directory: PathBuf,
pub iouring_config: iouring::Config,
}
#[derive(Clone)]
pub struct Storage {
storage_directory: PathBuf,
io_submitter: iouring::Submitter,
pool: BufferPool,
}
impl Storage {
pub fn start(mut cfg: Config, registry: &mut Registry, pool: BufferPool) -> Self {
cfg.iouring_config.single_issuer = true;
let (io_submitter, iouring_loop) = iouring::IoUringLoop::new(cfg.iouring_config, registry);
let storage = Self {
storage_directory: cfg.storage_directory,
io_submitter,
pool,
};
std::thread::spawn(move || iouring_loop.run());
storage
}
}
impl crate::Storage for Storage {
type Blob = Blob;
async fn open_versioned(
&self,
partition: &str,
name: &[u8],
versions: RangeInclusive<u16>,
) -> Result<(Blob, u64, u16), Error> {
super::validate_partition_name(partition)?;
let path = self.storage_directory.join(partition).join(hex(name));
let parent = path
.parent()
.ok_or_else(|| Error::PartitionMissing(partition.into()))?;
let parent_existed = parent.exists();
fs::create_dir_all(parent).map_err(|_| Error::PartitionCreationFailed(partition.into()))?;
let mut file = fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)
.map_err(|e| Error::BlobOpenFailed(partition.into(), hex(name), e))?;
let raw_len = file.metadata().map_err(|_| Error::ReadFailed)?.len();
let (blob_version, logical_len) = if Header::missing(raw_len) {
let (header, blob_version) = Header::new(&versions);
file.set_len(Header::SIZE_U64)
.map_err(|e| Error::BlobResizeFailed(partition.into(), hex(name), e))?;
file.seek(SeekFrom::Start(0))
.map_err(|_| Error::WriteFailed)?;
file.write_all(&header.encode())
.map_err(|_| Error::WriteFailed)?;
file.sync_all()
.map_err(|e| Error::BlobSyncFailed(partition.into(), hex(name), e))?;
if raw_len == 0 {
sync_dir(parent)?;
if !parent_existed {
sync_dir(&self.storage_directory)?;
}
}
(blob_version, 0)
} else {
file.seek(SeekFrom::Start(0))
.map_err(|_| Error::ReadFailed)?;
let mut header_bytes = [0u8; Header::SIZE];
file.read_exact(&mut header_bytes)
.map_err(|_| Error::ReadFailed)?;
Header::from(header_bytes, raw_len, &versions)
.map_err(|e| e.into_error(partition, name))?
};
let blob = Blob::new(
partition.into(),
name,
file,
self.io_submitter.clone(),
self.pool.clone(),
);
Ok((blob, logical_len, blob_version))
}
async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
super::validate_partition_name(partition)?;
let path = self.storage_directory.join(partition);
if let Some(name) = name {
let blob_path = path.join(hex(name));
fs::remove_file(blob_path)
.map_err(|_| Error::BlobMissing(partition.into(), hex(name)))?;
sync_dir(&path)?;
} else {
fs::remove_dir_all(&path).map_err(|_| Error::PartitionMissing(partition.into()))?;
sync_dir(&self.storage_directory)?;
}
Ok(())
}
async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
super::validate_partition_name(partition)?;
let path = self.storage_directory.join(partition);
let entries =
std::fs::read_dir(&path).map_err(|_| Error::PartitionMissing(partition.into()))?;
let mut blobs = Vec::new();
for entry in entries {
let entry = entry.map_err(|_| Error::ReadFailed)?;
let file_type = entry.file_type().map_err(|_| Error::ReadFailed)?;
if !file_type.is_file() {
return Err(Error::PartitionCorrupt(partition.into()));
}
if let Some(name) = entry.file_name().to_str() {
let name = from_hex(name).ok_or(Error::PartitionCorrupt(partition.into()))?;
blobs.push(name);
}
}
Ok(blobs)
}
}
pub struct Blob {
partition: String,
name: Vec<u8>,
file: Arc<File>,
io_submitter: iouring::Submitter,
pool: BufferPool,
}
impl Clone for Blob {
fn clone(&self) -> Self {
Self {
partition: self.partition.clone(),
name: self.name.clone(),
file: self.file.clone(),
io_submitter: self.io_submitter.clone(),
pool: self.pool.clone(),
}
}
}
impl Blob {
fn new(
partition: String,
name: &[u8],
file: File,
io_submitter: iouring::Submitter,
pool: BufferPool,
) -> Self {
Self {
partition,
name: name.to_vec(),
file: Arc::new(file),
io_submitter,
pool,
}
}
fn as_raw_fd(&self) -> Fd {
Fd(self.file.as_raw_fd())
}
async fn write_single_at(&self, mut offset: u64, mut buf: IoBuf) -> Result<(), Error> {
let mut bytes_written = 0;
let buf_len = buf.len();
while bytes_written < buf_len {
let ptr = unsafe { buf.as_ptr().add(bytes_written) };
let remaining_len = buf_len - bytes_written;
let op = opcode::Write::new(self.as_raw_fd(), ptr, remaining_len as _)
.offset(offset as _)
.build();
let (sender, receiver) = oneshot::channel();
self.io_submitter
.send(iouring::Op {
work: op,
sender,
buffer: Some(OpBuffer::Write(buf)),
fd: Some(OpFd::File(self.file.clone())),
iovecs: None,
})
.await
.map_err(|_| Error::WriteFailed)?;
let (return_value, return_buf) = receiver.await.map_err(|_| Error::WriteFailed)?;
buf = match return_buf {
Some(OpBuffer::Write(b)) => b,
_ => unreachable!("io_uring loop returns the same OpBuffer that was submitted"),
};
if should_retry(return_value) {
continue;
}
let op_bytes_written: usize =
return_value.try_into().map_err(|_| Error::WriteFailed)?;
if op_bytes_written == 0 {
return Err(Error::WriteFailed);
}
bytes_written += op_bytes_written;
offset += op_bytes_written as u64;
}
Ok(())
}
async fn write_vectored_at(&self, mut offset: u64, mut bufs: IoBufs) -> Result<(), Error> {
while bufs.has_remaining() {
let (iovecs, iovecs_len) = {
let max_iovecs = bufs.chunk_count().min(IOVEC_BATCH_SIZE);
assert!(
max_iovecs > 0,
"chunk_count should be > 0 if bufs.has_remaining() is true"
);
let mut iovecs: Box<[libc::iovec]> = std::iter::repeat_n(
libc::iovec {
iov_base: std::ptr::NonNull::<u8>::dangling().as_ptr().cast(),
iov_len: 0,
},
max_iovecs,
)
.collect();
let io_slices: &mut [std::io::IoSlice<'_>] = unsafe {
std::slice::from_raw_parts_mut(
iovecs.as_mut_ptr().cast::<std::io::IoSlice<'_>>(),
iovecs.len(),
)
};
let io_slices_len = bufs.chunks_vectored(io_slices);
assert!(
io_slices_len > 0,
"chunks_vectored should produce at least one slice when bufs has remaining"
);
(OpIovecs::new(iovecs), io_slices_len)
};
let op = opcode::Writev::new(self.as_raw_fd(), iovecs.as_ptr(), iovecs_len as _)
.offset(offset as _)
.build();
let (sender, receiver) = oneshot::channel();
self.io_submitter
.send(iouring::Op {
work: op,
sender,
buffer: Some(OpBuffer::WriteVectored(bufs)),
fd: Some(OpFd::File(self.file.clone())),
iovecs: Some(iovecs),
})
.await
.map_err(|_| Error::WriteFailed)?;
let (return_value, return_bufs) = receiver.await.map_err(|_| Error::WriteFailed)?;
bufs = match return_bufs {
Some(OpBuffer::WriteVectored(b)) => b,
_ => unreachable!("io_uring loop returns the same OpBuffer that was submitted"),
};
if should_retry(return_value) {
continue;
}
let op_bytes_written: usize =
return_value.try_into().map_err(|_| Error::WriteFailed)?;
if op_bytes_written == 0 {
return Err(Error::WriteFailed);
}
bufs.advance(op_bytes_written);
offset += op_bytes_written as u64;
}
Ok(())
}
}
impl crate::Blob for Blob {
async fn read_at(&self, offset: u64, len: usize) -> Result<IoBufsMut, Error> {
self.read_at_buf(offset, len, self.pool.alloc(len)).await
}
async fn read_at_buf(
&self,
offset: u64,
len: usize,
bufs: impl Into<IoBufsMut> + Send,
) -> Result<IoBufsMut, Error> {
let mut input_bufs = bufs.into();
unsafe { input_bufs.set_len(len) };
let (mut io_buf, original_bufs) = if input_bufs.is_single() {
(input_bufs.coalesce(), None)
} else {
let tmp = unsafe { self.pool.alloc_len(len) };
(tmp, Some(input_bufs))
};
let mut bytes_read = 0;
let offset = offset
.checked_add(Header::SIZE_U64)
.ok_or(Error::OffsetOverflow)?;
while bytes_read < len {
let ptr = unsafe { io_buf.as_mut_ptr().add(bytes_read) };
let remaining_len = len - bytes_read;
let offset = offset + bytes_read as u64;
let op = opcode::Read::new(self.as_raw_fd(), ptr, remaining_len as _)
.offset(offset as _)
.build();
let (sender, receiver) = oneshot::channel();
self.io_submitter
.send(iouring::Op {
work: op,
sender,
buffer: Some(OpBuffer::Read(io_buf)),
fd: Some(OpFd::File(self.file.clone())),
iovecs: None,
})
.await
.map_err(|_| Error::ReadFailed)?;
let (return_value, return_buf) = receiver.await.map_err(|_| Error::ReadFailed)?;
io_buf = match return_buf {
Some(OpBuffer::Read(b)) => b,
_ => unreachable!("io_uring loop returns the same OpBuffer that was submitted"),
};
if should_retry(return_value) {
continue;
}
let op_bytes_read: usize = return_value.try_into().map_err(|_| Error::ReadFailed)?;
if op_bytes_read == 0 {
return Err(Error::BlobInsufficientLength);
}
bytes_read += op_bytes_read;
}
match original_bufs {
None => Ok(io_buf.into()),
Some(mut bufs) => {
bufs.copy_from_slice(io_buf.as_ref());
Ok(bufs)
}
}
}
async fn write_at(&self, offset: u64, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
let bufs = bufs.into();
let offset = offset
.checked_add(Header::SIZE_U64)
.ok_or(Error::OffsetOverflow)?;
match bufs.try_into_single() {
Ok(buf) => self.write_single_at(offset, buf).await,
Err(bufs) => self.write_vectored_at(offset, bufs).await,
}
}
async fn resize(&self, len: u64) -> Result<(), Error> {
let len = len
.checked_add(Header::SIZE_U64)
.ok_or(Error::OffsetOverflow)?;
self.file.set_len(len).map_err(|e| {
Error::BlobResizeFailed(self.partition.clone(), hex(&self.name), IoError::other(e))
})
}
async fn sync(&self) -> Result<(), Error> {
loop {
let op = opcode::Fsync::new(self.as_raw_fd()).build();
let (sender, receiver) = oneshot::channel();
self.io_submitter
.send(iouring::Op {
work: op,
sender,
buffer: None,
fd: Some(OpFd::File(self.file.clone())),
iovecs: None,
})
.await
.map_err(|_| {
Error::BlobSyncFailed(
self.partition.clone(),
hex(&self.name),
IoError::other("failed to send work"),
)
})?;
let (return_value, _) = receiver.await.map_err(|_| {
Error::BlobSyncFailed(
self.partition.clone(),
hex(&self.name),
IoError::other("failed to read result"),
)
})?;
if should_retry(return_value) {
continue;
}
if return_value < 0 {
return Err(Error::BlobSyncFailed(
self.partition.clone(),
hex(&self.name),
IoError::other(format!("error code: {return_value}")),
));
}
return Ok(());
}
}
}
#[cfg(test)]
mod tests {
use super::{Header, *};
use crate::{
storage::tests::run_storage_tests, Blob, BufferPool, BufferPoolConfig, Storage as _,
};
use rand::{Rng as _, SeedableRng as _};
use std::env;
fn create_test_storage() -> (Storage, PathBuf) {
let mut rng = rand::rngs::StdRng::from_entropy();
let storage_directory =
env::temp_dir().join(format!("commonware_iouring_storage_{}", rng.gen::<u64>()));
let pool = BufferPool::new(BufferPoolConfig::for_storage(), &mut Registry::default());
let storage = Storage::start(
Config {
storage_directory: storage_directory.clone(),
iouring_config: Default::default(),
},
&mut Registry::default(),
pool,
);
(storage, storage_directory)
}
#[tokio::test]
async fn test_iouring_storage() {
let (storage, storage_directory) = create_test_storage();
run_storage_tests(storage).await;
let _ = std::fs::remove_dir_all(storage_directory);
}
#[tokio::test]
async fn test_blob_header_handling() {
let (storage, storage_directory) = create_test_storage();
let (blob, size) = storage.open("partition", b"test").await.unwrap();
assert_eq!(size, 0, "new blob should have logical size 0");
let file_path = storage_directory.join("partition").join(hex(b"test"));
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64,
"raw file should have 8-byte header"
);
let data = b"hello world";
blob.write_at(0, data.to_vec()).await.unwrap();
blob.sync().await.unwrap();
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(metadata.len(), Header::SIZE_U64 + data.len() as u64);
let raw_content = std::fs::read(&file_path).unwrap();
assert_eq!(&raw_content[..Header::MAGIC_LENGTH], &Header::MAGIC);
assert_eq!(
&raw_content[Header::MAGIC_LENGTH..Header::MAGIC_LENGTH + Header::VERSION_LENGTH],
&Header::RUNTIME_VERSION.to_be_bytes()
);
assert_eq!(&raw_content[Header::SIZE..], data);
let read_buf = blob.read_at(0, data.len()).await.unwrap().coalesce();
assert_eq!(read_buf, data);
blob.resize(5).await.unwrap();
blob.sync().await.unwrap();
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64 + 5,
"resize(5) should result in 13 raw bytes"
);
blob.resize(0).await.unwrap();
blob.sync().await.unwrap();
let metadata = std::fs::metadata(&file_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64,
"resize(0) should leave only header"
);
blob.write_at(0, b"test data".to_vec()).await.unwrap();
blob.sync().await.unwrap();
drop(blob);
let (blob2, size2) = storage.open("partition", b"test").await.unwrap();
assert_eq!(size2, 9, "reopened blob should have logical size 9");
let read_buf = blob2.read_at(0, 9).await.unwrap().coalesce();
assert_eq!(read_buf, b"test data");
drop(blob2);
let corrupted_path = storage_directory.join("partition").join(hex(b"corrupted"));
std::fs::write(&corrupted_path, vec![0u8; 4]).unwrap();
let (blob3, size3) = storage.open("partition", b"corrupted").await.unwrap();
assert_eq!(size3, 0, "corrupted blob should return logical size 0");
let metadata = std::fs::metadata(&corrupted_path).unwrap();
assert_eq!(
metadata.len(),
Header::SIZE_U64,
"corrupted blob should be reset to header-only"
);
drop(blob3);
let _ = std::fs::remove_dir_all(&storage_directory);
}
#[tokio::test]
async fn test_blob_magic_mismatch() {
let (storage, storage_directory) = create_test_storage();
let partition_path = storage_directory.join("partition");
std::fs::create_dir_all(&partition_path).unwrap();
let bad_magic_path = partition_path.join(hex(b"bad_magic"));
std::fs::write(&bad_magic_path, vec![0u8; Header::SIZE]).unwrap();
let result = storage.open("partition", b"bad_magic").await;
assert!(
matches!(result, Err(crate::Error::BlobCorrupt(_, _, reason)) if reason.contains("invalid magic"))
);
let _ = std::fs::remove_dir_all(&storage_directory);
}
}