miden_core/mast/serialization/
mod.rs1use alloc::vec::Vec;
39
40use decorator::{DecoratorDataBuilder, DecoratorInfo};
41use string_table::StringTable;
42use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
43
44use super::{DecoratorId, MastForest, MastNode, MastNodeId};
45use crate::AdviceMap;
46
47mod decorator;
48
49mod info;
50use info::MastNodeInfo;
51
52mod basic_blocks;
53use basic_blocks::{BasicBlockDataBuilder, BasicBlockDataDecoder};
54
55use crate::DecoratorList;
56
57mod string_table;
58
59#[cfg(test)]
60mod tests;
61
62type NodeDataOffset = u32;
67
68type DecoratorDataOffset = u32;
70
71type StringDataOffset = usize;
73
74type StringIndex = usize;
76
77const MAGIC: &[u8; 5] = b"MAST\0";
82
83const VERSION: [u8; 3] = [0, 0, 0];
89
90impl Serializable for MastForest {
94 fn write_into<W: ByteWriter>(&self, target: &mut W) {
95 let mut basic_block_data_builder = BasicBlockDataBuilder::new();
96
97 let mut before_enter_decorators: Vec<(usize, Vec<DecoratorId>)> = Vec::new();
99 let mut after_exit_decorators: Vec<(usize, Vec<DecoratorId>)> = Vec::new();
100
101 let mut basic_block_decorators: Vec<(usize, Vec<(usize, DecoratorId)>)> = Vec::new();
102
103 target.write_bytes(MAGIC);
105 target.write_bytes(&VERSION);
106
107 target.write_usize(self.nodes.len());
109 target.write_usize(self.decorators.len());
110
111 let roots: Vec<u32> = self.roots.iter().map(u32::from).collect();
113 roots.write_into(target);
114
115 let mast_node_infos: Vec<MastNodeInfo> = self
118 .nodes
119 .iter()
120 .enumerate()
121 .map(|(mast_node_id, mast_node)| {
122 if !mast_node.before_enter().is_empty() {
123 before_enter_decorators.push((mast_node_id, mast_node.before_enter().to_vec()));
124 }
125 if !mast_node.after_exit().is_empty() {
126 after_exit_decorators.push((mast_node_id, mast_node.after_exit().to_vec()));
127 }
128
129 let ops_offset = if let MastNode::Block(basic_block) = mast_node {
130 let ops_offset = basic_block_data_builder.encode_basic_block(basic_block);
131
132 basic_block_decorators.push((mast_node_id, basic_block.decorators().clone()));
133
134 ops_offset
135 } else {
136 0
137 };
138
139 MastNodeInfo::new(mast_node, ops_offset)
140 })
141 .collect();
142
143 let basic_block_data = basic_block_data_builder.finalize();
144 basic_block_data.write_into(target);
145
146 for mast_node_info in mast_node_infos {
148 mast_node_info.write_into(target);
149 }
150
151 self.advice_map.write_into(target);
152
153 let mut decorator_data_builder = DecoratorDataBuilder::new();
156 for decorator in &self.decorators {
157 decorator_data_builder.add_decorator(decorator)
158 }
159
160 let (decorator_data, decorator_infos, string_table) = decorator_data_builder.finalize();
161
162 decorator_data.write_into(target);
164 string_table.write_into(target);
165
166 for decorator_info in decorator_infos {
168 decorator_info.write_into(target);
169 }
170
171 basic_block_decorators.write_into(target);
172
173 before_enter_decorators.write_into(target);
175 after_exit_decorators.write_into(target);
176 }
177}
178
179impl Deserializable for MastForest {
180 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
181 read_and_validate_magic(source)?;
182 read_and_validate_version(source)?;
183
184 let node_count = source.read_usize()?;
186 let decorator_count = source.read_usize()?;
187
188 let roots: Vec<u32> = Deserializable::read_from(source)?;
190
191 let basic_block_data: Vec<u8> = Deserializable::read_from(source)?;
193 let mast_node_infos: Vec<MastNodeInfo> = node_infos_iter(source, node_count)
194 .collect::<Result<Vec<MastNodeInfo>, DeserializationError>>()?;
195
196 let advice_map = AdviceMap::read_from(source)?;
197
198 let decorator_data: Vec<u8> = Deserializable::read_from(source)?;
200 let string_table: StringTable = Deserializable::read_from(source)?;
201 let decorator_infos = decorator_infos_iter(source, decorator_count);
202
203 let mut mast_forest = {
205 let mut mast_forest = MastForest::new();
206
207 for decorator_info in decorator_infos {
208 let decorator_info = decorator_info?;
209 let decorator =
210 decorator_info.try_into_decorator(&string_table, &decorator_data)?;
211
212 mast_forest.add_decorator(decorator).map_err(|e| {
213 DeserializationError::InvalidValue(format!(
214 "failed to add decorator to MAST forest while deserializing: {e}",
215 ))
216 })?;
217 }
218
219 let basic_block_data_decoder = BasicBlockDataDecoder::new(&basic_block_data);
221 for mast_node_info in mast_node_infos {
222 let node =
223 mast_node_info.try_into_mast_node(node_count, &basic_block_data_decoder)?;
224
225 mast_forest.add_node(node).map_err(|e| {
226 DeserializationError::InvalidValue(format!(
227 "failed to add node to MAST forest while deserializing: {e}",
228 ))
229 })?;
230 }
231
232 for root in roots {
234 let root = MastNodeId::from_u32_safe(root, &mast_forest)?;
236 mast_forest.make_root(root);
237 }
238
239 mast_forest.advice_map = advice_map;
240
241 mast_forest
242 };
243
244 let basic_block_decorators: Vec<(usize, DecoratorList)> =
245 read_block_decorators(source, &mast_forest)?;
246 for (node_id, decorator_list) in basic_block_decorators {
247 let node_id = MastNodeId::from_usize_safe(node_id, &mast_forest)?;
248
249 match &mut mast_forest[node_id] {
250 MastNode::Block(basic_block) => {
251 basic_block.set_decorators(decorator_list);
252 },
253 other => {
254 return Err(DeserializationError::InvalidValue(format!(
255 "expected mast node with id {node_id} to be a basic block, found {other:?}"
256 )))
257 },
258 }
259 }
260
261 let before_enter_decorators: Vec<(usize, Vec<DecoratorId>)> =
263 read_before_after_decorators(source, &mast_forest)?;
264 for (node_id, decorator_ids) in before_enter_decorators {
265 let node_id = MastNodeId::from_usize_safe(node_id, &mast_forest)?;
266 mast_forest.set_before_enter(node_id, decorator_ids);
267 }
268
269 let after_exit_decorators: Vec<(usize, Vec<DecoratorId>)> =
270 read_before_after_decorators(source, &mast_forest)?;
271 for (node_id, decorator_ids) in after_exit_decorators {
272 let node_id = MastNodeId::from_usize_safe(node_id, &mast_forest)?;
273 mast_forest.set_after_exit(node_id, decorator_ids);
274 }
275
276 Ok(mast_forest)
277 }
278}
279
280fn read_and_validate_magic<R: ByteReader>(source: &mut R) -> Result<[u8; 5], DeserializationError> {
281 let magic: [u8; 5] = source.read_array()?;
282 if magic != *MAGIC {
283 return Err(DeserializationError::InvalidValue(format!(
284 "Invalid magic bytes. Expected '{:?}', got '{:?}'",
285 *MAGIC, magic
286 )));
287 }
288 Ok(magic)
289}
290
291fn read_and_validate_version<R: ByteReader>(
292 source: &mut R,
293) -> Result<[u8; 3], DeserializationError> {
294 let version: [u8; 3] = source.read_array()?;
295 if version != VERSION {
296 return Err(DeserializationError::InvalidValue(format!(
297 "Unsupported version. Got '{version:?}', but only '{VERSION:?}' is supported",
298 )));
299 }
300 Ok(version)
301}
302
303fn read_block_decorators<R: ByteReader>(
304 source: &mut R,
305 mast_forest: &MastForest,
306) -> Result<Vec<(usize, DecoratorList)>, DeserializationError> {
307 let vec_len: usize = source.read()?;
308 let mut out_vec: Vec<_> = Vec::with_capacity(vec_len);
309
310 for _ in 0..vec_len {
311 let node_id: usize = source.read()?;
312
313 let decorator_vec_len: usize = source.read()?;
314 let mut inner_vec: Vec<(usize, DecoratorId)> = Vec::with_capacity(decorator_vec_len);
315 for _ in 0..decorator_vec_len {
316 let op_id: usize = source.read()?;
317 let decorator_id = DecoratorId::from_u32_safe(source.read()?, mast_forest)?;
318 inner_vec.push((op_id, decorator_id));
319 }
320
321 out_vec.push((node_id, inner_vec));
322 }
323
324 Ok(out_vec)
325}
326
327fn decorator_infos_iter<'a, R>(
328 source: &'a mut R,
329 decorator_count: usize,
330) -> impl Iterator<Item = Result<DecoratorInfo, DeserializationError>> + 'a
331where
332 R: ByteReader + 'a,
333{
334 let mut remaining = decorator_count;
335 core::iter::from_fn(move || {
336 if remaining == 0 {
337 return None;
338 }
339 remaining -= 1;
340 Some(DecoratorInfo::read_from(source))
341 })
342}
343
344fn node_infos_iter<'a, R>(
345 source: &'a mut R,
346 node_count: usize,
347) -> impl Iterator<Item = Result<MastNodeInfo, DeserializationError>> + 'a
348where
349 R: ByteReader + 'a,
350{
351 let mut remaining = node_count;
352 core::iter::from_fn(move || {
353 if remaining == 0 {
354 return None;
355 }
356 remaining -= 1;
357 Some(MastNodeInfo::read_from(source))
358 })
359}
360
361fn read_before_after_decorators<R: ByteReader>(
367 source: &mut R,
368 mast_forest: &MastForest,
369) -> Result<Vec<(usize, Vec<DecoratorId>)>, DeserializationError> {
370 let vec_len: usize = source.read()?;
371 let mut out_vec: Vec<_> = Vec::with_capacity(vec_len);
372
373 for _ in 0..vec_len {
374 let node_id: usize = source.read()?;
375
376 let inner_vec_len: usize = source.read()?;
377 let mut inner_vec: Vec<DecoratorId> = Vec::with_capacity(inner_vec_len);
378 for _ in 0..inner_vec_len {
379 let decorator_id = DecoratorId::from_u32_safe(source.read()?, mast_forest)?;
380 inner_vec.push(decorator_id);
381 }
382
383 out_vec.push((node_id, inner_vec));
384 }
385
386 Ok(out_vec)
387}