miden_objects/transaction/
chain_mmr.rs1use alloc::{collections::BTreeMap, vec::Vec};
2
3use vm_core::utils::{Deserializable, Serializable};
4
5use crate::{
6 block::{BlockHeader, BlockNumber},
7 crypto::merkle::{InnerNodeInfo, MmrPeaks, PartialMmr},
8 ChainMmrError,
9};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct ChainMmr {
27 mmr: PartialMmr,
29 blocks: BTreeMap<BlockNumber, BlockHeader>,
32}
33
34impl ChainMmr {
35 pub fn new(mmr: PartialMmr, blocks: Vec<BlockHeader>) -> Result<Self, ChainMmrError> {
47 let chain_length = mmr.forest();
48
49 let mut block_map = BTreeMap::new();
50 for block in blocks.into_iter() {
51 if block.block_num().as_usize() >= chain_length {
52 return Err(ChainMmrError::block_num_too_big(chain_length, block.block_num()));
53 }
54
55 if block_map.insert(block.block_num(), block).is_some() {
56 return Err(ChainMmrError::duplicate_block(block.block_num()));
57 }
58
59 if !mmr.is_tracked(block.block_num().as_usize()) {
60 return Err(ChainMmrError::untracked_block(block.block_num()));
61 }
62 }
63
64 Ok(Self { mmr, blocks: block_map })
65 }
66
67 pub fn peaks(&self) -> MmrPeaks {
72 self.mmr.peaks()
73 }
74
75 pub fn chain_length(&self) -> BlockNumber {
77 BlockNumber::from(
78 u32::try_from(self.mmr.forest())
79 .expect("chain mmr should never contain more than u32::MAX blocks"),
80 )
81 }
82
83 pub fn contains_block(&self, block_num: BlockNumber) -> bool {
85 self.blocks.contains_key(&block_num)
86 }
87
88 pub fn get_block(&self, block_num: BlockNumber) -> Option<&BlockHeader> {
91 self.blocks.get(&block_num)
92 }
93
94 pub fn add_block(&mut self, block_header: BlockHeader, track: bool) {
107 assert_eq!(block_header.block_num(), self.chain_length());
108 self.mmr.add(block_header.hash(), track);
109 }
110
111 pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
117 self.mmr.inner_nodes(
118 self.blocks.values().map(|block| (block.block_num().as_usize(), block.hash())),
119 )
120 }
121}
122
123impl Serializable for ChainMmr {
124 fn write_into<W: miden_crypto::utils::ByteWriter>(&self, target: &mut W) {
125 self.mmr.write_into(target);
126 self.blocks.len().write_into(target);
127 for block in self.blocks.values() {
128 block.write_into(target);
129 }
130 }
131}
132
133impl Deserializable for ChainMmr {
134 fn read_from<R: miden_crypto::utils::ByteReader>(
135 source: &mut R,
136 ) -> Result<Self, miden_crypto::utils::DeserializationError> {
137 let mmr = PartialMmr::read_from(source)?;
138 let block_count = usize::read_from(source)?;
139 let mut blocks = BTreeMap::new();
140 for _ in 0..block_count {
141 let block = BlockHeader::read_from(source)?;
142 blocks.insert(block.block_num(), block);
143 }
144 Ok(Self { mmr, blocks })
145 }
146}
147#[cfg(test)]
151mod tests {
152 use vm_core::utils::{Deserializable, Serializable};
153
154 use super::ChainMmr;
155 use crate::{
156 alloc::vec::Vec,
157 block::{BlockHeader, BlockNumber},
158 crypto::merkle::{Mmr, PartialMmr},
159 Digest,
160 };
161
162 #[test]
163 fn test_chain_mmr_add() {
164 let mut mmr = Mmr::default();
166 for i in 0..3 {
167 let block_header = int_to_block_header(i);
168 mmr.add(block_header.hash());
169 }
170 let partial_mmr: PartialMmr = mmr.peaks().into();
171 let mut chain_mmr = ChainMmr::new(partial_mmr, Vec::new()).unwrap();
172
173 let block_num = 3;
175 let bock_header = int_to_block_header(block_num);
176 mmr.add(bock_header.hash());
177 chain_mmr.add_block(bock_header, true);
178
179 assert_eq!(
180 mmr.open(block_num as usize).unwrap(),
181 chain_mmr.mmr.open(block_num as usize).unwrap().unwrap()
182 );
183
184 let block_num = 4;
186 let bock_header = int_to_block_header(block_num);
187 mmr.add(bock_header.hash());
188 chain_mmr.add_block(bock_header, true);
189
190 assert_eq!(
191 mmr.open(block_num as usize).unwrap(),
192 chain_mmr.mmr.open(block_num as usize).unwrap().unwrap()
193 );
194
195 let block_num = 5;
197 let bock_header = int_to_block_header(block_num);
198 mmr.add(bock_header.hash());
199 chain_mmr.add_block(bock_header, true);
200
201 assert_eq!(
202 mmr.open(block_num as usize).unwrap(),
203 chain_mmr.mmr.open(block_num as usize).unwrap().unwrap()
204 );
205 }
206
207 #[test]
208 fn tst_chain_mmr_serialization() {
209 let mut mmr = Mmr::default();
211 for i in 0..3 {
212 let block_header = int_to_block_header(i);
213 mmr.add(block_header.hash());
214 }
215 let partial_mmr: PartialMmr = mmr.peaks().into();
216 let chain_mmr = ChainMmr::new(partial_mmr, Vec::new()).unwrap();
217
218 let bytes = chain_mmr.to_bytes();
219 let deserialized = ChainMmr::read_from_bytes(&bytes).unwrap();
220
221 assert_eq!(chain_mmr, deserialized);
222 }
223
224 fn int_to_block_header(block_num: impl Into<BlockNumber>) -> BlockHeader {
225 BlockHeader::new(
226 0,
227 Digest::default(),
228 block_num.into(),
229 Digest::default(),
230 Digest::default(),
231 Digest::default(),
232 Digest::default(),
233 Digest::default(),
234 Digest::default(),
235 Digest::default(),
236 0,
237 )
238 }
239}