miden_objects/transaction/
chain_mmr.rs1use alloc::collections::BTreeMap;
2
3use crate::{
4 ChainMmrError,
5 block::{BlockHeader, BlockNumber},
6 crypto::merkle::{InnerNodeInfo, MmrPeaks, PartialMmr},
7 utils::serde::{Deserializable, Serializable},
8};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct ChainMmr {
26 mmr: PartialMmr,
28 blocks: BTreeMap<BlockNumber, BlockHeader>,
31}
32
33impl ChainMmr {
34 pub fn new(
46 mmr: PartialMmr,
47 blocks: impl IntoIterator<Item = BlockHeader>,
48 ) -> Result<Self, ChainMmrError> {
49 let chain_length = mmr.forest();
50 let mut block_map = BTreeMap::new();
51 for block in blocks {
52 let block_num = block.block_num();
53 if block.block_num().as_usize() >= chain_length {
54 return Err(ChainMmrError::block_num_too_big(chain_length, block_num));
55 }
56
57 if !mmr.is_tracked(block_num.as_usize()) {
58 return Err(ChainMmrError::untracked_block(block_num));
59 }
60
61 if block_map.insert(block_num, block).is_some() {
62 return Err(ChainMmrError::duplicate_block(block_num));
63 }
64 }
65
66 Ok(Self { mmr, blocks: block_map })
67 }
68
69 pub fn mmr(&self) -> &PartialMmr {
74 &self.mmr
75 }
76
77 pub fn peaks(&self) -> MmrPeaks {
79 self.mmr.peaks()
80 }
81
82 pub fn chain_length(&self) -> BlockNumber {
84 BlockNumber::from(
85 u32::try_from(self.mmr.forest())
86 .expect("chain mmr should never contain more than u32::MAX blocks"),
87 )
88 }
89
90 pub fn contains_block(&self, block_num: BlockNumber) -> bool {
92 self.blocks.contains_key(&block_num)
93 }
94
95 pub fn get_block(&self, block_num: BlockNumber) -> Option<&BlockHeader> {
98 self.blocks.get(&block_num)
99 }
100
101 pub fn block_headers(&self) -> impl Iterator<Item = &BlockHeader> {
103 self.blocks.values()
104 }
105
106 pub fn add_block(&mut self, block_header: BlockHeader, track: bool) {
119 assert_eq!(block_header.block_num(), self.chain_length());
120 self.mmr.add(block_header.commitment(), track);
121 }
122
123 pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
129 self.mmr.inner_nodes(
130 self.blocks
131 .values()
132 .map(|block| (block.block_num().as_usize(), block.commitment())),
133 )
134 }
135
136 #[cfg(any(feature = "testing", test))]
143 pub fn block_headers_mut(&mut self) -> &mut BTreeMap<BlockNumber, BlockHeader> {
144 &mut self.blocks
145 }
146
147 #[cfg(any(feature = "testing", test))]
151 pub fn partial_mmr_mut(&mut self) -> &mut PartialMmr {
152 &mut self.mmr
153 }
154}
155
156impl Serializable for ChainMmr {
157 fn write_into<W: miden_crypto::utils::ByteWriter>(&self, target: &mut W) {
158 self.mmr.write_into(target);
159 self.blocks.write_into(target);
160 }
161}
162
163impl Deserializable for ChainMmr {
164 fn read_from<R: miden_crypto::utils::ByteReader>(
165 source: &mut R,
166 ) -> Result<Self, miden_crypto::utils::DeserializationError> {
167 let mmr = PartialMmr::read_from(source)?;
168 let blocks = BTreeMap::<BlockNumber, BlockHeader>::read_from(source)?;
169 Ok(Self { mmr, blocks })
170 }
171}
172#[cfg(test)]
176mod tests {
177 use vm_core::utils::{Deserializable, Serializable};
178
179 use super::ChainMmr;
180 use crate::{
181 Digest,
182 alloc::vec::Vec,
183 block::{BlockHeader, BlockNumber},
184 crypto::merkle::{Mmr, PartialMmr},
185 };
186
187 #[test]
188 fn test_chain_mmr_add() {
189 let mut mmr = Mmr::default();
191 for i in 0..3 {
192 let block_header = int_to_block_header(i);
193 mmr.add(block_header.commitment());
194 }
195 let partial_mmr: PartialMmr = mmr.peaks().into();
196 let mut chain_mmr = ChainMmr::new(partial_mmr, Vec::new()).unwrap();
197
198 let block_num = 3;
200 let bock_header = int_to_block_header(block_num);
201 mmr.add(bock_header.commitment());
202 chain_mmr.add_block(bock_header, true);
203
204 assert_eq!(
205 mmr.open(block_num as usize).unwrap(),
206 chain_mmr.mmr.open(block_num as usize).unwrap().unwrap()
207 );
208
209 let block_num = 4;
211 let bock_header = int_to_block_header(block_num);
212 mmr.add(bock_header.commitment());
213 chain_mmr.add_block(bock_header, true);
214
215 assert_eq!(
216 mmr.open(block_num as usize).unwrap(),
217 chain_mmr.mmr.open(block_num as usize).unwrap().unwrap()
218 );
219
220 let block_num = 5;
222 let bock_header = int_to_block_header(block_num);
223 mmr.add(bock_header.commitment());
224 chain_mmr.add_block(bock_header, true);
225
226 assert_eq!(
227 mmr.open(block_num as usize).unwrap(),
228 chain_mmr.mmr.open(block_num as usize).unwrap().unwrap()
229 );
230 }
231
232 #[test]
233 fn tst_chain_mmr_serialization() {
234 let mut mmr = Mmr::default();
236 for i in 0..3 {
237 let block_header = int_to_block_header(i);
238 mmr.add(block_header.commitment());
239 }
240 let partial_mmr: PartialMmr = mmr.peaks().into();
241 let chain_mmr = ChainMmr::new(partial_mmr, Vec::new()).unwrap();
242
243 let bytes = chain_mmr.to_bytes();
244 let deserialized = ChainMmr::read_from_bytes(&bytes).unwrap();
245
246 assert_eq!(chain_mmr, deserialized);
247 }
248
249 fn int_to_block_header(block_num: impl Into<BlockNumber>) -> BlockHeader {
250 BlockHeader::new(
251 0,
252 Digest::default(),
253 block_num.into(),
254 Digest::default(),
255 Digest::default(),
256 Digest::default(),
257 Digest::default(),
258 Digest::default(),
259 Digest::default(),
260 Digest::default(),
261 0,
262 )
263 }
264}