Skip to main content

miden_core/mast/serialization/
info.rs

1use alloc::vec::Vec;
2
3use super::{NodeDataOffset, basic_blocks::BasicBlockDataDecoder};
4#[cfg(test)]
5use crate::mast::node::MastNodeExt;
6use crate::{
7    mast::{MastForestContributor, MastNode, MastNodeId, Word, node::MastNodeBuilder},
8    serde::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
9    utils::Idx,
10};
11
12// CONSTANTS
13// ================================================================================================
14
15const JOIN: u8 = 0;
16const SPLIT: u8 = 1;
17const LOOP: u8 = 2;
18const BLOCK: u8 = 3;
19const CALL: u8 = 4;
20const SYSCALL: u8 = 5;
21const DYN: u8 = 6;
22const DYNCALL: u8 = 7;
23const EXTERNAL: u8 = 8;
24
25// MAST NODE ENTRIES
26// ================================================================================================
27
28/// Fixed-width structural metadata for a serialized [`MastNode`].
29///
30/// This is the random-access portion of the node table. Digests are intentionally modeled
31/// separately so the wire format can move them into dedicated sections.
32///
33/// Child indices for `Join`, `Split`, `Loop`, `Call`, and `SysCall` are stored inline so random
34/// access does not need any extra pointer chasing.
35///
36/// The serialized representation is always 8 bytes, which keeps the node-entry table fixed-width
37/// on the wire.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39#[repr(u8)]
40pub enum MastNodeEntry {
41    Join {
42        left_child_id: u32,
43        right_child_id: u32,
44    } = JOIN,
45    Split {
46        if_branch_id: u32,
47        else_branch_id: u32,
48    } = SPLIT,
49    Loop {
50        body_id: u32,
51    } = LOOP,
52    Block {
53        // offset of operations in node data
54        ops_offset: u32,
55    } = BLOCK,
56    Call {
57        callee_id: u32,
58    } = CALL,
59    SysCall {
60        callee_id: u32,
61    } = SYSCALL,
62    Dyn = DYN,
63    Dyncall = DYNCALL,
64    External = EXTERNAL,
65}
66
67/// Constructors
68impl MastNodeEntry {
69    /// Serialized byte size of one fixed-width MAST node entry.
70    pub const SERIALIZED_SIZE: usize = 8;
71
72    /// Constructs a new [`MastNodeEntry`] from a [`MastNode`].
73    pub fn new(mast_node: &MastNode, ops_offset: NodeDataOffset) -> Self {
74        Self::new_inner(mast_node, ops_offset, None)
75    }
76
77    /// Constructs a new [`MastNodeEntry`] from a [`MastNode`], remapping child IDs when needed.
78    pub fn new_with_id_remap(
79        mast_node: &MastNode,
80        ops_offset: NodeDataOffset,
81        id_remap: Option<&[u32]>,
82    ) -> Self {
83        Self::new_inner(mast_node, ops_offset, id_remap)
84    }
85
86    fn new_inner(
87        mast_node: &MastNode,
88        ops_offset: NodeDataOffset,
89        id_remap: Option<&[u32]>,
90    ) -> Self {
91        use MastNode::*;
92
93        if !matches!(mast_node, &Block(_)) {
94            debug_assert_eq!(ops_offset, 0);
95        }
96
97        let remap_id = |id: MastNodeId| -> u32 {
98            id_remap.and_then(|remap| remap.get(id.to_usize()).copied()).unwrap_or(id.0)
99        };
100
101        match mast_node {
102            Block(_) => Self::Block { ops_offset },
103            Join(join_node) => Self::Join {
104                left_child_id: remap_id(join_node.first()),
105                right_child_id: remap_id(join_node.second()),
106            },
107            Split(split_node) => Self::Split {
108                if_branch_id: remap_id(split_node.on_true()),
109                else_branch_id: remap_id(split_node.on_false()),
110            },
111            Loop(loop_node) => Self::Loop { body_id: remap_id(loop_node.body()) },
112            Call(call_node) => {
113                if call_node.is_syscall() {
114                    Self::SysCall { callee_id: remap_id(call_node.callee()) }
115                } else {
116                    Self::Call { callee_id: remap_id(call_node.callee()) }
117                }
118            },
119            Dyn(dyn_node) => {
120                if dyn_node.is_dyncall() {
121                    Self::Dyncall
122                } else {
123                    Self::Dyn
124                }
125            },
126            External(_) => Self::External,
127        }
128    }
129
130    /// Attempts to convert this [`MastNodeEntry`] into a [`MastNodeBuilder`].
131    ///
132    /// The `node_count` is the total expected number of nodes in the
133    /// [`crate::mast::MastForest`] **after deserialization**.
134    pub fn try_into_mast_node_builder(
135        self,
136        node_count: usize,
137        basic_block_data_decoder: &BasicBlockDataDecoder,
138        digest: Word,
139    ) -> Result<MastNodeBuilder, DeserializationError> {
140        match self {
141            Self::Block { ops_offset } => {
142                let op_batches = basic_block_data_decoder.decode_operations(ops_offset)?;
143                let builder = crate::mast::node::BasicBlockNodeBuilder::from_op_batches(
144                    op_batches,
145                    Vec::new(), // decorators set later
146                    digest,
147                );
148                Ok(MastNodeBuilder::BasicBlock(builder))
149            },
150            Self::Join { left_child_id, right_child_id } => {
151                let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?;
152                let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?;
153                let builder = crate::mast::node::JoinNodeBuilder::new([left_child, right_child])
154                    .with_digest(digest);
155                Ok(MastNodeBuilder::Join(builder))
156            },
157            Self::Split { if_branch_id, else_branch_id } => {
158                let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?;
159                let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?;
160                let builder = crate::mast::node::SplitNodeBuilder::new([if_branch, else_branch])
161                    .with_digest(digest);
162                Ok(MastNodeBuilder::Split(builder))
163            },
164            Self::Loop { body_id } => {
165                let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?;
166                let builder = crate::mast::node::LoopNodeBuilder::new(body_id).with_digest(digest);
167                Ok(MastNodeBuilder::Loop(builder))
168            },
169            Self::Call { callee_id } => {
170                let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
171                let builder =
172                    crate::mast::node::CallNodeBuilder::new(callee_id).with_digest(digest);
173                Ok(MastNodeBuilder::Call(builder))
174            },
175            Self::SysCall { callee_id } => {
176                let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
177                let builder =
178                    crate::mast::node::CallNodeBuilder::new_syscall(callee_id).with_digest(digest);
179                Ok(MastNodeBuilder::Call(builder))
180            },
181            Self::Dyn => Ok(MastNodeBuilder::Dyn(
182                crate::mast::node::DynNodeBuilder::new_dyn().with_digest(digest),
183            )),
184            Self::Dyncall => Ok(MastNodeBuilder::Dyn(
185                crate::mast::node::DynNodeBuilder::new_dyncall().with_digest(digest),
186            )),
187            Self::External => {
188                Ok(MastNodeBuilder::External(crate::mast::node::ExternalNodeBuilder::new(digest)))
189            },
190        }
191    }
192}
193
194impl Serializable for MastNodeEntry {
195    fn write_into<W: ByteWriter>(&self, target: &mut W) {
196        let discriminant = self.discriminant() as u64;
197        assert!(discriminant <= 0b1111);
198
199        let payload = match *self {
200            Self::Join {
201                left_child_id: left,
202                right_child_id: right,
203            } => Self::encode_u32_pair(left, right),
204            Self::Split {
205                if_branch_id: if_branch,
206                else_branch_id: else_branch,
207            } => Self::encode_u32_pair(if_branch, else_branch),
208            Self::Loop { body_id: body } => Self::encode_u32_payload(body),
209            Self::Block { ops_offset } => Self::encode_u32_payload(ops_offset),
210            Self::Call { callee_id } => Self::encode_u32_payload(callee_id),
211            Self::SysCall { callee_id } => Self::encode_u32_payload(callee_id),
212            Self::Dyn | Self::Dyncall | Self::External => 0,
213        };
214
215        let value = (discriminant << 60) | payload;
216        target.write_u64(value);
217    }
218}
219
220impl Deserializable for MastNodeEntry {
221    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
222        let (discriminant, payload) = {
223            let value = source.read_u64()?;
224
225            let discriminant = (value >> 60) as u8;
226            let payload = value & 0x0f_ff_ff_ff_ff_ff_ff_ff;
227
228            (discriminant, payload)
229        };
230
231        match discriminant {
232            JOIN => {
233                let (left_child_id, right_child_id) = Self::decode_u32_pair(payload);
234                Ok(Self::Join { left_child_id, right_child_id })
235            },
236            SPLIT => {
237                let (if_branch_id, else_branch_id) = Self::decode_u32_pair(payload);
238                Ok(Self::Split { if_branch_id, else_branch_id })
239            },
240            LOOP => {
241                let body_id = Self::decode_u32_payload(payload)?;
242                Ok(Self::Loop { body_id })
243            },
244            BLOCK => {
245                let ops_offset = Self::decode_u32_payload(payload)?;
246                Ok(Self::Block { ops_offset })
247            },
248            CALL => {
249                let callee_id = Self::decode_u32_payload(payload)?;
250                Ok(Self::Call { callee_id })
251            },
252            SYSCALL => {
253                let callee_id = Self::decode_u32_payload(payload)?;
254                Ok(Self::SysCall { callee_id })
255            },
256            DYN => Ok(Self::Dyn),
257            DYNCALL => Ok(Self::Dyncall),
258            EXTERNAL => Ok(Self::External),
259            _ => Err(DeserializationError::InvalidValue(format!(
260                "Invalid tag for MAST node: {discriminant}"
261            ))),
262        }
263    }
264
265    /// Returns the fixed serialized size: always 8 bytes (u64).
266    fn min_serialized_size() -> usize {
267        Self::SERIALIZED_SIZE
268    }
269}
270
271/// Serialization helpers
272impl MastNodeEntry {
273    fn discriminant(&self) -> u8 {
274        // SAFETY: This is safe because we have given this enum a primitive representation with
275        // #[repr(u8)], with the first field of the underlying union-of-structs the discriminant.
276        //
277        // See the section on "accessing the numeric value of the discriminant"
278        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
279        unsafe { *<*const _>::from(self).cast::<u8>() }
280    }
281
282    /// Encodes two u32 numbers in the first 60 bits of a `u64`.
283    ///
284    /// # Panics
285    /// - Panics if either `left_value` or `right_value` doesn't fit in 30 bits.
286    fn encode_u32_pair(left_value: u32, right_value: u32) -> u64 {
287        assert!(
288            left_value.leading_zeros() >= 2,
289            "MastNodeEntry::encode_u32_pair: left value doesn't fit in 30 bits: {left_value}",
290        );
291        assert!(
292            right_value.leading_zeros() >= 2,
293            "MastNodeEntry::encode_u32_pair: right value doesn't fit in 30 bits: {right_value}",
294        );
295
296        ((left_value as u64) << 30) | (right_value as u64)
297    }
298
299    fn encode_u32_payload(payload: u32) -> u64 {
300        payload as u64
301    }
302}
303
304/// Deserialization helpers
305impl MastNodeEntry {
306    /// Decodes two `u32` numbers from a 60-bit payload.
307    fn decode_u32_pair(payload: u64) -> (u32, u32) {
308        let left_value = (payload >> 30) as u32;
309        let right_value = (payload & 0x3f_ff_ff_ff) as u32;
310
311        (left_value, right_value)
312    }
313
314    /// Decodes one `u32` number from a 60-bit payload.
315    ///
316    /// Returns an error if the payload doesn't fit in a `u32`.
317    pub fn decode_u32_payload(payload: u64) -> Result<u32, DeserializationError> {
318        payload.try_into().map_err(|_| {
319            DeserializationError::InvalidValue(format!(
320                "Invalid payload: expected to fit in u32, but was {payload}"
321            ))
322        })
323    }
324}
325
326// MAST NODE INFO
327// ================================================================================================
328
329/// Logical node metadata combining fixed-width structure and a digest value.
330///
331/// This is a convenience type for APIs that want both pieces together. The wire format does not
332/// require `MastNodeInfo` to appear as one contiguous fixed-width section.
333#[derive(Debug, Clone, Copy, PartialEq, Eq)]
334pub struct MastNodeInfo {
335    entry: MastNodeEntry,
336    digest: Word,
337}
338
339impl MastNodeInfo {
340    /// Constructs a new [`MastNodeInfo`] from a [`MastNode`], along with an `ops_offset`
341    ///
342    /// For non-basic block nodes, `ops_offset` is ignored, and should be set to 0.
343    #[cfg(test)]
344    pub fn new(mast_node: &MastNode, ops_offset: NodeDataOffset) -> Self {
345        Self {
346            entry: MastNodeEntry::new(mast_node, ops_offset),
347            digest: mast_node.digest(),
348        }
349    }
350
351    /// Attempts to convert this [`MastNodeInfo`] into a [`MastNodeBuilder`].
352    #[cfg(test)]
353    pub fn try_into_mast_node_builder(
354        self,
355        node_count: usize,
356        basic_block_data_decoder: &BasicBlockDataDecoder,
357    ) -> Result<MastNodeBuilder, DeserializationError> {
358        self.entry
359            .try_into_mast_node_builder(node_count, basic_block_data_decoder, self.digest)
360    }
361
362    /// Returns the fixed-width structural node entry.
363    pub fn node_entry(&self) -> MastNodeEntry {
364        self.entry
365    }
366
367    /// Returns the stored node digest.
368    pub fn digest(&self) -> Word {
369        self.digest
370    }
371
372    /// Builds node metadata directly from serialized components.
373    pub(crate) fn from_entry(entry: MastNodeEntry, digest: Word) -> Self {
374        Self { entry, digest }
375    }
376}
377
378#[cfg(test)]
379impl Serializable for MastNodeInfo {
380    fn write_into<W: ByteWriter>(&self, target: &mut W) {
381        self.entry.write_into(target);
382        self.digest.write_into(target);
383    }
384}
385
386#[cfg(test)]
387impl Deserializable for MastNodeInfo {
388    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
389        let entry = MastNodeEntry::read_from(source)?;
390        let digest = Word::read_from(source)?;
391        Ok(Self { entry, digest })
392    }
393
394    /// Returns the minimum serialized size: 8 bytes for `MastNodeEntry` + 32 bytes for `Word`.
395    fn min_serialized_size() -> usize {
396        MastNodeEntry::min_serialized_size() + Word::min_serialized_size()
397    }
398}
399
400// TESTS
401// ================================================================================================
402
403#[cfg(test)]
404mod tests {
405    use alloc::vec::Vec;
406
407    use super::*;
408
409    #[test]
410    fn serialize_deserialize_60_bit_payload() {
411        // each child needs 30 bits
412        let mast_node_entry = MastNodeEntry::Join {
413            left_child_id: 0x3f_ff_ff_ff,
414            right_child_id: 0x3f_ff_ff_ff,
415        };
416
417        let serialized = mast_node_entry.to_bytes();
418        let deserialized = MastNodeEntry::read_from_bytes(&serialized).unwrap();
419
420        assert_eq!(mast_node_entry, deserialized);
421    }
422
423    #[test]
424    #[should_panic]
425    fn serialize_large_payloads_fails_1() {
426        // left child needs 31 bits
427        let mast_node_entry = MastNodeEntry::Join {
428            left_child_id: 0x4f_ff_ff_ff,
429            right_child_id: 0x0,
430        };
431
432        // must panic
433        let _serialized = mast_node_entry.to_bytes();
434    }
435
436    #[test]
437    #[should_panic]
438    fn serialize_large_payloads_fails_2() {
439        // right child needs 31 bits
440        let mast_node_entry = MastNodeEntry::Join {
441            left_child_id: 0x0,
442            right_child_id: 0x4f_ff_ff_ff,
443        };
444
445        // must panic
446        let _serialized = mast_node_entry.to_bytes();
447    }
448
449    #[test]
450    fn deserialize_large_payloads_fails() {
451        // Serialized `CALL` with a 33-bit payload
452        let serialized = {
453            let serialized_value = ((CALL as u64) << 60) | (u32::MAX as u64 + 1_u64);
454
455            let mut serialized_buffer: Vec<u8> = Vec::new();
456            serialized_value.write_into(&mut serialized_buffer);
457
458            serialized_buffer
459        };
460
461        let deserialized_result = MastNodeEntry::read_from_bytes(&serialized);
462
463        assert_matches!(deserialized_result, Err(DeserializationError::InvalidValue(_)));
464    }
465}