use bytes::{Buf, BufMut};
use commonware_codec::{DecodeExt, FixedSize, Read as CodecRead, Write as CodecWrite};
use commonware_utils::hex;
use std::ops::RangeInclusive;
#[derive(Debug)]
pub(crate) enum HeaderError {
InvalidMagic {
expected: [u8; 4],
found: [u8; 4],
},
UnsupportedRuntimeVersion {
expected: u16,
found: u16,
},
VersionMismatch {
expected: RangeInclusive<u16>,
found: u16,
},
}
impl HeaderError {
pub(crate) fn into_error(self, partition: &str, name: &[u8]) -> crate::Error {
match self {
Self::InvalidMagic { expected, found } => crate::Error::BlobCorrupt(
partition.into(),
hex(name),
format!("invalid magic: expected {expected:?}, found {found:?}"),
),
Self::UnsupportedRuntimeVersion { expected, found } => crate::Error::BlobCorrupt(
partition.into(),
hex(name),
format!("unsupported runtime version: expected {expected}, found {found}"),
),
Self::VersionMismatch { expected, found } => {
crate::Error::BlobVersionMismatch { expected, found }
}
}
}
}
pub mod audited;
#[cfg(feature = "iouring-storage")]
pub mod iouring;
pub mod memory;
pub mod metered;
#[cfg(all(not(target_arch = "wasm32"), not(feature = "iouring-storage")))]
pub mod tokio;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct Header {
magic: [u8; Self::MAGIC_LENGTH],
runtime_version: u16,
pub(crate) blob_version: u16,
}
impl Header {
pub(crate) const SIZE: usize = 8;
pub(crate) const SIZE_U64: u64 = Self::SIZE as u64;
pub(crate) const MAGIC_LENGTH: usize = 4;
#[cfg(test)]
pub(crate) const VERSION_LENGTH: usize = 2;
pub(crate) const MAGIC: [u8; Self::MAGIC_LENGTH] = *b"CWIC";
pub(crate) const RUNTIME_VERSION: u16 = 0;
pub(crate) const fn missing(raw_len: u64) -> bool {
raw_len < Self::SIZE_U64
}
pub(crate) const fn new(versions: &std::ops::RangeInclusive<u16>) -> (Self, u16) {
let blob_version = *versions.end();
let header = Self {
magic: Self::MAGIC,
runtime_version: Self::RUNTIME_VERSION,
blob_version,
};
(header, blob_version)
}
pub(crate) fn from(
raw_bytes: [u8; Self::SIZE],
raw_len: u64,
versions: &RangeInclusive<u16>,
) -> Result<(u16, u64), HeaderError> {
let header: Self = Self::decode(raw_bytes.as_slice())
.expect("header decode should never fail for correct size input");
header.validate(versions)?;
Ok((header.blob_version, raw_len - Self::SIZE_U64))
}
pub(crate) fn validate(&self, blob_versions: &RangeInclusive<u16>) -> Result<(), HeaderError> {
if self.magic != Self::MAGIC {
return Err(HeaderError::InvalidMagic {
expected: Self::MAGIC,
found: self.magic,
});
}
if self.runtime_version != Self::RUNTIME_VERSION {
return Err(HeaderError::UnsupportedRuntimeVersion {
expected: Self::RUNTIME_VERSION,
found: self.runtime_version,
});
}
if !blob_versions.contains(&self.blob_version) {
return Err(HeaderError::VersionMismatch {
expected: blob_versions.clone(),
found: self.blob_version,
});
}
Ok(())
}
}
impl FixedSize for Header {
const SIZE: usize = Self::SIZE;
}
impl CodecWrite for Header {
fn write(&self, buf: &mut impl BufMut) {
buf.put_slice(&self.magic);
buf.put_u16(self.runtime_version);
buf.put_u16(self.blob_version);
}
}
impl CodecRead for Header {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
if buf.remaining() < Self::SIZE {
return Err(commonware_codec::Error::EndOfBuffer);
}
let mut magic = [0u8; Self::MAGIC_LENGTH];
buf.copy_to_slice(&mut magic);
let runtime_version = buf.get_u16();
let blob_version = buf.get_u16();
Ok(Self {
magic,
runtime_version,
blob_version,
})
}
}
#[cfg(feature = "arbitrary")]
impl arbitrary::Arbitrary<'_> for Header {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
let version: u16 = u.arbitrary()?;
Ok(Self::new(&(version..=version)).0)
}
}
pub fn validate_partition_name(partition: &str) -> Result<(), crate::Error> {
if partition.is_empty()
|| partition
.chars()
.any(|c| !(c.is_ascii_alphanumeric() || ['_', '-'].contains(&c)))
{
return Err(crate::Error::PartitionNameInvalid(partition.into()));
}
Ok(())
}
#[cfg(test)]
pub(crate) mod tests {
use super::{Header, HeaderError};
use crate::{Blob, Storage};
use commonware_codec::{DecodeExt, Encode};
#[test]
fn test_header_fields() {
let (header, _) = Header::new(&(42..=42));
assert_eq!(header.magic, Header::MAGIC);
assert_eq!(header.runtime_version, Header::RUNTIME_VERSION);
assert_eq!(header.blob_version, 42);
}
#[test]
fn test_header_validate_success() {
let (header, _) = Header::new(&(5..=5));
assert!(header.validate(&(3..=7)).is_ok());
assert!(header.validate(&(5..=5)).is_ok());
}
#[test]
fn test_header_validate_magic_mismatch() {
let (mut header, _) = Header::new(&(5..=5));
header.magic = *b"XXXX";
let result = header.validate(&(3..=7));
assert!(matches!(
result,
Err(HeaderError::InvalidMagic { expected, found })
if expected == Header::MAGIC && found == *b"XXXX"
));
}
#[test]
fn test_header_validate_runtime_version_mismatch() {
let (mut header, _) = Header::new(&(5..=5));
header.runtime_version = 99;
let result = header.validate(&(3..=7));
assert!(matches!(
result,
Err(HeaderError::UnsupportedRuntimeVersion { expected, found })
if expected == Header::RUNTIME_VERSION && found == 99
));
}
#[test]
fn test_header_validate_blob_version_out_of_range() {
let (header, _) = Header::new(&(10..=10));
let result = header.validate(&(3..=7));
assert!(matches!(
result,
Err(HeaderError::VersionMismatch { expected, found })
if expected == (3..=7) && found == 10
));
}
#[test]
fn test_header_bytes_round_trip() {
let (header, _) = Header::new(&(123..=123));
let bytes = header.encode();
let decoded: Header = Header::decode(bytes.as_ref()).unwrap();
assert_eq!(header, decoded);
}
#[cfg(feature = "arbitrary")]
mod conformance {
use super::Header;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<Header>
}
}
pub(crate) async fn run_storage_tests<S>(storage: S)
where
S: Storage + Send + Sync + 'static,
S::Blob: Send + Sync,
{
test_open_and_write(&storage).await;
test_remove(&storage).await;
test_scan(&storage).await;
test_concurrent_access(&storage).await;
test_large_data(&storage).await;
test_overwrite_data(&storage).await;
test_read_beyond_bound(&storage).await;
test_write_at_large_offset(&storage).await;
test_append_data(&storage).await;
test_sequential_read_write(&storage).await;
test_sequential_chunk_read_write(&storage).await;
test_read_empty_blob(&storage).await;
test_overlapping_writes(&storage).await;
test_resize_then_open(&storage).await;
test_partition_name_validation(&storage).await;
test_blob_version_mismatch(&storage).await;
}
async fn test_open_and_write<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, len) = storage.open("partition", b"test_blob").await.unwrap();
assert_eq!(len, 0);
blob.write_at(Vec::from("hello world"), 0).await.unwrap();
let read = blob.read_at(vec![0; 11], 0).await.unwrap();
assert_eq!(
read.as_ref(),
b"hello world",
"Blob content does not match expected value"
);
}
async fn test_remove<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
storage.open("partition", b"test_blob").await.unwrap();
storage
.remove("partition", Some(b"test_blob"))
.await
.unwrap();
let blobs = storage.scan("partition").await.unwrap();
assert!(blobs.is_empty(), "Blob was not removed as expected");
}
async fn test_scan<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
storage.open("partition", b"blob1").await.unwrap();
storage.open("partition", b"blob2").await.unwrap();
let blobs = storage.scan("partition").await.unwrap();
assert_eq!(
blobs.len(),
2,
"Scan did not return the expected number of blobs"
);
assert!(
blobs.contains(&b"blob1".to_vec()),
"Blob1 is missing from scan results"
);
assert!(
blobs.contains(&b"blob2".to_vec()),
"Blob2 is missing from scan results"
);
}
async fn test_concurrent_access<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage.open("partition", b"test_blob").await.unwrap();
blob.write_at(b"concurrent write".to_vec(), 0)
.await
.unwrap();
let write_task = tokio::spawn({
let blob = blob.clone();
async move {
blob.write_at(b"concurrent write".to_vec(), 0)
.await
.unwrap();
}
});
let read_task = tokio::spawn({
let blob = blob.clone();
async move { blob.read_at(vec![0; 16], 0).await.unwrap() }
});
write_task.await.unwrap();
let buffer = read_task.await.unwrap();
assert_eq!(
buffer.as_ref(),
b"concurrent write",
"Concurrent access failed"
);
}
async fn test_large_data<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage.open("partition", b"large_blob").await.unwrap();
let large_data = vec![42u8; 10 * 1024 * 1024]; blob.write_at(large_data.clone(), 0).await.unwrap();
let read = blob.read_at(vec![0; 10 * 1024 * 1024], 0).await.unwrap();
assert_eq!(read.as_ref(), large_data, "Large data read/write failed");
}
async fn test_overwrite_data<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage
.open("test_overwrite_data", b"test_blob")
.await
.unwrap();
blob.write_at(b"initial data".to_vec(), 0).await.unwrap();
blob.write_at(b"overwrite".to_vec(), 8).await.unwrap();
let read = blob.read_at(vec![0; 17], 0).await.unwrap();
assert_eq!(
read.as_ref(),
b"initial overwrite",
"Data was not overwritten correctly"
);
}
async fn test_read_beyond_bound<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage
.open("test_read_beyond_written_data", b"test_blob")
.await
.unwrap();
blob.write_at(b"hello".to_vec(), 0).await.unwrap();
let result = blob.read_at(vec![0; 10], 6).await;
assert!(
result.is_err(),
"Reading beyond written data should return an error"
);
}
async fn test_write_at_large_offset<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage
.open("test_write_at_large_offset", b"test_blob")
.await
.unwrap();
blob.write_at(b"offset data".to_vec(), 10_000)
.await
.unwrap();
let read = blob.read_at(vec![0; 11], 10_000).await.unwrap();
assert_eq!(
read.as_ref(),
b"offset data",
"Data at large offset is incorrect"
);
}
async fn test_append_data<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage
.open("test_append_data", b"test_blob")
.await
.unwrap();
blob.write_at(b"first".to_vec(), 0).await.unwrap();
blob.write_at(b"second".to_vec(), 5).await.unwrap();
let read = blob.read_at(vec![0; 11], 0).await.unwrap();
assert_eq!(read.as_ref(), b"firstsecond", "Appended data is incorrect");
}
async fn test_sequential_read_write<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage.open("partition", b"test_blob").await.unwrap();
blob.write_at(b"first".to_vec(), 0).await.unwrap();
blob.write_at(b"second".to_vec(), 10).await.unwrap();
let read = blob.read_at(vec![0; 5], 0).await.unwrap();
assert_eq!(read.as_ref(), b"first", "Data at offset 0 is incorrect");
let read = blob.read_at(vec![0; 6], 10).await.unwrap();
assert_eq!(read.as_ref(), b"second", "Data at offset 10 is incorrect");
}
async fn test_sequential_chunk_read_write<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage
.open("test_large_data_in_chunks", b"large_blob")
.await
.unwrap();
let chunk_size = 1024 * 1024; let num_chunks = 10;
let data = vec![7u8; chunk_size];
for i in 0..num_chunks {
blob.write_at(data.clone(), (i * chunk_size) as u64)
.await
.unwrap();
}
let mut read = vec![0u8; chunk_size].into();
for i in 0..num_chunks {
read = blob.read_at(read, (i * chunk_size) as u64).await.unwrap();
assert_eq!(read.as_ref(), data, "Chunk {i} is incorrect");
}
}
async fn test_read_empty_blob<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage
.open("test_read_empty_blob", b"empty_blob")
.await
.unwrap();
let result = blob.read_at(vec![0; 1], 0).await;
assert!(
result.is_err(),
"Reading from an empty blob should return an error"
);
}
async fn test_overlapping_writes<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _) = storage
.open("test_overlapping_writes", b"test_blob")
.await
.unwrap();
blob.write_at(b"overlap".to_vec(), 0).await.unwrap();
blob.write_at(b"map".to_vec(), 4).await.unwrap();
let read = blob.read_at(vec![0; 7], 0).await.unwrap();
assert_eq!(
read.as_ref(),
b"overmap",
"Overlapping writes are incorrect"
);
}
async fn test_resize_then_open<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
{
let (blob, _) = storage
.open("test_resize_then_open", b"test_blob")
.await
.unwrap();
blob.write_at(b"hello world".to_vec(), 0).await.unwrap();
blob.resize(5).await.unwrap();
blob.sync().await.unwrap();
}
let (blob, len) = storage
.open("test_resize_then_open", b"test_blob")
.await
.unwrap();
assert_eq!(len, 5, "Blob length after resize is incorrect");
let read = blob.read_at(vec![0; 5], 0).await.unwrap();
assert_eq!(read.as_ref(), b"hello", "Resized data is incorrect");
}
async fn test_partition_name_validation<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
for valid in [
"partition",
"my_partition",
"my-partition",
"partition123",
"A1",
] {
assert!(
!matches!(
storage.open(valid, b"blob").await,
Err(crate::Error::PartitionNameInvalid(_))
),
"Valid partition name '{valid}' should be accepted by open"
);
assert!(
!matches!(
storage.remove(valid, None).await,
Err(crate::Error::PartitionNameInvalid(_))
),
"Valid partition name '{valid}' should be accepted by remove"
);
assert!(
!matches!(
storage.scan(valid).await,
Err(crate::Error::PartitionNameInvalid(_))
),
"Valid partition name '{valid}' should be accepted by scan"
);
}
for invalid in [
"my/partition",
"my.partition",
"my partition",
"../escape",
"",
] {
assert!(
matches!(
storage.open(invalid, b"blob").await,
Err(crate::Error::PartitionNameInvalid(_))
),
"Invalid partition name '{invalid}' should be rejected by open"
);
assert!(
matches!(
storage.remove(invalid, None).await,
Err(crate::Error::PartitionNameInvalid(_))
),
"Invalid partition name '{invalid}' should be rejected by remove"
);
assert!(
matches!(
storage.scan(invalid).await,
Err(crate::Error::PartitionNameInvalid(_))
),
"Invalid partition name '{invalid}' should be rejected by scan"
);
}
}
async fn test_blob_version_mismatch<S>(storage: &S)
where
S: Storage + Send + Sync,
S::Blob: Send + Sync,
{
let (blob, _, version) = storage
.open_versioned("test_version_mismatch", b"blob", 1..=1)
.await
.unwrap();
assert_eq!(version, 1);
blob.sync().await.unwrap();
drop(blob);
let (_, _, version) = storage
.open_versioned("test_version_mismatch", b"blob", 0..=2)
.await
.unwrap();
assert_eq!(version, 1);
let result = storage
.open_versioned("test_version_mismatch", b"blob", 2..=3)
.await;
assert!(
matches!(
result,
Err(crate::Error::BlobVersionMismatch { expected, found })
if expected == (2..=3) && found == 1
),
"Expected BlobVersionMismatch error"
);
}
}