commonware_storage/mmr/
storage.rs

1//! Defines the abstraction allowing MMRs with differing backends and representations to be
2//! uniformly accessed.
3
4use crate::mmr::{
5    bitmap::Bitmap,
6    hasher::{source_pos, Hasher, Standard},
7    iterator::{pos_to_height, PeakIterator},
8    journaled::Mmr as JournaledMmr,
9    mem::Mmr as MemMmr,
10    Error,
11};
12use commonware_cryptography::{Digest, Hasher as CHasher};
13use commonware_runtime::{Clock, Metrics, Storage as RStorage};
14use futures::future::try_join_all;
15use std::future::Future;
16
17/// A trait for accessing MMR digests from storage.
18pub trait Storage<D: Digest>: Send + Sync {
19    /// Return the number of elements in the MMR.
20    fn size(&self) -> u64;
21
22    /// Return the specified node of the MMR if it exists & hasn't been pruned.
23    fn get_node(&self, position: u64) -> impl Future<Output = Result<Option<D>, Error>> + Send;
24}
25
26impl<H: CHasher> Storage<H::Digest> for MemMmr<H>
27where
28    H: CHasher,
29{
30    fn size(&self) -> u64 {
31        self.size()
32    }
33
34    async fn get_node(&self, position: u64) -> Result<Option<H::Digest>, Error> {
35        Ok(MemMmr::get_node(self, position))
36    }
37}
38
39impl<E: RStorage + Clock + Metrics, H: CHasher> Storage<H::Digest> for JournaledMmr<E, H> {
40    fn size(&self) -> u64 {
41        self.size()
42    }
43
44    async fn get_node(&self, position: u64) -> Result<Option<H::Digest>, Error> {
45        self.get_node(position).await
46    }
47}
48
49impl<H: CHasher, const N: usize> Storage<H::Digest> for Bitmap<H, N> {
50    fn size(&self) -> u64 {
51        self.size()
52    }
53
54    async fn get_node(&self, position: u64) -> Result<Option<H::Digest>, Error> {
55        Ok(self.get_node(position))
56    }
57}
58
59/// A [Storage] implementation that makes grafted trees look like a single MMR for conveniently
60/// generating inclusion proofs.
61pub struct Grafting<'a, H: CHasher, S1: Storage<H::Digest>, S2: Storage<H::Digest>> {
62    peak_tree: &'a S1,
63    base_mmr: &'a S2,
64    height: u32,
65
66    _marker: std::marker::PhantomData<H>,
67}
68
69impl<'a, H: CHasher, S1: Storage<H::Digest>, S2: Storage<H::Digest>> Grafting<'a, H, S1, S2> {
70    /// Creates a new [Grafting] Storage instance.
71    pub fn new(peak_tree: &'a S1, base_mmr: &'a S2, height: u32) -> Self {
72        Self {
73            peak_tree,
74            base_mmr,
75            height,
76            _marker: std::marker::PhantomData,
77        }
78    }
79
80    pub async fn root(&self, hasher: &mut Standard<H>) -> Result<H::Digest, Error> {
81        let size = self.size();
82        let peak_futures = PeakIterator::new(size).map(|(peak_pos, _)| self.get_node(peak_pos));
83        let peaks = try_join_all(peak_futures).await?;
84        let unwrapped_peaks = peaks.iter().map(|p| p.as_ref().unwrap());
85        let digest = hasher.root(self.base_mmr.size(), unwrapped_peaks);
86
87        Ok(digest)
88    }
89}
90
91impl<H: CHasher, S1: Storage<H::Digest>, S2: Storage<H::Digest>> Storage<H::Digest>
92    for Grafting<'_, H, S1, S2>
93{
94    fn size(&self) -> u64 {
95        self.base_mmr.size()
96    }
97
98    async fn get_node(&self, pos: u64) -> Result<Option<H::Digest>, Error> {
99        let height = pos_to_height(pos);
100        if height < self.height {
101            return self.base_mmr.get_node(pos).await;
102        }
103
104        let source_pos = source_pos(pos, self.height);
105        let Some(source_pos) = source_pos else {
106            return Ok(None);
107        };
108
109        self.peak_tree.get_node(source_pos).await
110    }
111}