miden_core/mast/serialization/
mod.rs1use alloc::{
42 collections::BTreeMap,
43 string::{String, ToString},
44 sync::Arc,
45 vec::Vec,
46};
47
48use decorator::{DecoratorDataBuilder, DecoratorInfo};
49use string_table::StringTable;
50
51use super::{DecoratorId, MastForest, MastNode, MastNodeId};
52use crate::{
53 AdviceMap,
54 utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
55};
56
57mod decorator;
58
59mod info;
60use info::MastNodeInfo;
61
62mod basic_blocks;
63use basic_blocks::{BasicBlockDataBuilder, BasicBlockDataDecoder};
64
65use crate::DecoratorList;
66
67mod string_table;
68
69#[cfg(test)]
70mod tests;
71
72type NodeDataOffset = u32;
77
78type DecoratorDataOffset = u32;
80
81type StringDataOffset = usize;
83
84type StringIndex = usize;
86
87const MAGIC: &[u8; 5] = b"MAST\0";
92
93const VERSION: [u8; 3] = [0, 0, 0];
99
100impl Serializable for MastForest {
104 fn write_into<W: ByteWriter>(&self, target: &mut W) {
105 let mut basic_block_data_builder = BasicBlockDataBuilder::new();
106
107 let mut before_enter_decorators: Vec<(usize, Vec<DecoratorId>)> = Vec::new();
109 let mut after_exit_decorators: Vec<(usize, Vec<DecoratorId>)> = Vec::new();
110
111 let mut basic_block_decorators: Vec<(usize, Vec<(usize, DecoratorId)>)> = Vec::new();
112
113 target.write_bytes(MAGIC);
115 target.write_bytes(&VERSION);
116
117 target.write_usize(self.nodes.len());
119 target.write_usize(self.decorators.len());
120
121 let roots: Vec<u32> = self.roots.iter().map(u32::from).collect();
123 roots.write_into(target);
124
125 let mast_node_infos: Vec<MastNodeInfo> = self
128 .nodes
129 .iter()
130 .enumerate()
131 .map(|(mast_node_id, mast_node)| {
132 if !mast_node.before_enter().is_empty() {
133 before_enter_decorators.push((mast_node_id, mast_node.before_enter().to_vec()));
134 }
135 if !mast_node.after_exit().is_empty() {
136 after_exit_decorators.push((mast_node_id, mast_node.after_exit().to_vec()));
137 }
138
139 let ops_offset = if let MastNode::Block(basic_block) = mast_node {
140 let ops_offset = basic_block_data_builder.encode_basic_block(basic_block);
141
142 basic_block_decorators.push((mast_node_id, basic_block.decorators().clone()));
143
144 ops_offset
145 } else {
146 0
147 };
148
149 MastNodeInfo::new(mast_node, ops_offset)
150 })
151 .collect();
152
153 let basic_block_data = basic_block_data_builder.finalize();
154 basic_block_data.write_into(target);
155
156 for mast_node_info in mast_node_infos {
158 mast_node_info.write_into(target);
159 }
160
161 self.advice_map.write_into(target);
162 let error_codes: BTreeMap<u64, String> =
163 self.error_codes.iter().map(|(k, v)| (*k, v.to_string())).collect();
164 error_codes.write_into(target);
165
166 let mut decorator_data_builder = DecoratorDataBuilder::new();
169 for decorator in &self.decorators {
170 decorator_data_builder.add_decorator(decorator)
171 }
172
173 let (decorator_data, decorator_infos, string_table) = decorator_data_builder.finalize();
174
175 decorator_data.write_into(target);
177 string_table.write_into(target);
178
179 for decorator_info in decorator_infos {
181 decorator_info.write_into(target);
182 }
183
184 basic_block_decorators.write_into(target);
185
186 before_enter_decorators.write_into(target);
188 after_exit_decorators.write_into(target);
189 }
190}
191
192impl Deserializable for MastForest {
193 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
194 read_and_validate_magic(source)?;
195 read_and_validate_version(source)?;
196
197 let node_count = source.read_usize()?;
199 let decorator_count = source.read_usize()?;
200
201 let roots: Vec<u32> = Deserializable::read_from(source)?;
203
204 let basic_block_data: Vec<u8> = Deserializable::read_from(source)?;
206 let mast_node_infos: Vec<MastNodeInfo> = node_infos_iter(source, node_count)
207 .collect::<Result<Vec<MastNodeInfo>, DeserializationError>>()?;
208
209 let advice_map = AdviceMap::read_from(source)?;
210
211 let error_codes: BTreeMap<u64, String> = Deserializable::read_from(source)?;
212 let error_codes: BTreeMap<u64, Arc<str>> =
213 error_codes.into_iter().map(|(k, v)| (k, Arc::from(v))).collect();
214
215 let decorator_data: Vec<u8> = Deserializable::read_from(source)?;
217 let string_table: StringTable = Deserializable::read_from(source)?;
218 let decorator_infos = decorator_infos_iter(source, decorator_count);
219
220 let mut mast_forest = {
222 let mut mast_forest = MastForest::new();
223
224 for decorator_info in decorator_infos {
225 let decorator_info = decorator_info?;
226 let decorator =
227 decorator_info.try_into_decorator(&string_table, &decorator_data)?;
228
229 mast_forest.add_decorator(decorator).map_err(|e| {
230 DeserializationError::InvalidValue(format!(
231 "failed to add decorator to MAST forest while deserializing: {e}",
232 ))
233 })?;
234 }
235
236 let basic_block_data_decoder = BasicBlockDataDecoder::new(&basic_block_data);
238 for mast_node_info in mast_node_infos {
239 let node =
240 mast_node_info.try_into_mast_node(node_count, &basic_block_data_decoder)?;
241
242 mast_forest.add_node(node).map_err(|e| {
243 DeserializationError::InvalidValue(format!(
244 "failed to add node to MAST forest while deserializing: {e}",
245 ))
246 })?;
247 }
248
249 for root in roots {
251 let root = MastNodeId::from_u32_safe(root, &mast_forest)?;
253 mast_forest.make_root(root);
254 }
255
256 mast_forest.advice_map = advice_map;
257
258 mast_forest
259 };
260
261 let basic_block_decorators: Vec<(usize, DecoratorList)> =
262 read_block_decorators(source, &mast_forest)?;
263 for (node_id, decorator_list) in basic_block_decorators {
264 let node_id = MastNodeId::from_usize_safe(node_id, &mast_forest)?;
265
266 match &mut mast_forest[node_id] {
267 MastNode::Block(basic_block) => {
268 basic_block.set_decorators(decorator_list);
269 },
270 other => {
271 return Err(DeserializationError::InvalidValue(format!(
272 "expected mast node with id {node_id} to be a basic block, found {other:?}"
273 )));
274 },
275 }
276 }
277
278 let before_enter_decorators: Vec<(usize, Vec<DecoratorId>)> =
280 read_before_after_decorators(source, &mast_forest)?;
281 for (node_id, decorator_ids) in before_enter_decorators {
282 let node_id = MastNodeId::from_usize_safe(node_id, &mast_forest)?;
283 mast_forest.append_before_enter(node_id, &decorator_ids);
284 }
285
286 let after_exit_decorators: Vec<(usize, Vec<DecoratorId>)> =
287 read_before_after_decorators(source, &mast_forest)?;
288 for (node_id, decorator_ids) in after_exit_decorators {
289 let node_id = MastNodeId::from_usize_safe(node_id, &mast_forest)?;
290 mast_forest.append_after_exit(node_id, &decorator_ids);
291 }
292
293 mast_forest.error_codes = error_codes;
294
295 Ok(mast_forest)
296 }
297}
298
299fn read_and_validate_magic<R: ByteReader>(source: &mut R) -> Result<[u8; 5], DeserializationError> {
300 let magic: [u8; 5] = source.read_array()?;
301 if magic != *MAGIC {
302 return Err(DeserializationError::InvalidValue(format!(
303 "Invalid magic bytes. Expected '{:?}', got '{:?}'",
304 *MAGIC, magic
305 )));
306 }
307 Ok(magic)
308}
309
310fn read_and_validate_version<R: ByteReader>(
311 source: &mut R,
312) -> Result<[u8; 3], DeserializationError> {
313 let version: [u8; 3] = source.read_array()?;
314 if version != VERSION {
315 return Err(DeserializationError::InvalidValue(format!(
316 "Unsupported version. Got '{version:?}', but only '{VERSION:?}' is supported",
317 )));
318 }
319 Ok(version)
320}
321
322fn read_block_decorators<R: ByteReader>(
323 source: &mut R,
324 mast_forest: &MastForest,
325) -> Result<Vec<(usize, DecoratorList)>, DeserializationError> {
326 let vec_len: usize = source.read()?;
327 let mut out_vec: Vec<_> = Vec::with_capacity(vec_len);
328
329 for _ in 0..vec_len {
330 let node_id: usize = source.read()?;
331
332 let decorator_vec_len: usize = source.read()?;
333 let mut inner_vec: Vec<(usize, DecoratorId)> = Vec::with_capacity(decorator_vec_len);
334 for _ in 0..decorator_vec_len {
335 let op_id: usize = source.read()?;
336 let decorator_id = DecoratorId::from_u32_safe(source.read()?, mast_forest)?;
337 inner_vec.push((op_id, decorator_id));
338 }
339
340 out_vec.push((node_id, inner_vec));
341 }
342
343 Ok(out_vec)
344}
345
346fn decorator_infos_iter<'a, R>(
347 source: &'a mut R,
348 decorator_count: usize,
349) -> impl Iterator<Item = Result<DecoratorInfo, DeserializationError>> + 'a
350where
351 R: ByteReader + 'a,
352{
353 let mut remaining = decorator_count;
354 core::iter::from_fn(move || {
355 if remaining == 0 {
356 return None;
357 }
358 remaining -= 1;
359 Some(DecoratorInfo::read_from(source))
360 })
361}
362
363fn node_infos_iter<'a, R>(
364 source: &'a mut R,
365 node_count: usize,
366) -> impl Iterator<Item = Result<MastNodeInfo, DeserializationError>> + 'a
367where
368 R: ByteReader + 'a,
369{
370 let mut remaining = node_count;
371 core::iter::from_fn(move || {
372 if remaining == 0 {
373 return None;
374 }
375 remaining -= 1;
376 Some(MastNodeInfo::read_from(source))
377 })
378}
379
380fn read_before_after_decorators<R: ByteReader>(
386 source: &mut R,
387 mast_forest: &MastForest,
388) -> Result<Vec<(usize, Vec<DecoratorId>)>, DeserializationError> {
389 let vec_len: usize = source.read()?;
390 let mut out_vec: Vec<_> = Vec::with_capacity(vec_len);
391
392 for _ in 0..vec_len {
393 let node_id: usize = source.read()?;
394
395 let inner_vec_len: usize = source.read()?;
396 let mut inner_vec: Vec<DecoratorId> = Vec::with_capacity(inner_vec_len);
397 for _ in 0..inner_vec_len {
398 let decorator_id = DecoratorId::from_u32_safe(source.read()?, mast_forest)?;
399 inner_vec.push(decorator_id);
400 }
401
402 out_vec.push((node_id, inner_vec));
403 }
404
405 Ok(out_vec)
406}