use super::Header;
use crate::{BufferPool, IoBufs, IoBufsMut};
use commonware_codec::Encode;
use commonware_utils::{
hex,
sync::{Mutex, RwLock},
};
use sha2::{Digest, Sha256};
use std::{collections::BTreeMap, ops::RangeInclusive, sync::Arc};
#[derive(Clone)]
pub struct Storage {
partitions: Arc<Mutex<BTreeMap<String, Partition>>>,
pool: BufferPool,
}
impl Storage {
pub fn new(pool: BufferPool) -> Self {
Self {
partitions: Arc::new(Mutex::new(BTreeMap::new())),
pool,
}
}
}
impl Storage {
pub fn audit(&self) -> [u8; 32] {
let partitions = self.partitions.lock();
let mut hasher = Sha256::new();
for (partition_name, blobs) in partitions.iter() {
for (blob_name, content) in blobs.iter() {
hasher.update(partition_name.as_bytes());
hasher.update(blob_name);
hasher.update(content);
}
}
hasher.finalize().into()
}
}
impl crate::Storage for Storage {
type Blob = Blob;
async fn open_versioned(
&self,
partition: &str,
name: &[u8],
versions: RangeInclusive<u16>,
) -> Result<(Self::Blob, u64, u16), crate::Error> {
super::validate_partition_name(partition)?;
let mut partitions = self.partitions.lock();
let partition_entry = partitions.entry(partition.into()).or_default();
let content = partition_entry.entry(name.into()).or_default();
let raw_len = content.len() as u64;
let (blob_version, logical_len) = if Header::missing(raw_len) {
let (header, blob_version) = Header::new(&versions);
content.clear();
content.extend_from_slice(&header.encode());
(blob_version, 0)
} else {
let mut header_bytes = [0u8; Header::SIZE];
header_bytes.copy_from_slice(&content[..Header::SIZE]);
Header::from(header_bytes, raw_len, &versions)
.map_err(|e| e.into_error(partition, name))?
};
Ok((
Blob::new(
self.partitions.clone(),
partition.into(),
name,
content.clone(),
self.pool.clone(),
),
logical_len,
blob_version,
))
}
async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), crate::Error> {
super::validate_partition_name(partition)?;
let mut partitions = self.partitions.lock();
match name {
Some(name) => {
partitions
.get_mut(partition)
.ok_or(crate::Error::PartitionMissing(partition.into()))?
.remove(name)
.ok_or(crate::Error::BlobMissing(partition.into(), hex(name)))?;
}
None => {
partitions
.remove(partition)
.ok_or(crate::Error::PartitionMissing(partition.into()))?;
}
}
Ok(())
}
async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, crate::Error> {
super::validate_partition_name(partition)?;
let partitions = self.partitions.lock();
let partition = partitions
.get(partition)
.ok_or(crate::Error::PartitionMissing(partition.into()))?;
let mut results = Vec::with_capacity(partition.len());
for name in partition.keys() {
results.push(name.clone());
}
results.sort(); Ok(results)
}
}
type Partition = BTreeMap<Vec<u8>, Vec<u8>>;
#[derive(Clone)]
pub struct Blob {
partitions: Arc<Mutex<BTreeMap<String, Partition>>>,
partition: String,
name: Vec<u8>,
content: Arc<RwLock<Vec<u8>>>,
pool: BufferPool,
}
impl Blob {
fn new(
partitions: Arc<Mutex<BTreeMap<String, Partition>>>,
partition: String,
name: &[u8],
content: Vec<u8>,
pool: BufferPool,
) -> Self {
Self {
partitions,
partition,
name: name.into(),
content: Arc::new(RwLock::new(content)),
pool,
}
}
}
impl crate::Blob for Blob {
async fn read_at(&self, offset: u64, len: usize) -> Result<IoBufsMut, crate::Error> {
self.read_at_buf(offset, len, self.pool.alloc(len)).await
}
async fn read_at_buf(
&self,
offset: u64,
len: usize,
bufs: impl Into<IoBufsMut> + Send,
) -> Result<IoBufsMut, crate::Error> {
let mut bufs = bufs.into();
unsafe { bufs.set_len(len) };
let offset = offset
.checked_add(Header::SIZE_U64)
.ok_or(crate::Error::OffsetOverflow)?;
let offset: usize = offset
.try_into()
.map_err(|_| crate::Error::OffsetOverflow)?;
let content = self.content.read();
let content_len = content.len();
if offset + len > content_len {
return Err(crate::Error::BlobInsufficientLength);
}
bufs.copy_from_slice(&content[offset..offset + len]);
Ok(bufs)
}
async fn write_at(
&self,
offset: u64,
bufs: impl Into<IoBufs> + Send,
) -> Result<(), crate::Error> {
let buf = bufs.into().coalesce();
let offset = offset
.checked_add(Header::SIZE_U64)
.ok_or(crate::Error::OffsetOverflow)?;
let offset: usize = offset
.try_into()
.map_err(|_| crate::Error::OffsetOverflow)?;
let mut content = self.content.write();
let required = offset + buf.len();
if required > content.len() {
content.resize(required, 0);
}
content[offset..offset + buf.len()].copy_from_slice(buf.as_ref());
Ok(())
}
async fn resize(&self, len: u64) -> Result<(), crate::Error> {
let len = len
.checked_add(Header::SIZE_U64)
.ok_or(crate::Error::OffsetOverflow)?;
let len: usize = len.try_into().map_err(|_| crate::Error::OffsetOverflow)?;
let mut content = self.content.write();
content.resize(len, 0);
Ok(())
}
async fn sync(&self) -> Result<(), crate::Error> {
let new_content = self.content.read().clone();
let mut partitions = self.partitions.lock();
let partition = partitions
.get_mut(&self.partition)
.ok_or(crate::Error::PartitionMissing(self.partition.clone()))?;
let content = partition
.get_mut(&self.name)
.ok_or(crate::Error::BlobMissing(
self.partition.clone(),
hex(&self.name),
))?;
*content = new_content;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{Header, *};
use crate::{storage::tests::run_storage_tests, Blob, BufferPoolConfig, Storage as _};
fn test_pool() -> BufferPool {
BufferPool::new(
BufferPoolConfig::for_storage(),
&mut prometheus_client::registry::Registry::default(),
)
}
#[tokio::test]
async fn test_memory_storage() {
let storage = Storage::new(test_pool());
run_storage_tests(storage).await;
}
#[tokio::test]
async fn test_blob_header_handling() {
let storage = Storage::new(test_pool());
let (blob, size) = storage.open("partition", b"test").await.unwrap();
assert_eq!(size, 0, "new blob should have logical size 0");
{
let partitions = storage.partitions.lock();
let partition = partitions.get("partition").unwrap();
let raw_content = partition.get(&b"test".to_vec()).unwrap();
assert_eq!(
raw_content.len(),
Header::SIZE,
"raw storage should have 8-byte header"
);
}
let data = b"hello world";
blob.write_at(0, data).await.unwrap();
blob.sync().await.unwrap();
{
let partitions = storage.partitions.lock();
let partition = partitions.get("partition").unwrap();
let raw_content = partition.get(&b"test".to_vec()).unwrap();
assert_eq!(raw_content.len(), Header::SIZE + data.len());
assert_eq!(&raw_content[..Header::MAGIC_LENGTH], &Header::MAGIC);
assert_eq!(&raw_content[Header::SIZE..], data);
}
let read_buf = blob.read_at(0, data.len()).await.unwrap();
assert_eq!(read_buf.coalesce(), data);
{
let mut partitions = storage.partitions.lock();
let partition = partitions.get_mut("partition").unwrap();
partition.insert(b"corrupted".to_vec(), vec![0u8; 2]);
}
let (_blob, size) = storage.open("partition", b"corrupted").await.unwrap();
assert_eq!(size, 0, "corrupted blob should return logical size 0");
{
let partitions = storage.partitions.lock();
let partition = partitions.get("partition").unwrap();
let raw_content = partition.get(&b"corrupted".to_vec()).unwrap();
assert_eq!(
raw_content.len(),
Header::SIZE,
"corrupted blob should be reset to header-only"
);
}
}
#[tokio::test]
async fn test_blob_magic_mismatch() {
let storage = Storage::new(test_pool());
{
let mut partitions = storage.partitions.lock();
let partition = partitions.entry("partition".into()).or_default();
partition.insert(b"bad_magic".to_vec(), vec![0u8; Header::SIZE]);
}
let result = storage.open("partition", b"bad_magic").await;
assert!(
matches!(result, Err(crate::Error::BlobCorrupt(_, _, reason)) if reason.contains("invalid magic"))
);
}
}