use crate::{deterministic::Auditor, Error, IoBufs, IoBufsMut};
use sha2::digest::Update;
use std::sync::Arc;
#[derive(Clone)]
pub struct Storage<S: crate::Storage> {
inner: S,
auditor: Arc<Auditor>,
}
impl<S: crate::Storage> Storage<S> {
pub const fn new(inner: S, auditor: Arc<Auditor>) -> Self {
Self { inner, auditor }
}
pub const fn inner(&self) -> &S {
&self.inner
}
}
impl<S: crate::Storage> crate::Storage for Storage<S> {
type Blob = Blob<S::Blob>;
async fn open_versioned(
&self,
partition: &str,
name: &[u8],
versions: std::ops::RangeInclusive<u16>,
) -> Result<(Self::Blob, u64, u16), Error> {
self.auditor.event(b"open", |hasher| {
hasher.update(partition.as_bytes());
hasher.update(name);
hasher.update(&versions.start().to_be_bytes());
hasher.update(&versions.end().to_be_bytes());
});
self.inner
.open_versioned(partition, name, versions)
.await
.map(|(blob, len, blob_version)| {
(
Blob {
auditor: self.auditor.clone(),
inner: blob,
partition: partition.into(),
name: name.to_vec(),
},
len,
blob_version,
)
})
}
async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
self.auditor.event(b"remove", |hasher| {
hasher.update(partition.as_bytes());
if let Some(name) = name {
hasher.update(name);
}
});
self.inner.remove(partition, name).await
}
async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
self.auditor.event(b"scan", |hasher| {
hasher.update(partition.as_bytes());
});
self.inner.scan(partition).await
}
}
#[derive(Clone)]
pub struct Blob<B: crate::Blob> {
auditor: Arc<Auditor>,
partition: String,
name: Vec<u8>,
inner: B,
}
impl<B: crate::Blob> crate::Blob for Blob<B> {
async fn read_at(&self, offset: u64, len: usize) -> Result<IoBufsMut, Error> {
self.auditor.event(b"read_at", |hasher| {
hasher.update(self.partition.as_bytes());
hasher.update(&self.name);
hasher.update(&offset.to_be_bytes());
hasher.update(&len.to_be_bytes());
});
self.inner.read_at(offset, len).await
}
async fn read_at_buf(
&self,
offset: u64,
len: usize,
bufs: impl Into<IoBufsMut> + Send,
) -> Result<IoBufsMut, Error> {
let bufs = bufs.into();
self.auditor.event(b"read_at_buf", |hasher| {
hasher.update(self.partition.as_bytes());
hasher.update(&self.name);
hasher.update(&offset.to_be_bytes());
hasher.update(&len.to_be_bytes());
});
self.inner.read_at_buf(offset, len, bufs).await
}
async fn write_at(&self, offset: u64, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
let bufs = bufs.into();
self.auditor.event(b"write_at", |hasher| {
hasher.update(self.partition.as_bytes());
hasher.update(&self.name);
hasher.update(&offset.to_be_bytes());
bufs.for_each_chunk(|chunk| hasher.update(chunk));
});
self.inner.write_at(offset, bufs).await
}
async fn resize(&self, len: u64) -> Result<(), Error> {
self.auditor.event(b"resize", |hasher| {
hasher.update(self.partition.as_bytes());
hasher.update(&self.name);
hasher.update(&len.to_be_bytes());
});
self.inner.resize(len).await
}
async fn sync(&self) -> Result<(), Error> {
self.auditor.event(b"sync", |hasher| {
hasher.update(self.partition.as_bytes());
hasher.update(&self.name);
});
self.inner.sync().await
}
}
#[cfg(test)]
mod tests {
use crate::{
storage::{
audited::Storage as AuditedStorage, memory::Storage as MemStorage,
tests::run_storage_tests,
},
Blob as _, BufferPool, BufferPoolConfig, Error, IoBuf, IoBufs, IoBufsMut, Storage as _,
};
use commonware_utils::sync::Mutex;
use std::sync::Arc;
fn test_pool() -> BufferPool {
BufferPool::new(
BufferPoolConfig::for_storage(),
&mut prometheus_client::registry::Registry::default(),
)
}
#[tokio::test]
async fn test_audited_storage() {
let inner = MemStorage::new(test_pool());
let auditor = Arc::new(crate::deterministic::Auditor::default());
let storage = AuditedStorage::new(inner, auditor.clone());
run_storage_tests(storage).await;
}
#[tokio::test]
async fn test_audited_storage_combined() {
use crate::deterministic::Auditor;
let inner1 = MemStorage::new(test_pool());
let auditor1 = Arc::new(Auditor::default());
let storage1 = AuditedStorage::new(inner1, auditor1.clone());
let inner2 = MemStorage::new(test_pool());
let auditor2 = Arc::new(Auditor::default());
let storage2 = AuditedStorage::new(inner2, auditor2.clone());
let (blob1, _) = storage1.open("partition", b"test_blob").await.unwrap();
let (blob2, _) = storage2.open("partition", b"test_blob").await.unwrap();
blob1.write_at(0, b"hello world").await.unwrap();
blob2.write_at(0, b"hello world").await.unwrap();
assert_eq!(
auditor1.state(),
auditor2.state(),
"Hashes do not match after write"
);
let read = blob1.read_at(0, 11).await.unwrap();
assert_eq!(
read.coalesce(),
b"hello world",
"Blob1 content does not match"
);
let read = blob2.read_at(0, 11).await.unwrap();
assert_eq!(
read.coalesce(),
b"hello world",
"Blob2 content does not match"
);
assert_eq!(
auditor1.state(),
auditor2.state(),
"Hashes do not match after read"
);
blob1.resize(5).await.unwrap();
blob2.resize(5).await.unwrap();
assert_eq!(
auditor1.state(),
auditor2.state(),
"Hashes do not match after resize"
);
blob1.sync().await.unwrap();
blob2.sync().await.unwrap();
assert_eq!(
auditor1.state(),
auditor2.state(),
"Hashes do not match after sync"
);
drop(blob1);
drop(blob2);
assert_eq!(
auditor1.state(),
auditor2.state(),
"Hashes do not match after drop"
);
storage1
.remove("partition", Some(b"test_blob"))
.await
.unwrap();
storage2
.remove("partition", Some(b"test_blob"))
.await
.unwrap();
assert_eq!(
auditor1.state(),
auditor2.state(),
"Hashes do not match after remove"
);
let blobs1 = storage1.scan("partition").await.unwrap();
let blobs2 = storage2.scan("partition").await.unwrap();
assert!(
blobs1.is_empty(),
"Partition1 should be empty after blob removal"
);
assert!(
blobs2.is_empty(),
"Partition2 should be empty after blob removal"
);
assert_eq!(
auditor1.state(),
auditor2.state(),
"Hashes do not match after scan"
);
}
#[derive(Clone)]
struct RecordingBlob {
chunk_counts: Arc<Mutex<Vec<usize>>>,
}
impl crate::Blob for RecordingBlob {
async fn read_at(&self, _offset: u64, _len: usize) -> Result<IoBufsMut, Error> {
unreachable!("not used in test");
}
async fn read_at_buf(
&self,
_offset: u64,
_len: usize,
_bufs: impl Into<IoBufsMut> + Send,
) -> Result<IoBufsMut, Error> {
unreachable!("not used in test");
}
async fn write_at(
&self,
_offset: u64,
bufs: impl Into<IoBufs> + Send,
) -> Result<(), Error> {
self.chunk_counts.lock().push(bufs.into().chunk_count());
Ok(())
}
async fn resize(&self, _len: u64) -> Result<(), Error> {
Ok(())
}
async fn sync(&self) -> Result<(), Error> {
Ok(())
}
}
#[tokio::test]
async fn test_audited_blob_write_preserves_chunking() {
let chunk_counts = Arc::new(Mutex::new(Vec::new()));
let blob = super::Blob {
auditor: Arc::new(crate::deterministic::Auditor::default()),
partition: "partition".into(),
name: b"blob".to_vec(),
inner: RecordingBlob {
chunk_counts: chunk_counts.clone(),
},
};
blob.write_at(
0,
IoBufs::from(vec![
IoBuf::from(b"a".to_vec()),
IoBuf::from(b"b".to_vec()),
IoBuf::from(b"c".to_vec()),
IoBuf::from(b"d".to_vec()),
]),
)
.await
.unwrap();
assert_eq!(*chunk_counts.lock(), vec![4]);
}
}