1use crate::{
23 merkle::{
24 batch,
25 hasher::Hasher,
26 mem::{Config as MemConfig, Mem},
27 Error, Family, Location, Position,
28 },
29 metadata::{Config as MConfig, Metadata},
30 Context,
31};
32use commonware_codec::DecodeExt;
33use commonware_cryptography::Digest;
34use commonware_parallel::Strategy;
35use commonware_utils::{
36 sequence::prefixed_u64::U64,
37 sync::{AsyncMutex, RwLock},
38};
39use std::sync::Arc;
40
41pub struct UnmerkleizedBatch<F: Family, D: Digest, S: Strategy> {
43 inner: batch::UnmerkleizedBatch<F, D, S>,
44}
45
46impl<F: Family, D: Digest, S: Strategy> UnmerkleizedBatch<F, D, S> {
47 pub(crate) const fn wrap(inner: batch::UnmerkleizedBatch<F, D, S>) -> Self {
49 Self { inner }
50 }
51
52 pub fn add(self, hasher: &impl Hasher<F, Digest = D>, element: &[u8]) -> Self {
54 Self {
55 inner: self.inner.add(hasher, element),
56 }
57 }
58
59 pub fn add_leaf_digest(self, digest: D) -> Self {
61 Self {
62 inner: self.inner.add_leaf_digest(digest),
63 }
64 }
65
66 pub fn leaves(&self) -> Location<F> {
68 self.inner.leaves()
69 }
70
71 pub fn merkleize(
73 self,
74 base: &Mem<F, D>,
75 hasher: &impl Hasher<F, Digest = D>,
76 ) -> Arc<batch::MerkleizedBatch<F, D, S>> {
77 self.inner.merkleize(base, hasher)
78 }
79}
80
81#[derive(Clone)]
83pub struct Config<S: Strategy> {
84 pub partition: String,
86
87 pub strategy: S,
89}
90
91pub struct Merkle<F: Family, E: Context, D: Digest, S: Strategy> {
93 inner: RwLock<Mem<F, D>>,
94 metadata: AsyncMutex<Metadata<E, U64, Vec<u8>>>,
95 sync_lock: AsyncMutex<()>,
96 strategy: S,
97 active_slot: RwLock<u8>,
100}
101
102const GEN_PTR_PREFIX: u8 = 0;
106const SLOT_A_SIZE_PREFIX: u8 = 1;
107const SLOT_A_NODE_PREFIX: u8 = 2;
108const SLOT_B_SIZE_PREFIX: u8 = 3;
109const SLOT_B_NODE_PREFIX: u8 = 4;
110
111const fn size_prefix(slot: u8) -> u8 {
112 if slot == 0 {
113 SLOT_A_SIZE_PREFIX
114 } else {
115 SLOT_B_SIZE_PREFIX
116 }
117}
118
119const fn node_prefix(slot: u8) -> u8 {
120 if slot == 0 {
121 SLOT_A_NODE_PREFIX
122 } else {
123 SLOT_B_NODE_PREFIX
124 }
125}
126
127impl<F: Family, E: Context, D: Digest, S: Strategy> Merkle<F, E, D, S> {
128 const fn validate_persisted_leaves(leaves: Location<F>) -> Result<(), Error<F>> {
129 if !leaves.is_valid() {
130 return Err(Error::DataCorrupted("slot size exceeds MAX_LEAVES"));
131 }
132 Ok(())
133 }
134
135 fn read_gen_ptr(metadata: &Metadata<E, U64, Vec<u8>>) -> Result<Option<u8>, Error<F>> {
137 let Some(raw) = metadata.get(&U64::new(GEN_PTR_PREFIX, 0)) else {
138 return Ok(None);
139 };
140 if raw.len() != 1 || (raw[0] != 0 && raw[0] != 1) {
141 return Err(Error::DataCorrupted("invalid generation pointer"));
142 }
143 Ok(Some(raw[0]))
144 }
145
146 fn read_slot_size(
148 metadata: &Metadata<E, U64, Vec<u8>>,
149 slot: u8,
150 ) -> Result<Option<Location<F>>, Error<F>> {
151 let Some(raw) = metadata.get(&U64::new(size_prefix(slot), 0)) else {
152 return Ok(None);
153 };
154 let bytes: [u8; 8] = raw
155 .as_slice()
156 .try_into()
157 .map_err(|_| Error::DataCorrupted("slot size is not 8 bytes"))?;
158 let leaves = Location::new(u64::from_be_bytes(bytes));
159 Self::validate_persisted_leaves(leaves)?;
160 Ok(Some(leaves))
161 }
162
163 fn clear_slot_pins(metadata: &mut Metadata<E, U64, Vec<u8>>, slot: u8, leaves: Location<F>) {
165 let pin_count = F::nodes_to_pin(leaves).count();
166 for i in 0..pin_count {
167 metadata.remove(&U64::new(node_prefix(slot), i as u64));
168 }
169 }
170
171 fn clear_slot(metadata: &mut Metadata<E, U64, Vec<u8>>, slot: u8, leaves: Location<F>) {
174 Self::clear_slot_pins(metadata, slot, leaves);
175 metadata.remove(&U64::new(size_prefix(slot), 0));
176 }
177
178 fn load_slot_pins(
179 metadata: &Metadata<E, U64, Vec<u8>>,
180 slot: u8,
181 leaves: Location<F>,
182 ) -> Result<Vec<D>, Error<F>> {
183 let mut pinned = Vec::new();
184 for (idx, pos) in F::nodes_to_pin(leaves).enumerate() {
185 let bytes = metadata
186 .get(&U64::new(node_prefix(slot), idx as u64))
187 .ok_or(Error::MissingNode(pos))?;
188 let digest = D::decode(bytes.as_ref())
189 .map_err(|_| Error::DataCorrupted("invalid pinned node"))?;
190 pinned.push(digest);
191 }
192 Ok(pinned)
193 }
194
195 pub async fn init(context: E, cfg: Config<S>) -> Result<Self, Error<F>> {
197 let metadata = Metadata::<_, U64, Vec<u8>>::init(
198 context.child("compact_metadata"),
199 MConfig {
200 partition: cfg.partition,
201 codec_config: ((0..).into(), ()),
202 },
203 )
204 .await?;
205
206 let active_slot = Self::read_gen_ptr(&metadata)?.unwrap_or(0);
207 let leaves = Self::read_slot_size(&metadata, active_slot)?.unwrap_or(Location::new(0));
208 let mem = if leaves == 0 {
209 Mem::new()
210 } else {
211 Mem::init(MemConfig {
212 nodes: vec![],
213 pruning_boundary: leaves,
214 pinned_nodes: Self::load_slot_pins(&metadata, active_slot, leaves)?,
215 })?
216 };
217
218 Ok(Self {
219 inner: RwLock::new(mem),
220 metadata: AsyncMutex::new(metadata),
221 sync_lock: AsyncMutex::new(()),
222 strategy: cfg.strategy,
223 active_slot: RwLock::new(active_slot),
224 })
225 }
226
227 pub(crate) async fn init_from_compact_state(
242 context: E,
243 cfg: Config<S>,
244 leaves: Location<F>,
245 pinned_nodes: Vec<D>,
246 ) -> Result<Self, Error<F>> {
247 Self::validate_persisted_leaves(leaves)?;
248 if pinned_nodes.len() != F::nodes_to_pin(leaves).count() {
249 return Err(Error::InvalidPinnedNodes);
250 }
251
252 let mut metadata = Metadata::<_, U64, Vec<u8>>::init(
253 context.child("compact_metadata"),
254 MConfig {
255 partition: cfg.partition,
256 codec_config: ((0..).into(), ()),
257 },
258 )
259 .await?;
260 metadata.clear();
261
262 let mem = if leaves == 0 {
263 Mem::new()
264 } else {
265 Mem::init(MemConfig {
266 nodes: vec![],
267 pruning_boundary: leaves,
268 pinned_nodes,
269 })?
270 };
271
272 let merkle = Self {
273 inner: RwLock::new(mem),
274 metadata: AsyncMutex::new(metadata),
275 sync_lock: AsyncMutex::new(()),
276 strategy: cfg.strategy,
277 active_slot: RwLock::new(0),
278 };
279 Ok(merkle)
280 }
281
282 pub fn root(
284 &self,
285 hasher: &impl Hasher<F, Digest = D>,
286 inactive_peaks: usize,
287 ) -> Result<D, Error<F>> {
288 self.inner.read().root(hasher, inactive_peaks)
289 }
290
291 pub fn size(&self) -> Position<F> {
293 self.inner.read().size()
294 }
295
296 pub fn leaves(&self) -> Location<F> {
298 self.inner.read().leaves()
299 }
300
301 pub const fn strategy(&self) -> &S {
303 &self.strategy
304 }
305
306 pub(crate) fn active_slot(&self) -> u8 {
308 *self.active_slot.read()
309 }
310
311 pub fn with_mem<R>(&self, f: impl FnOnce(&Mem<F, D>) -> R) -> R {
313 let inner = self.inner.read();
314 f(&inner)
315 }
316
317 pub fn new_batch(&self) -> UnmerkleizedBatch<F, D, S> {
319 let inner = self.inner.read();
320 UnmerkleizedBatch::wrap(inner.new_batch_with_strategy(self.strategy.clone()))
321 }
322
323 pub(crate) fn to_batch(&self) -> Arc<batch::MerkleizedBatch<F, D, S>> {
325 let inner = self.inner.read();
326 batch::MerkleizedBatch::from_mem_with_strategy(&inner, self.strategy.clone())
327 }
328
329 pub fn apply_batch(&mut self, batch: &batch::MerkleizedBatch<F, D, S>) -> Result<(), Error<F>> {
331 self.inner.get_mut().apply_batch(batch)
332 }
333
334 pub(crate) async fn read_metadata_key(&self, key: &U64) -> Option<Vec<u8>> {
337 let metadata = self.metadata.lock().await;
338 metadata.get(key).cloned()
339 }
340
341 pub(crate) async fn sync_with_witness<W, R>(
358 &self,
359 build_witness: impl FnOnce(&Mem<F, D>) -> Result<W, Error<F>>,
360 update: impl FnOnce(&mut Metadata<E, U64, Vec<u8>>, u8, W) -> Result<R, Error<F>>,
361 ) -> Result<R, Error<F>> {
362 let _sync_guard = self.sync_lock.lock().await;
363
364 let current_slot = *self.active_slot.read();
365 let target_slot = 1 - current_slot;
366
367 let (leaves, pinned_nodes, witness) = {
368 let inner = self.inner.read();
369 let leaves = inner.leaves();
370 let pinned_nodes = F::nodes_to_pin(leaves)
371 .map(|pos| *inner.get_node_unchecked(pos))
372 .collect::<Vec<_>>();
373 let witness = build_witness(&inner)?;
374 (leaves, pinned_nodes, witness)
375 };
376
377 let result = {
378 let mut metadata = self.metadata.lock().await;
379 let old_target_leaves =
380 Self::read_slot_size(&metadata, target_slot)?.unwrap_or(Location::new(0));
381 Self::clear_slot_pins(&mut metadata, target_slot, old_target_leaves);
382 metadata.put(
383 U64::new(size_prefix(target_slot), 0),
384 leaves.as_u64().to_be_bytes().to_vec(),
385 );
386 for (idx, digest) in pinned_nodes.iter().enumerate() {
387 metadata.put(
388 U64::new(node_prefix(target_slot), idx as u64),
389 digest.to_vec(),
390 );
391 }
392 let result = update(&mut metadata, target_slot, witness)?;
393 metadata
394 .put_sync(U64::new(GEN_PTR_PREFIX, 0), vec![target_slot])
395 .await?;
396 result
397 };
398
399 *self.active_slot.write() = target_slot;
400 self.inner.write().prune_all();
401 Ok(result)
402 }
403
404 pub(crate) async fn rewind(&mut self) -> Result<u8, Error<F>> {
415 let _sync_guard = self.sync_lock.lock().await;
416
417 let current_slot = *self.active_slot.read();
418 let target_slot = 1 - current_slot;
419
420 let (new_leaves, pinned_nodes) = {
421 let metadata = self.metadata.lock().await;
422 let Some(new_leaves) = Self::read_slot_size(&metadata, target_slot)? else {
423 return Err(Error::RewindBeyondHistory);
424 };
425 let pinned_nodes = if new_leaves == 0 {
426 Vec::new()
427 } else {
428 Self::load_slot_pins(&metadata, target_slot, new_leaves)?
429 };
430 (new_leaves, pinned_nodes)
431 };
432
433 let new_mem = if new_leaves == 0 {
435 Mem::new()
436 } else {
437 Mem::init(MemConfig {
438 nodes: vec![],
439 pruning_boundary: new_leaves,
440 pinned_nodes,
441 })?
442 };
443
444 {
451 let mut metadata = self.metadata.lock().await;
452 let old_current_leaves =
453 Self::read_slot_size(&metadata, current_slot)?.unwrap_or(Location::new(0));
454 Self::clear_slot(&mut metadata, current_slot, old_current_leaves);
455 metadata
456 .put_sync(U64::new(GEN_PTR_PREFIX, 0), vec![target_slot])
457 .await?;
458 }
459
460 *self.inner.write() = new_mem;
461 *self.active_slot.write() = target_slot;
462 Ok(target_slot)
463 }
464
465 pub async fn sync(&self) -> Result<(), Error<F>> {
467 self.sync_with_witness(|_| Ok(()), |_, _, ()| Ok(()))
468 .await
469 .map(|_| ())
470 }
471
472 pub async fn commit(&self) -> Result<(), Error<F>> {
474 self.sync().await
475 }
476
477 pub async fn destroy(self) -> Result<(), Error<F>> {
479 self.metadata.into_inner().destroy().await?;
480 Ok(())
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use crate::{
488 merkle::{hasher::Standard as StandardHasher, mmb, mmr, Bagging::ForwardFold},
489 metadata::{Config as MConfig, Metadata},
490 };
491 use commonware_cryptography::Sha256;
492 use commonware_parallel::Sequential;
493 use commonware_runtime::{deterministic, Runner as _, Supervisor as _};
494
495 type TestMerkle<F> = Merkle<
496 F,
497 deterministic::Context,
498 <Sha256 as commonware_cryptography::Hasher>::Digest,
499 Sequential,
500 >;
501
502 async fn open<F: Family>(context: deterministic::Context, partition: &str) -> TestMerkle<F> {
503 TestMerkle::<F>::init(
504 context,
505 Config {
506 partition: partition.into(),
507 strategy: Sequential,
508 },
509 )
510 .await
511 .unwrap()
512 }
513
514 async fn append_and_sync<F: Family>(merkle: &mut TestMerkle<F>, values: &[&[u8]]) {
515 let hasher = StandardHasher::<Sha256>::new(ForwardFold);
516 let batch = {
517 let mut b = merkle.new_batch();
518 for v in values {
519 b = b.add(&hasher, v);
520 }
521 merkle.with_mem(|mem| b.merkleize(mem, &hasher))
522 };
523 merkle.apply_batch(&batch).unwrap();
524 merkle.sync().await.unwrap();
525 }
526
527 async fn assert_reopen_and_continue<F: Family>(
528 context: deterministic::Context,
529 partition: &str,
530 ) {
531 let hasher = StandardHasher::<Sha256>::new(ForwardFold);
532 let cfg = Config {
533 partition: partition.into(),
534 strategy: Sequential,
535 };
536
537 let mut merkle = TestMerkle::<F>::init(context.child("first"), cfg.clone())
538 .await
539 .unwrap();
540 let batch = {
541 let batch = merkle.new_batch().add(&hasher, b"a").add(&hasher, b"b");
542 merkle.with_mem(|mem| batch.merkleize(mem, &hasher))
543 };
544 merkle.apply_batch(&batch).unwrap();
545 let root_before = merkle.root(&hasher, 0).unwrap();
546 let leaves_before = merkle.leaves();
547 merkle.sync().await.unwrap();
548 drop(merkle);
549
550 let mut reopened = TestMerkle::<F>::init(context.child("second"), cfg)
551 .await
552 .unwrap();
553 assert_eq!(reopened.root(&hasher, 0).unwrap(), root_before);
554 assert_eq!(reopened.leaves(), leaves_before);
555
556 let batch = {
557 let batch = reopened.new_batch().add(&hasher, b"c");
558 reopened.with_mem(|mem| batch.merkleize(mem, &hasher))
559 };
560 reopened.apply_batch(&batch).unwrap();
561 reopened.sync().await.unwrap();
562 }
563
564 #[test]
565 fn test_compact_reopen_and_continue_mmr() {
566 deterministic::Runner::default().start(|context| async move {
567 assert_reopen_and_continue::<mmr::Family>(context, "compact-mmr").await;
568 });
569 }
570
571 #[test]
572 fn test_compact_reopen_and_continue_mmb() {
573 deterministic::Runner::default().start(|context| async move {
574 assert_reopen_and_continue::<mmb::Family>(context, "compact-mmb").await;
575 });
576 }
577
578 async fn assert_rewind_restores_prior_state<F: Family>(
579 context: deterministic::Context,
580 partition: &str,
581 ) {
582 let hasher = StandardHasher::<Sha256>::new(ForwardFold);
583 let mut merkle = open::<F>(context, partition).await;
584
585 append_and_sync(&mut merkle, &[b"a", b"b"]).await;
586 let root_after_first = merkle.root(&hasher, 0).unwrap();
587 let leaves_after_first = merkle.leaves();
588
589 append_and_sync(&mut merkle, &[b"c"]).await;
590 assert_ne!(merkle.root(&hasher, 0).unwrap(), root_after_first);
591
592 merkle.rewind().await.unwrap();
593 assert_eq!(merkle.root(&hasher, 0).unwrap(), root_after_first);
594 assert_eq!(merkle.leaves(), leaves_after_first);
595
596 merkle.destroy().await.unwrap();
597 }
598
599 #[test]
600 fn test_rewind_restores_prior_state_mmr() {
601 deterministic::Runner::default().start(|context| async move {
602 assert_rewind_restores_prior_state::<mmr::Family>(context, "rewind-prior-mmr").await;
603 });
604 }
605
606 #[test]
607 fn test_rewind_restores_prior_state_mmb() {
608 deterministic::Runner::default().start(|context| async move {
609 assert_rewind_restores_prior_state::<mmb::Family>(context, "rewind-prior-mmb").await;
610 });
611 }
612
613 #[test]
614 fn test_rewind_beyond_history_errors() {
615 deterministic::Runner::default().start(|context| async move {
616 let mut merkle = open::<mmr::Family>(context, "rewind-beyond").await;
617 assert!(matches!(
619 merkle.rewind().await,
620 Err(Error::RewindBeyondHistory)
621 ));
622 append_and_sync(&mut merkle, &[b"a"]).await;
625 assert!(matches!(
626 merkle.rewind().await,
627 Err(Error::RewindBeyondHistory)
628 ));
629 merkle.destroy().await.unwrap();
630 });
631 }
632
633 #[test]
634 fn test_rewind_discards_uncommitted() {
635 deterministic::Runner::default().start(|context| async move {
636 let hasher = StandardHasher::<Sha256>::new(ForwardFold);
637 let mut merkle = open::<mmr::Family>(context, "rewind-uncommitted").await;
638
639 append_and_sync(&mut merkle, &[b"a"]).await;
640 append_and_sync(&mut merkle, &[b"b"]).await;
641 let root_after_two = merkle.root(&hasher, 0).unwrap();
642 let leaves_after_two = merkle.leaves();
643
644 let batch = {
646 let b = merkle.new_batch().add(&hasher, b"c");
647 merkle.with_mem(|mem| b.merkleize(mem, &hasher))
648 };
649 merkle.apply_batch(&batch).unwrap();
650 assert_ne!(merkle.root(&hasher, 0).unwrap(), root_after_two);
651
652 merkle.rewind().await.unwrap();
655 assert_ne!(merkle.root(&hasher, 0).unwrap(), root_after_two);
656 assert_ne!(merkle.leaves(), leaves_after_two);
657
658 merkle.destroy().await.unwrap();
659 });
660 }
661
662 #[test]
663 fn test_rewind_persists_across_reopen() {
664 deterministic::Runner::default().start(|context| async move {
665 let hasher = StandardHasher::<Sha256>::new(ForwardFold);
666 let partition = "rewind-reopen";
667 let cfg = Config {
668 partition: partition.into(),
669 strategy: Sequential,
670 };
671
672 let mut merkle = open::<mmr::Family>(context.child("first"), partition).await;
673 append_and_sync(&mut merkle, &[b"a"]).await;
674 let root_after_first = merkle.root(&hasher, 0).unwrap();
675 append_and_sync(&mut merkle, &[b"b"]).await;
676 merkle.rewind().await.unwrap();
677 drop(merkle);
678
679 let reopened: TestMerkle<mmr::Family> =
680 Merkle::<mmr::Family, _, _, Sequential>::init(context.child("second"), cfg)
681 .await
682 .unwrap();
683 assert_eq!(reopened.root(&hasher, 0).unwrap(), root_after_first);
684 reopened.destroy().await.unwrap();
685 });
686 }
687
688 #[test]
689 fn test_double_rewind_errors() {
690 deterministic::Runner::default().start(|context| async move {
691 let mut merkle = open::<mmr::Family>(context, "rewind-double").await;
692 append_and_sync(&mut merkle, &[b"a"]).await;
693 append_and_sync(&mut merkle, &[b"b"]).await;
694 merkle.rewind().await.unwrap();
695 assert!(matches!(
696 merkle.rewind().await,
697 Err(Error::RewindBeyondHistory)
698 ));
699 merkle.destroy().await.unwrap();
700 });
701 }
702
703 #[test]
704 fn test_rewind_then_sync_then_rewind() {
705 deterministic::Runner::default().start(|context| async move {
706 let hasher = StandardHasher::<Sha256>::new(ForwardFold);
707 let mut merkle = open::<mmr::Family>(context, "rewind-resumable").await;
708
709 append_and_sync(&mut merkle, &[b"a"]).await;
710 let root_after_first = merkle.root(&hasher, 0).unwrap();
711 append_and_sync(&mut merkle, &[b"b"]).await;
712 merkle.rewind().await.unwrap();
713 assert_eq!(merkle.root(&hasher, 0).unwrap(), root_after_first);
714
715 append_and_sync(&mut merkle, &[b"c"]).await;
717 let root_abc = merkle.root(&hasher, 0).unwrap();
718 assert_ne!(root_abc, root_after_first);
719 merkle.rewind().await.unwrap();
720 assert_eq!(merkle.root(&hasher, 0).unwrap(), root_after_first);
721
722 merkle.destroy().await.unwrap();
723 });
724 }
725
726 #[test]
727 fn test_reopen_rejects_invalid_persisted_leaf_count() {
728 deterministic::Runner::default().start(|context| async move {
729 let partition = "compact-invalid-leaf-count";
730 let cfg = Config {
731 partition: partition.into(),
732 strategy: Sequential,
733 };
734
735 let mut merkle = TestMerkle::<mmr::Family>::init(context.child("first"), cfg.clone())
736 .await
737 .unwrap();
738 append_and_sync(&mut merkle, &[b"a"]).await;
739 let slot = merkle.active_slot();
740 drop(merkle);
741
742 let mut metadata = Metadata::<_, U64, Vec<u8>>::init(
743 context.child("tamper"),
744 MConfig {
745 partition: partition.into(),
746 codec_config: ((0..).into(), ()),
747 },
748 )
749 .await
750 .unwrap();
751 metadata
752 .put_sync(
753 U64::new(size_prefix(slot), 0),
754 (mmr::Family::MAX_LEAVES.as_u64() + 1)
755 .to_be_bytes()
756 .to_vec(),
757 )
758 .await
759 .unwrap();
760
761 let reopened = TestMerkle::<mmr::Family>::init(context.child("second"), cfg).await;
762 assert!(matches!(
763 reopened,
764 Err(Error::DataCorrupted("slot size exceeds MAX_LEAVES"))
765 ));
766 });
767 }
768}