miden_core/mast/serialization/
mod.rs1use alloc::{
42 collections::BTreeMap,
43 string::{String, ToString},
44 sync::Arc,
45 vec::Vec,
46};
47
48use decorator::{DecoratorDataBuilder, DecoratorInfo};
49use string_table::StringTable;
50use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
51
52use super::{DecoratorId, MastForest, MastNode, MastNodeId};
53use crate::AdviceMap;
54
55mod decorator;
56
57mod info;
58use info::MastNodeInfo;
59
60mod basic_blocks;
61use basic_blocks::{BasicBlockDataBuilder, BasicBlockDataDecoder};
62
63use crate::DecoratorList;
64
65mod string_table;
66
67#[cfg(test)]
68mod tests;
69
70type NodeDataOffset = u32;
75
76type DecoratorDataOffset = u32;
78
79type StringDataOffset = usize;
81
82type StringIndex = usize;
84
85const MAGIC: &[u8; 5] = b"MAST\0";
90
91const VERSION: [u8; 3] = [0, 0, 0];
97
98impl Serializable for MastForest {
102 fn write_into<W: ByteWriter>(&self, target: &mut W) {
103 let mut basic_block_data_builder = BasicBlockDataBuilder::new();
104
105 let mut before_enter_decorators: Vec<(usize, Vec<DecoratorId>)> = Vec::new();
107 let mut after_exit_decorators: Vec<(usize, Vec<DecoratorId>)> = Vec::new();
108
109 let mut basic_block_decorators: Vec<(usize, Vec<(usize, DecoratorId)>)> = Vec::new();
110
111 target.write_bytes(MAGIC);
113 target.write_bytes(&VERSION);
114
115 target.write_usize(self.nodes.len());
117 target.write_usize(self.decorators.len());
118
119 let roots: Vec<u32> = self.roots.iter().map(u32::from).collect();
121 roots.write_into(target);
122
123 let mast_node_infos: Vec<MastNodeInfo> = self
126 .nodes
127 .iter()
128 .enumerate()
129 .map(|(mast_node_id, mast_node)| {
130 if !mast_node.before_enter().is_empty() {
131 before_enter_decorators.push((mast_node_id, mast_node.before_enter().to_vec()));
132 }
133 if !mast_node.after_exit().is_empty() {
134 after_exit_decorators.push((mast_node_id, mast_node.after_exit().to_vec()));
135 }
136
137 let ops_offset = if let MastNode::Block(basic_block) = mast_node {
138 let ops_offset = basic_block_data_builder.encode_basic_block(basic_block);
139
140 basic_block_decorators.push((mast_node_id, basic_block.decorators().clone()));
141
142 ops_offset
143 } else {
144 0
145 };
146
147 MastNodeInfo::new(mast_node, ops_offset)
148 })
149 .collect();
150
151 let basic_block_data = basic_block_data_builder.finalize();
152 basic_block_data.write_into(target);
153
154 for mast_node_info in mast_node_infos {
156 mast_node_info.write_into(target);
157 }
158
159 self.advice_map.write_into(target);
160 let error_codes: BTreeMap<u64, String> =
161 self.error_codes.iter().map(|(k, v)| (*k, v.to_string())).collect();
162 error_codes.write_into(target);
163
164 let mut decorator_data_builder = DecoratorDataBuilder::new();
167 for decorator in &self.decorators {
168 decorator_data_builder.add_decorator(decorator)
169 }
170
171 let (decorator_data, decorator_infos, string_table) = decorator_data_builder.finalize();
172
173 decorator_data.write_into(target);
175 string_table.write_into(target);
176
177 for decorator_info in decorator_infos {
179 decorator_info.write_into(target);
180 }
181
182 basic_block_decorators.write_into(target);
183
184 before_enter_decorators.write_into(target);
186 after_exit_decorators.write_into(target);
187 }
188}
189
190impl Deserializable for MastForest {
191 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
192 read_and_validate_magic(source)?;
193 read_and_validate_version(source)?;
194
195 let node_count = source.read_usize()?;
197 let decorator_count = source.read_usize()?;
198
199 let roots: Vec<u32> = Deserializable::read_from(source)?;
201
202 let basic_block_data: Vec<u8> = Deserializable::read_from(source)?;
204 let mast_node_infos: Vec<MastNodeInfo> = node_infos_iter(source, node_count)
205 .collect::<Result<Vec<MastNodeInfo>, DeserializationError>>()?;
206
207 let advice_map = AdviceMap::read_from(source)?;
208
209 let error_codes: BTreeMap<u64, String> = Deserializable::read_from(source)?;
210 let error_codes: BTreeMap<u64, Arc<str>> =
211 error_codes.into_iter().map(|(k, v)| (k, Arc::from(v))).collect();
212
213 let decorator_data: Vec<u8> = Deserializable::read_from(source)?;
215 let string_table: StringTable = Deserializable::read_from(source)?;
216 let decorator_infos = decorator_infos_iter(source, decorator_count);
217
218 let mut mast_forest = {
220 let mut mast_forest = MastForest::new();
221
222 for decorator_info in decorator_infos {
223 let decorator_info = decorator_info?;
224 let decorator =
225 decorator_info.try_into_decorator(&string_table, &decorator_data)?;
226
227 mast_forest.add_decorator(decorator).map_err(|e| {
228 DeserializationError::InvalidValue(format!(
229 "failed to add decorator to MAST forest while deserializing: {e}",
230 ))
231 })?;
232 }
233
234 let basic_block_data_decoder = BasicBlockDataDecoder::new(&basic_block_data);
236 for mast_node_info in mast_node_infos {
237 let node =
238 mast_node_info.try_into_mast_node(node_count, &basic_block_data_decoder)?;
239
240 mast_forest.add_node(node).map_err(|e| {
241 DeserializationError::InvalidValue(format!(
242 "failed to add node to MAST forest while deserializing: {e}",
243 ))
244 })?;
245 }
246
247 for root in roots {
249 let root = MastNodeId::from_u32_safe(root, &mast_forest)?;
251 mast_forest.make_root(root);
252 }
253
254 mast_forest.advice_map = advice_map;
255
256 mast_forest
257 };
258
259 let basic_block_decorators: Vec<(usize, DecoratorList)> =
260 read_block_decorators(source, &mast_forest)?;
261 for (node_id, decorator_list) in basic_block_decorators {
262 let node_id = MastNodeId::from_usize_safe(node_id, &mast_forest)?;
263
264 match &mut mast_forest[node_id] {
265 MastNode::Block(basic_block) => {
266 basic_block.set_decorators(decorator_list);
267 },
268 other => {
269 return Err(DeserializationError::InvalidValue(format!(
270 "expected mast node with id {node_id} to be a basic block, found {other:?}"
271 )));
272 },
273 }
274 }
275
276 let before_enter_decorators: Vec<(usize, Vec<DecoratorId>)> =
278 read_before_after_decorators(source, &mast_forest)?;
279 for (node_id, decorator_ids) in before_enter_decorators {
280 let node_id = MastNodeId::from_usize_safe(node_id, &mast_forest)?;
281 mast_forest.append_before_enter(node_id, &decorator_ids);
282 }
283
284 let after_exit_decorators: Vec<(usize, Vec<DecoratorId>)> =
285 read_before_after_decorators(source, &mast_forest)?;
286 for (node_id, decorator_ids) in after_exit_decorators {
287 let node_id = MastNodeId::from_usize_safe(node_id, &mast_forest)?;
288 mast_forest.append_after_exit(node_id, &decorator_ids);
289 }
290
291 mast_forest.error_codes = error_codes;
292
293 Ok(mast_forest)
294 }
295}
296
297fn read_and_validate_magic<R: ByteReader>(source: &mut R) -> Result<[u8; 5], DeserializationError> {
298 let magic: [u8; 5] = source.read_array()?;
299 if magic != *MAGIC {
300 return Err(DeserializationError::InvalidValue(format!(
301 "Invalid magic bytes. Expected '{:?}', got '{:?}'",
302 *MAGIC, magic
303 )));
304 }
305 Ok(magic)
306}
307
308fn read_and_validate_version<R: ByteReader>(
309 source: &mut R,
310) -> Result<[u8; 3], DeserializationError> {
311 let version: [u8; 3] = source.read_array()?;
312 if version != VERSION {
313 return Err(DeserializationError::InvalidValue(format!(
314 "Unsupported version. Got '{version:?}', but only '{VERSION:?}' is supported",
315 )));
316 }
317 Ok(version)
318}
319
320fn read_block_decorators<R: ByteReader>(
321 source: &mut R,
322 mast_forest: &MastForest,
323) -> Result<Vec<(usize, DecoratorList)>, DeserializationError> {
324 let vec_len: usize = source.read()?;
325 let mut out_vec: Vec<_> = Vec::with_capacity(vec_len);
326
327 for _ in 0..vec_len {
328 let node_id: usize = source.read()?;
329
330 let decorator_vec_len: usize = source.read()?;
331 let mut inner_vec: Vec<(usize, DecoratorId)> = Vec::with_capacity(decorator_vec_len);
332 for _ in 0..decorator_vec_len {
333 let op_id: usize = source.read()?;
334 let decorator_id = DecoratorId::from_u32_safe(source.read()?, mast_forest)?;
335 inner_vec.push((op_id, decorator_id));
336 }
337
338 out_vec.push((node_id, inner_vec));
339 }
340
341 Ok(out_vec)
342}
343
344fn decorator_infos_iter<'a, R>(
345 source: &'a mut R,
346 decorator_count: usize,
347) -> impl Iterator<Item = Result<DecoratorInfo, DeserializationError>> + 'a
348where
349 R: ByteReader + 'a,
350{
351 let mut remaining = decorator_count;
352 core::iter::from_fn(move || {
353 if remaining == 0 {
354 return None;
355 }
356 remaining -= 1;
357 Some(DecoratorInfo::read_from(source))
358 })
359}
360
361fn node_infos_iter<'a, R>(
362 source: &'a mut R,
363 node_count: usize,
364) -> impl Iterator<Item = Result<MastNodeInfo, DeserializationError>> + 'a
365where
366 R: ByteReader + 'a,
367{
368 let mut remaining = node_count;
369 core::iter::from_fn(move || {
370 if remaining == 0 {
371 return None;
372 }
373 remaining -= 1;
374 Some(MastNodeInfo::read_from(source))
375 })
376}
377
378fn read_before_after_decorators<R: ByteReader>(
384 source: &mut R,
385 mast_forest: &MastForest,
386) -> Result<Vec<(usize, Vec<DecoratorId>)>, DeserializationError> {
387 let vec_len: usize = source.read()?;
388 let mut out_vec: Vec<_> = Vec::with_capacity(vec_len);
389
390 for _ in 0..vec_len {
391 let node_id: usize = source.read()?;
392
393 let inner_vec_len: usize = source.read()?;
394 let mut inner_vec: Vec<DecoratorId> = Vec::with_capacity(inner_vec_len);
395 for _ in 0..inner_vec_len {
396 let decorator_id = DecoratorId::from_u32_safe(source.read()?, mast_forest)?;
397 inner_vec.push(decorator_id);
398 }
399
400 out_vec.push((node_id, inner_vec));
401 }
402
403 Ok(out_vec)
404}