miden_core/mast/serialization/
info.rs1use alloc::vec::Vec;
2
3use super::{NodeDataOffset, basic_blocks::BasicBlockDataDecoder};
4#[cfg(test)]
5use crate::mast::node::MastNodeExt;
6use crate::{
7 mast::{MastForestContributor, MastNode, MastNodeId, Word, node::MastNodeBuilder},
8 serde::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
9 utils::Idx,
10};
11
12const JOIN: u8 = 0;
16const SPLIT: u8 = 1;
17const LOOP: u8 = 2;
18const BLOCK: u8 = 3;
19const CALL: u8 = 4;
20const SYSCALL: u8 = 5;
21const DYN: u8 = 6;
22const DYNCALL: u8 = 7;
23const EXTERNAL: u8 = 8;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39#[repr(u8)]
40pub enum MastNodeEntry {
41 Join {
42 left_child_id: u32,
43 right_child_id: u32,
44 } = JOIN,
45 Split {
46 if_branch_id: u32,
47 else_branch_id: u32,
48 } = SPLIT,
49 Loop {
50 body_id: u32,
51 } = LOOP,
52 Block {
53 ops_offset: u32,
55 } = BLOCK,
56 Call {
57 callee_id: u32,
58 } = CALL,
59 SysCall {
60 callee_id: u32,
61 } = SYSCALL,
62 Dyn = DYN,
63 Dyncall = DYNCALL,
64 External = EXTERNAL,
65}
66
67impl MastNodeEntry {
69 pub const SERIALIZED_SIZE: usize = 8;
71
72 pub fn new(mast_node: &MastNode, ops_offset: NodeDataOffset) -> Self {
74 Self::new_inner(mast_node, ops_offset, None)
75 }
76
77 pub fn new_with_id_remap(
79 mast_node: &MastNode,
80 ops_offset: NodeDataOffset,
81 id_remap: Option<&[u32]>,
82 ) -> Self {
83 Self::new_inner(mast_node, ops_offset, id_remap)
84 }
85
86 fn new_inner(
87 mast_node: &MastNode,
88 ops_offset: NodeDataOffset,
89 id_remap: Option<&[u32]>,
90 ) -> Self {
91 use MastNode::*;
92
93 if !matches!(mast_node, &Block(_)) {
94 debug_assert_eq!(ops_offset, 0);
95 }
96
97 let remap_id = |id: MastNodeId| -> u32 {
98 id_remap.and_then(|remap| remap.get(id.to_usize()).copied()).unwrap_or(id.0)
99 };
100
101 match mast_node {
102 Block(_) => Self::Block { ops_offset },
103 Join(join_node) => Self::Join {
104 left_child_id: remap_id(join_node.first()),
105 right_child_id: remap_id(join_node.second()),
106 },
107 Split(split_node) => Self::Split {
108 if_branch_id: remap_id(split_node.on_true()),
109 else_branch_id: remap_id(split_node.on_false()),
110 },
111 Loop(loop_node) => Self::Loop { body_id: remap_id(loop_node.body()) },
112 Call(call_node) => {
113 if call_node.is_syscall() {
114 Self::SysCall { callee_id: remap_id(call_node.callee()) }
115 } else {
116 Self::Call { callee_id: remap_id(call_node.callee()) }
117 }
118 },
119 Dyn(dyn_node) => {
120 if dyn_node.is_dyncall() {
121 Self::Dyncall
122 } else {
123 Self::Dyn
124 }
125 },
126 External(_) => Self::External,
127 }
128 }
129
130 pub fn try_into_mast_node_builder(
135 self,
136 node_count: usize,
137 basic_block_data_decoder: &BasicBlockDataDecoder,
138 digest: Word,
139 ) -> Result<MastNodeBuilder, DeserializationError> {
140 match self {
141 Self::Block { ops_offset } => {
142 let op_batches = basic_block_data_decoder.decode_operations(ops_offset)?;
143 let builder = crate::mast::node::BasicBlockNodeBuilder::from_op_batches(
144 op_batches,
145 Vec::new(), digest,
147 );
148 Ok(MastNodeBuilder::BasicBlock(builder))
149 },
150 Self::Join { left_child_id, right_child_id } => {
151 let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?;
152 let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?;
153 let builder = crate::mast::node::JoinNodeBuilder::new([left_child, right_child])
154 .with_digest(digest);
155 Ok(MastNodeBuilder::Join(builder))
156 },
157 Self::Split { if_branch_id, else_branch_id } => {
158 let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?;
159 let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?;
160 let builder = crate::mast::node::SplitNodeBuilder::new([if_branch, else_branch])
161 .with_digest(digest);
162 Ok(MastNodeBuilder::Split(builder))
163 },
164 Self::Loop { body_id } => {
165 let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?;
166 let builder = crate::mast::node::LoopNodeBuilder::new(body_id).with_digest(digest);
167 Ok(MastNodeBuilder::Loop(builder))
168 },
169 Self::Call { callee_id } => {
170 let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
171 let builder =
172 crate::mast::node::CallNodeBuilder::new(callee_id).with_digest(digest);
173 Ok(MastNodeBuilder::Call(builder))
174 },
175 Self::SysCall { callee_id } => {
176 let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
177 let builder =
178 crate::mast::node::CallNodeBuilder::new_syscall(callee_id).with_digest(digest);
179 Ok(MastNodeBuilder::Call(builder))
180 },
181 Self::Dyn => Ok(MastNodeBuilder::Dyn(
182 crate::mast::node::DynNodeBuilder::new_dyn().with_digest(digest),
183 )),
184 Self::Dyncall => Ok(MastNodeBuilder::Dyn(
185 crate::mast::node::DynNodeBuilder::new_dyncall().with_digest(digest),
186 )),
187 Self::External => {
188 Ok(MastNodeBuilder::External(crate::mast::node::ExternalNodeBuilder::new(digest)))
189 },
190 }
191 }
192}
193
194impl Serializable for MastNodeEntry {
195 fn write_into<W: ByteWriter>(&self, target: &mut W) {
196 let discriminant = self.discriminant() as u64;
197 assert!(discriminant <= 0b1111);
198
199 let payload = match *self {
200 Self::Join {
201 left_child_id: left,
202 right_child_id: right,
203 } => Self::encode_u32_pair(left, right),
204 Self::Split {
205 if_branch_id: if_branch,
206 else_branch_id: else_branch,
207 } => Self::encode_u32_pair(if_branch, else_branch),
208 Self::Loop { body_id: body } => Self::encode_u32_payload(body),
209 Self::Block { ops_offset } => Self::encode_u32_payload(ops_offset),
210 Self::Call { callee_id } => Self::encode_u32_payload(callee_id),
211 Self::SysCall { callee_id } => Self::encode_u32_payload(callee_id),
212 Self::Dyn | Self::Dyncall | Self::External => 0,
213 };
214
215 let value = (discriminant << 60) | payload;
216 target.write_u64(value);
217 }
218}
219
220impl Deserializable for MastNodeEntry {
221 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
222 let (discriminant, payload) = {
223 let value = source.read_u64()?;
224
225 let discriminant = (value >> 60) as u8;
226 let payload = value & 0x0f_ff_ff_ff_ff_ff_ff_ff;
227
228 (discriminant, payload)
229 };
230
231 match discriminant {
232 JOIN => {
233 let (left_child_id, right_child_id) = Self::decode_u32_pair(payload);
234 Ok(Self::Join { left_child_id, right_child_id })
235 },
236 SPLIT => {
237 let (if_branch_id, else_branch_id) = Self::decode_u32_pair(payload);
238 Ok(Self::Split { if_branch_id, else_branch_id })
239 },
240 LOOP => {
241 let body_id = Self::decode_u32_payload(payload)?;
242 Ok(Self::Loop { body_id })
243 },
244 BLOCK => {
245 let ops_offset = Self::decode_u32_payload(payload)?;
246 Ok(Self::Block { ops_offset })
247 },
248 CALL => {
249 let callee_id = Self::decode_u32_payload(payload)?;
250 Ok(Self::Call { callee_id })
251 },
252 SYSCALL => {
253 let callee_id = Self::decode_u32_payload(payload)?;
254 Ok(Self::SysCall { callee_id })
255 },
256 DYN => Ok(Self::Dyn),
257 DYNCALL => Ok(Self::Dyncall),
258 EXTERNAL => Ok(Self::External),
259 _ => Err(DeserializationError::InvalidValue(format!(
260 "Invalid tag for MAST node: {discriminant}"
261 ))),
262 }
263 }
264
265 fn min_serialized_size() -> usize {
267 Self::SERIALIZED_SIZE
268 }
269}
270
271impl MastNodeEntry {
273 fn discriminant(&self) -> u8 {
274 unsafe { *<*const _>::from(self).cast::<u8>() }
280 }
281
282 fn encode_u32_pair(left_value: u32, right_value: u32) -> u64 {
287 assert!(
288 left_value.leading_zeros() >= 2,
289 "MastNodeEntry::encode_u32_pair: left value doesn't fit in 30 bits: {left_value}",
290 );
291 assert!(
292 right_value.leading_zeros() >= 2,
293 "MastNodeEntry::encode_u32_pair: right value doesn't fit in 30 bits: {right_value}",
294 );
295
296 ((left_value as u64) << 30) | (right_value as u64)
297 }
298
299 fn encode_u32_payload(payload: u32) -> u64 {
300 payload as u64
301 }
302}
303
304impl MastNodeEntry {
306 fn decode_u32_pair(payload: u64) -> (u32, u32) {
308 let left_value = (payload >> 30) as u32;
309 let right_value = (payload & 0x3f_ff_ff_ff) as u32;
310
311 (left_value, right_value)
312 }
313
314 pub fn decode_u32_payload(payload: u64) -> Result<u32, DeserializationError> {
318 payload.try_into().map_err(|_| {
319 DeserializationError::InvalidValue(format!(
320 "Invalid payload: expected to fit in u32, but was {payload}"
321 ))
322 })
323 }
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq)]
334pub struct MastNodeInfo {
335 entry: MastNodeEntry,
336 digest: Word,
337}
338
339impl MastNodeInfo {
340 #[cfg(test)]
344 pub fn new(mast_node: &MastNode, ops_offset: NodeDataOffset) -> Self {
345 Self {
346 entry: MastNodeEntry::new(mast_node, ops_offset),
347 digest: mast_node.digest(),
348 }
349 }
350
351 #[cfg(test)]
353 pub fn try_into_mast_node_builder(
354 self,
355 node_count: usize,
356 basic_block_data_decoder: &BasicBlockDataDecoder,
357 ) -> Result<MastNodeBuilder, DeserializationError> {
358 self.entry
359 .try_into_mast_node_builder(node_count, basic_block_data_decoder, self.digest)
360 }
361
362 pub fn node_entry(&self) -> MastNodeEntry {
364 self.entry
365 }
366
367 pub fn digest(&self) -> Word {
369 self.digest
370 }
371
372 pub(crate) fn from_entry(entry: MastNodeEntry, digest: Word) -> Self {
374 Self { entry, digest }
375 }
376}
377
378#[cfg(test)]
379impl Serializable for MastNodeInfo {
380 fn write_into<W: ByteWriter>(&self, target: &mut W) {
381 self.entry.write_into(target);
382 self.digest.write_into(target);
383 }
384}
385
386#[cfg(test)]
387impl Deserializable for MastNodeInfo {
388 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
389 let entry = MastNodeEntry::read_from(source)?;
390 let digest = Word::read_from(source)?;
391 Ok(Self { entry, digest })
392 }
393
394 fn min_serialized_size() -> usize {
396 MastNodeEntry::min_serialized_size() + Word::min_serialized_size()
397 }
398}
399
400#[cfg(test)]
404mod tests {
405 use alloc::vec::Vec;
406
407 use super::*;
408
409 #[test]
410 fn serialize_deserialize_60_bit_payload() {
411 let mast_node_entry = MastNodeEntry::Join {
413 left_child_id: 0x3f_ff_ff_ff,
414 right_child_id: 0x3f_ff_ff_ff,
415 };
416
417 let serialized = mast_node_entry.to_bytes();
418 let deserialized = MastNodeEntry::read_from_bytes(&serialized).unwrap();
419
420 assert_eq!(mast_node_entry, deserialized);
421 }
422
423 #[test]
424 #[should_panic]
425 fn serialize_large_payloads_fails_1() {
426 let mast_node_entry = MastNodeEntry::Join {
428 left_child_id: 0x4f_ff_ff_ff,
429 right_child_id: 0x0,
430 };
431
432 let _serialized = mast_node_entry.to_bytes();
434 }
435
436 #[test]
437 #[should_panic]
438 fn serialize_large_payloads_fails_2() {
439 let mast_node_entry = MastNodeEntry::Join {
441 left_child_id: 0x0,
442 right_child_id: 0x4f_ff_ff_ff,
443 };
444
445 let _serialized = mast_node_entry.to_bytes();
447 }
448
449 #[test]
450 fn deserialize_large_payloads_fails() {
451 let serialized = {
453 let serialized_value = ((CALL as u64) << 60) | (u32::MAX as u64 + 1_u64);
454
455 let mut serialized_buffer: Vec<u8> = Vec::new();
456 serialized_value.write_into(&mut serialized_buffer);
457
458 serialized_buffer
459 };
460
461 let deserialized_result = MastNodeEntry::read_from_bytes(&serialized);
462
463 assert_matches!(deserialized_result, Err(DeserializationError::InvalidValue(_)));
464 }
465}