use crate::{
merkle::{
batch,
hasher::Hasher,
mem::{Config as MemConfig, Mem},
Error, Family, Location, Position,
},
metadata::{Config as MConfig, Metadata},
Context,
};
use commonware_codec::DecodeExt;
use commonware_cryptography::Digest;
use commonware_parallel::Strategy;
use commonware_utils::{
sequence::prefixed_u64::U64,
sync::{AsyncMutex, RwLock},
};
use std::sync::Arc;
pub struct UnmerkleizedBatch<F: Family, D: Digest, S: Strategy> {
inner: batch::UnmerkleizedBatch<F, D, S>,
}
impl<F: Family, D: Digest, S: Strategy> UnmerkleizedBatch<F, D, S> {
pub(crate) const fn wrap(inner: batch::UnmerkleizedBatch<F, D, S>) -> Self {
Self { inner }
}
pub fn add(self, hasher: &impl Hasher<F, Digest = D>, element: &[u8]) -> Self {
Self {
inner: self.inner.add(hasher, element),
}
}
pub fn add_leaf_digest(self, digest: D) -> Self {
Self {
inner: self.inner.add_leaf_digest(digest),
}
}
pub fn leaves(&self) -> Location<F> {
self.inner.leaves()
}
pub fn merkleize(
self,
base: &Mem<F, D>,
hasher: &impl Hasher<F, Digest = D>,
) -> Arc<batch::MerkleizedBatch<F, D, S>> {
self.inner.merkleize(base, hasher)
}
}
#[derive(Clone)]
pub struct Config<S: Strategy> {
pub partition: String,
pub strategy: S,
}
pub struct Merkle<F: Family, E: Context, D: Digest, S: Strategy> {
inner: RwLock<Mem<F, D>>,
metadata: AsyncMutex<Metadata<E, U64, Vec<u8>>>,
sync_lock: AsyncMutex<()>,
strategy: S,
active_slot: RwLock<u8>,
}
const GEN_PTR_PREFIX: u8 = 0;
const SLOT_A_SIZE_PREFIX: u8 = 1;
const SLOT_A_NODE_PREFIX: u8 = 2;
const SLOT_B_SIZE_PREFIX: u8 = 3;
const SLOT_B_NODE_PREFIX: u8 = 4;
const fn size_prefix(slot: u8) -> u8 {
if slot == 0 {
SLOT_A_SIZE_PREFIX
} else {
SLOT_B_SIZE_PREFIX
}
}
const fn node_prefix(slot: u8) -> u8 {
if slot == 0 {
SLOT_A_NODE_PREFIX
} else {
SLOT_B_NODE_PREFIX
}
}
impl<F: Family, E: Context, D: Digest, S: Strategy> Merkle<F, E, D, S> {
const fn validate_persisted_leaves(leaves: Location<F>) -> Result<(), Error<F>> {
if !leaves.is_valid() {
return Err(Error::DataCorrupted("slot size exceeds MAX_LEAVES"));
}
Ok(())
}
fn read_gen_ptr(metadata: &Metadata<E, U64, Vec<u8>>) -> Result<Option<u8>, Error<F>> {
let Some(raw) = metadata.get(&U64::new(GEN_PTR_PREFIX, 0)) else {
return Ok(None);
};
if raw.len() != 1 || (raw[0] != 0 && raw[0] != 1) {
return Err(Error::DataCorrupted("invalid generation pointer"));
}
Ok(Some(raw[0]))
}
fn read_slot_size(
metadata: &Metadata<E, U64, Vec<u8>>,
slot: u8,
) -> Result<Option<Location<F>>, Error<F>> {
let Some(raw) = metadata.get(&U64::new(size_prefix(slot), 0)) else {
return Ok(None);
};
let bytes: [u8; 8] = raw
.as_slice()
.try_into()
.map_err(|_| Error::DataCorrupted("slot size is not 8 bytes"))?;
let leaves = Location::new(u64::from_be_bytes(bytes));
Self::validate_persisted_leaves(leaves)?;
Ok(Some(leaves))
}
fn clear_slot_pins(metadata: &mut Metadata<E, U64, Vec<u8>>, slot: u8, leaves: Location<F>) {
let pin_count = F::nodes_to_pin(leaves).count();
for i in 0..pin_count {
metadata.remove(&U64::new(node_prefix(slot), i as u64));
}
}
fn clear_slot(metadata: &mut Metadata<E, U64, Vec<u8>>, slot: u8, leaves: Location<F>) {
Self::clear_slot_pins(metadata, slot, leaves);
metadata.remove(&U64::new(size_prefix(slot), 0));
}
fn load_slot_pins(
metadata: &Metadata<E, U64, Vec<u8>>,
slot: u8,
leaves: Location<F>,
) -> Result<Vec<D>, Error<F>> {
let mut pinned = Vec::new();
for (idx, pos) in F::nodes_to_pin(leaves).enumerate() {
let bytes = metadata
.get(&U64::new(node_prefix(slot), idx as u64))
.ok_or(Error::MissingNode(pos))?;
let digest = D::decode(bytes.as_ref())
.map_err(|_| Error::DataCorrupted("invalid pinned node"))?;
pinned.push(digest);
}
Ok(pinned)
}
pub async fn init(context: E, cfg: Config<S>) -> Result<Self, Error<F>> {
let metadata = Metadata::<_, U64, Vec<u8>>::init(
context.child("compact_metadata"),
MConfig {
partition: cfg.partition,
codec_config: ((0..).into(), ()),
},
)
.await?;
let active_slot = Self::read_gen_ptr(&metadata)?.unwrap_or(0);
let leaves = Self::read_slot_size(&metadata, active_slot)?.unwrap_or(Location::new(0));
let mem = if leaves == 0 {
Mem::new()
} else {
Mem::init(MemConfig {
nodes: vec![],
pruning_boundary: leaves,
pinned_nodes: Self::load_slot_pins(&metadata, active_slot, leaves)?,
})?
};
Ok(Self {
inner: RwLock::new(mem),
metadata: AsyncMutex::new(metadata),
sync_lock: AsyncMutex::new(()),
strategy: cfg.strategy,
active_slot: RwLock::new(active_slot),
})
}
pub(crate) async fn init_from_compact_state(
context: E,
cfg: Config<S>,
leaves: Location<F>,
pinned_nodes: Vec<D>,
) -> Result<Self, Error<F>> {
Self::validate_persisted_leaves(leaves)?;
if pinned_nodes.len() != F::nodes_to_pin(leaves).count() {
return Err(Error::InvalidPinnedNodes);
}
let mut metadata = Metadata::<_, U64, Vec<u8>>::init(
context.child("compact_metadata"),
MConfig {
partition: cfg.partition,
codec_config: ((0..).into(), ()),
},
)
.await?;
metadata.clear();
let mem = if leaves == 0 {
Mem::new()
} else {
Mem::init(MemConfig {
nodes: vec![],
pruning_boundary: leaves,
pinned_nodes,
})?
};
let merkle = Self {
inner: RwLock::new(mem),
metadata: AsyncMutex::new(metadata),
sync_lock: AsyncMutex::new(()),
strategy: cfg.strategy,
active_slot: RwLock::new(0),
};
Ok(merkle)
}
pub fn root(
&self,
hasher: &impl Hasher<F, Digest = D>,
inactive_peaks: usize,
) -> Result<D, Error<F>> {
self.inner.read().root(hasher, inactive_peaks)
}
pub fn size(&self) -> Position<F> {
self.inner.read().size()
}
pub fn leaves(&self) -> Location<F> {
self.inner.read().leaves()
}
pub const fn strategy(&self) -> &S {
&self.strategy
}
pub(crate) fn active_slot(&self) -> u8 {
*self.active_slot.read()
}
pub fn with_mem<R>(&self, f: impl FnOnce(&Mem<F, D>) -> R) -> R {
let inner = self.inner.read();
f(&inner)
}
pub fn new_batch(&self) -> UnmerkleizedBatch<F, D, S> {
let inner = self.inner.read();
UnmerkleizedBatch::wrap(inner.new_batch_with_strategy(self.strategy.clone()))
}
pub(crate) fn to_batch(&self) -> Arc<batch::MerkleizedBatch<F, D, S>> {
let inner = self.inner.read();
batch::MerkleizedBatch::from_mem_with_strategy(&inner, self.strategy.clone())
}
pub fn apply_batch(&mut self, batch: &batch::MerkleizedBatch<F, D, S>) -> Result<(), Error<F>> {
self.inner.get_mut().apply_batch(batch)
}
pub(crate) async fn read_metadata_key(&self, key: &U64) -> Option<Vec<u8>> {
let metadata = self.metadata.lock().await;
metadata.get(key).cloned()
}
pub(crate) async fn sync_with_witness<W, R>(
&self,
build_witness: impl FnOnce(&Mem<F, D>) -> Result<W, Error<F>>,
update: impl FnOnce(&mut Metadata<E, U64, Vec<u8>>, u8, W) -> Result<R, Error<F>>,
) -> Result<R, Error<F>> {
let _sync_guard = self.sync_lock.lock().await;
let current_slot = *self.active_slot.read();
let target_slot = 1 - current_slot;
let (leaves, pinned_nodes, witness) = {
let inner = self.inner.read();
let leaves = inner.leaves();
let pinned_nodes = F::nodes_to_pin(leaves)
.map(|pos| *inner.get_node_unchecked(pos))
.collect::<Vec<_>>();
let witness = build_witness(&inner)?;
(leaves, pinned_nodes, witness)
};
let result = {
let mut metadata = self.metadata.lock().await;
let old_target_leaves =
Self::read_slot_size(&metadata, target_slot)?.unwrap_or(Location::new(0));
Self::clear_slot_pins(&mut metadata, target_slot, old_target_leaves);
metadata.put(
U64::new(size_prefix(target_slot), 0),
leaves.as_u64().to_be_bytes().to_vec(),
);
for (idx, digest) in pinned_nodes.iter().enumerate() {
metadata.put(
U64::new(node_prefix(target_slot), idx as u64),
digest.to_vec(),
);
}
let result = update(&mut metadata, target_slot, witness)?;
metadata
.put_sync(U64::new(GEN_PTR_PREFIX, 0), vec![target_slot])
.await?;
result
};
*self.active_slot.write() = target_slot;
self.inner.write().prune_all();
Ok(result)
}
pub(crate) async fn rewind(&mut self) -> Result<u8, Error<F>> {
let _sync_guard = self.sync_lock.lock().await;
let current_slot = *self.active_slot.read();
let target_slot = 1 - current_slot;
let (new_leaves, pinned_nodes) = {
let metadata = self.metadata.lock().await;
let Some(new_leaves) = Self::read_slot_size(&metadata, target_slot)? else {
return Err(Error::RewindBeyondHistory);
};
let pinned_nodes = if new_leaves == 0 {
Vec::new()
} else {
Self::load_slot_pins(&metadata, target_slot, new_leaves)?
};
(new_leaves, pinned_nodes)
};
let new_mem = if new_leaves == 0 {
Mem::new()
} else {
Mem::init(MemConfig {
nodes: vec![],
pruning_boundary: new_leaves,
pinned_nodes,
})?
};
{
let mut metadata = self.metadata.lock().await;
let old_current_leaves =
Self::read_slot_size(&metadata, current_slot)?.unwrap_or(Location::new(0));
Self::clear_slot(&mut metadata, current_slot, old_current_leaves);
metadata
.put_sync(U64::new(GEN_PTR_PREFIX, 0), vec![target_slot])
.await?;
}
*self.inner.write() = new_mem;
*self.active_slot.write() = target_slot;
Ok(target_slot)
}
pub async fn sync(&self) -> Result<(), Error<F>> {
self.sync_with_witness(|_| Ok(()), |_, _, ()| Ok(()))
.await
.map(|_| ())
}
pub async fn commit(&self) -> Result<(), Error<F>> {
self.sync().await
}
pub async fn destroy(self) -> Result<(), Error<F>> {
self.metadata.into_inner().destroy().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
merkle::{hasher::Standard as StandardHasher, mmb, mmr, Bagging::ForwardFold},
metadata::{Config as MConfig, Metadata},
};
use commonware_cryptography::Sha256;
use commonware_parallel::Sequential;
use commonware_runtime::{deterministic, Runner as _, Supervisor as _};
type TestMerkle<F> = Merkle<
F,
deterministic::Context,
<Sha256 as commonware_cryptography::Hasher>::Digest,
Sequential,
>;
async fn open<F: Family>(context: deterministic::Context, partition: &str) -> TestMerkle<F> {
TestMerkle::<F>::init(
context,
Config {
partition: partition.into(),
strategy: Sequential,
},
)
.await
.unwrap()
}
async fn append_and_sync<F: Family>(merkle: &mut TestMerkle<F>, values: &[&[u8]]) {
let hasher = StandardHasher::<Sha256>::new(ForwardFold);
let batch = {
let mut b = merkle.new_batch();
for v in values {
b = b.add(&hasher, v);
}
merkle.with_mem(|mem| b.merkleize(mem, &hasher))
};
merkle.apply_batch(&batch).unwrap();
merkle.sync().await.unwrap();
}
async fn assert_reopen_and_continue<F: Family>(
context: deterministic::Context,
partition: &str,
) {
let hasher = StandardHasher::<Sha256>::new(ForwardFold);
let cfg = Config {
partition: partition.into(),
strategy: Sequential,
};
let mut merkle = TestMerkle::<F>::init(context.child("first"), cfg.clone())
.await
.unwrap();
let batch = {
let batch = merkle.new_batch().add(&hasher, b"a").add(&hasher, b"b");
merkle.with_mem(|mem| batch.merkleize(mem, &hasher))
};
merkle.apply_batch(&batch).unwrap();
let root_before = merkle.root(&hasher, 0).unwrap();
let leaves_before = merkle.leaves();
merkle.sync().await.unwrap();
drop(merkle);
let mut reopened = TestMerkle::<F>::init(context.child("second"), cfg)
.await
.unwrap();
assert_eq!(reopened.root(&hasher, 0).unwrap(), root_before);
assert_eq!(reopened.leaves(), leaves_before);
let batch = {
let batch = reopened.new_batch().add(&hasher, b"c");
reopened.with_mem(|mem| batch.merkleize(mem, &hasher))
};
reopened.apply_batch(&batch).unwrap();
reopened.sync().await.unwrap();
}
#[test]
fn test_compact_reopen_and_continue_mmr() {
deterministic::Runner::default().start(|context| async move {
assert_reopen_and_continue::<mmr::Family>(context, "compact-mmr").await;
});
}
#[test]
fn test_compact_reopen_and_continue_mmb() {
deterministic::Runner::default().start(|context| async move {
assert_reopen_and_continue::<mmb::Family>(context, "compact-mmb").await;
});
}
async fn assert_rewind_restores_prior_state<F: Family>(
context: deterministic::Context,
partition: &str,
) {
let hasher = StandardHasher::<Sha256>::new(ForwardFold);
let mut merkle = open::<F>(context, partition).await;
append_and_sync(&mut merkle, &[b"a", b"b"]).await;
let root_after_first = merkle.root(&hasher, 0).unwrap();
let leaves_after_first = merkle.leaves();
append_and_sync(&mut merkle, &[b"c"]).await;
assert_ne!(merkle.root(&hasher, 0).unwrap(), root_after_first);
merkle.rewind().await.unwrap();
assert_eq!(merkle.root(&hasher, 0).unwrap(), root_after_first);
assert_eq!(merkle.leaves(), leaves_after_first);
merkle.destroy().await.unwrap();
}
#[test]
fn test_rewind_restores_prior_state_mmr() {
deterministic::Runner::default().start(|context| async move {
assert_rewind_restores_prior_state::<mmr::Family>(context, "rewind-prior-mmr").await;
});
}
#[test]
fn test_rewind_restores_prior_state_mmb() {
deterministic::Runner::default().start(|context| async move {
assert_rewind_restores_prior_state::<mmb::Family>(context, "rewind-prior-mmb").await;
});
}
#[test]
fn test_rewind_beyond_history_errors() {
deterministic::Runner::default().start(|context| async move {
let mut merkle = open::<mmr::Family>(context, "rewind-beyond").await;
assert!(matches!(
merkle.rewind().await,
Err(Error::RewindBeyondHistory)
));
append_and_sync(&mut merkle, &[b"a"]).await;
assert!(matches!(
merkle.rewind().await,
Err(Error::RewindBeyondHistory)
));
merkle.destroy().await.unwrap();
});
}
#[test]
fn test_rewind_discards_uncommitted() {
deterministic::Runner::default().start(|context| async move {
let hasher = StandardHasher::<Sha256>::new(ForwardFold);
let mut merkle = open::<mmr::Family>(context, "rewind-uncommitted").await;
append_and_sync(&mut merkle, &[b"a"]).await;
append_and_sync(&mut merkle, &[b"b"]).await;
let root_after_two = merkle.root(&hasher, 0).unwrap();
let leaves_after_two = merkle.leaves();
let batch = {
let b = merkle.new_batch().add(&hasher, b"c");
merkle.with_mem(|mem| b.merkleize(mem, &hasher))
};
merkle.apply_batch(&batch).unwrap();
assert_ne!(merkle.root(&hasher, 0).unwrap(), root_after_two);
merkle.rewind().await.unwrap();
assert_ne!(merkle.root(&hasher, 0).unwrap(), root_after_two);
assert_ne!(merkle.leaves(), leaves_after_two);
merkle.destroy().await.unwrap();
});
}
#[test]
fn test_rewind_persists_across_reopen() {
deterministic::Runner::default().start(|context| async move {
let hasher = StandardHasher::<Sha256>::new(ForwardFold);
let partition = "rewind-reopen";
let cfg = Config {
partition: partition.into(),
strategy: Sequential,
};
let mut merkle = open::<mmr::Family>(context.child("first"), partition).await;
append_and_sync(&mut merkle, &[b"a"]).await;
let root_after_first = merkle.root(&hasher, 0).unwrap();
append_and_sync(&mut merkle, &[b"b"]).await;
merkle.rewind().await.unwrap();
drop(merkle);
let reopened: TestMerkle<mmr::Family> =
Merkle::<mmr::Family, _, _, Sequential>::init(context.child("second"), cfg)
.await
.unwrap();
assert_eq!(reopened.root(&hasher, 0).unwrap(), root_after_first);
reopened.destroy().await.unwrap();
});
}
#[test]
fn test_double_rewind_errors() {
deterministic::Runner::default().start(|context| async move {
let mut merkle = open::<mmr::Family>(context, "rewind-double").await;
append_and_sync(&mut merkle, &[b"a"]).await;
append_and_sync(&mut merkle, &[b"b"]).await;
merkle.rewind().await.unwrap();
assert!(matches!(
merkle.rewind().await,
Err(Error::RewindBeyondHistory)
));
merkle.destroy().await.unwrap();
});
}
#[test]
fn test_rewind_then_sync_then_rewind() {
deterministic::Runner::default().start(|context| async move {
let hasher = StandardHasher::<Sha256>::new(ForwardFold);
let mut merkle = open::<mmr::Family>(context, "rewind-resumable").await;
append_and_sync(&mut merkle, &[b"a"]).await;
let root_after_first = merkle.root(&hasher, 0).unwrap();
append_and_sync(&mut merkle, &[b"b"]).await;
merkle.rewind().await.unwrap();
assert_eq!(merkle.root(&hasher, 0).unwrap(), root_after_first);
append_and_sync(&mut merkle, &[b"c"]).await;
let root_abc = merkle.root(&hasher, 0).unwrap();
assert_ne!(root_abc, root_after_first);
merkle.rewind().await.unwrap();
assert_eq!(merkle.root(&hasher, 0).unwrap(), root_after_first);
merkle.destroy().await.unwrap();
});
}
#[test]
fn test_reopen_rejects_invalid_persisted_leaf_count() {
deterministic::Runner::default().start(|context| async move {
let partition = "compact-invalid-leaf-count";
let cfg = Config {
partition: partition.into(),
strategy: Sequential,
};
let mut merkle = TestMerkle::<mmr::Family>::init(context.child("first"), cfg.clone())
.await
.unwrap();
append_and_sync(&mut merkle, &[b"a"]).await;
let slot = merkle.active_slot();
drop(merkle);
let mut metadata = Metadata::<_, U64, Vec<u8>>::init(
context.child("tamper"),
MConfig {
partition: partition.into(),
codec_config: ((0..).into(), ()),
},
)
.await
.unwrap();
metadata
.put_sync(
U64::new(size_prefix(slot), 0),
(mmr::Family::MAX_LEAVES.as_u64() + 1)
.to_be_bytes()
.to_vec(),
)
.await
.unwrap();
let reopened = TestMerkle::<mmr::Family>::init(context.child("second"), cfg).await;
assert!(matches!(
reopened,
Err(Error::DataCorrupted("slot size exceeds MAX_LEAVES"))
));
});
}
}