commonware_storage/mmr/
storage.rs1use crate::mmr::{
5 bitmap::Bitmap,
6 hasher::{source_pos, Hasher, Standard},
7 iterator::{pos_to_height, PeakIterator},
8 journaled::Mmr as JournaledMmr,
9 mem::Mmr as MemMmr,
10 Error,
11};
12use commonware_cryptography::{Digest, Hasher as CHasher};
13use commonware_runtime::{Clock, Metrics, Storage as RStorage};
14use futures::future::try_join_all;
15use std::future::Future;
16
17pub trait Storage<D: Digest>: Send + Sync {
19 fn size(&self) -> u64;
21
22 fn get_node(&self, position: u64) -> impl Future<Output = Result<Option<D>, Error>> + Send;
24}
25
26impl<H: CHasher> Storage<H::Digest> for MemMmr<H>
27where
28 H: CHasher,
29{
30 fn size(&self) -> u64 {
31 self.size()
32 }
33
34 async fn get_node(&self, position: u64) -> Result<Option<H::Digest>, Error> {
35 Ok(MemMmr::get_node(self, position))
36 }
37}
38
39impl<E: RStorage + Clock + Metrics, H: CHasher> Storage<H::Digest> for JournaledMmr<E, H> {
40 fn size(&self) -> u64 {
41 self.size()
42 }
43
44 async fn get_node(&self, position: u64) -> Result<Option<H::Digest>, Error> {
45 self.get_node(position).await
46 }
47}
48
49impl<H: CHasher, const N: usize> Storage<H::Digest> for Bitmap<H, N> {
50 fn size(&self) -> u64 {
51 self.size()
52 }
53
54 async fn get_node(&self, position: u64) -> Result<Option<H::Digest>, Error> {
55 Ok(self.get_node(position))
56 }
57}
58
59pub struct Grafting<'a, H: CHasher, S1: Storage<H::Digest>, S2: Storage<H::Digest>> {
62 peak_tree: &'a S1,
63 base_mmr: &'a S2,
64 height: u32,
65
66 _marker: std::marker::PhantomData<H>,
67}
68
69impl<'a, H: CHasher, S1: Storage<H::Digest>, S2: Storage<H::Digest>> Grafting<'a, H, S1, S2> {
70 pub fn new(peak_tree: &'a S1, base_mmr: &'a S2, height: u32) -> Self {
72 Self {
73 peak_tree,
74 base_mmr,
75 height,
76 _marker: std::marker::PhantomData,
77 }
78 }
79
80 pub async fn root(&self, hasher: &mut Standard<H>) -> Result<H::Digest, Error> {
81 let size = self.size();
82 let peak_futures = PeakIterator::new(size).map(|(peak_pos, _)| self.get_node(peak_pos));
83 let peaks = try_join_all(peak_futures).await?;
84 let unwrapped_peaks = peaks.iter().map(|p| p.as_ref().unwrap());
85 let digest = hasher.root(self.base_mmr.size(), unwrapped_peaks);
86
87 Ok(digest)
88 }
89}
90
91impl<H: CHasher, S1: Storage<H::Digest>, S2: Storage<H::Digest>> Storage<H::Digest>
92 for Grafting<'_, H, S1, S2>
93{
94 fn size(&self) -> u64 {
95 self.base_mmr.size()
96 }
97
98 async fn get_node(&self, pos: u64) -> Result<Option<H::Digest>, Error> {
99 let height = pos_to_height(pos);
100 if height < self.height {
101 return self.base_mmr.get_node(pos).await;
102 }
103
104 let source_pos = source_pos(pos, self.height);
105 let Some(source_pos) = source_pos else {
106 return Ok(None);
107 };
108
109 self.peak_tree.get_node(source_pos).await
110 }
111}