Skip to main content

miden_core/mast/serialization/
info.rs

1use super::{NodeDataOffset, basic_blocks::BasicBlockDataDecoder};
2#[cfg(test)]
3use crate::mast::node::MastNodeExt;
4use crate::{
5    mast::{MastForestContributor, MastNode, MastNodeId, Word, node::MastNodeBuilder},
6    serde::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
7};
8
9// CONSTANTS
10// ================================================================================================
11
12const JOIN: u8 = 0;
13const SPLIT: u8 = 1;
14const LOOP: u8 = 2;
15const BLOCK: u8 = 3;
16const CALL: u8 = 4;
17const SYSCALL: u8 = 5;
18const DYN: u8 = 6;
19const DYNCALL: u8 = 7;
20const EXTERNAL: u8 = 8;
21
22// MAST NODE ENTRIES
23// ================================================================================================
24
25/// Fixed-width structural metadata for a serialized [`MastNode`].
26///
27/// This is the random-access portion of the node table. Digests are intentionally modeled
28/// separately so the wire format can move them into dedicated sections.
29///
30/// Child indices for `Join`, `Split`, `Loop`, `Call`, and `SysCall` are stored inline so random
31/// access does not need any extra pointer chasing.
32///
33/// The serialized representation is always 8 bytes, which keeps the node-entry table fixed-width
34/// on the wire.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[repr(u8)]
37pub enum MastNodeEntry {
38    Join {
39        left_child_id: u32,
40        right_child_id: u32,
41    } = JOIN,
42    Split {
43        if_branch_id: u32,
44        else_branch_id: u32,
45    } = SPLIT,
46    Loop {
47        body_id: u32,
48    } = LOOP,
49    Block {
50        // offset of operations in node data
51        ops_offset: u32,
52    } = BLOCK,
53    Call {
54        callee_id: u32,
55    } = CALL,
56    SysCall {
57        callee_id: u32,
58    } = SYSCALL,
59    Dyn = DYN,
60    Dyncall = DYNCALL,
61    External = EXTERNAL,
62}
63
64/// Constructors
65impl MastNodeEntry {
66    /// Serialized byte size of one fixed-width MAST node entry.
67    pub const SERIALIZED_SIZE: usize = 8;
68
69    /// Constructs a new [`MastNodeEntry`] from a [`MastNode`].
70    pub fn new(mast_node: &MastNode, ops_offset: NodeDataOffset) -> Self {
71        use MastNode::*;
72
73        if !matches!(mast_node, &Block(_)) {
74            debug_assert_eq!(ops_offset, 0);
75        }
76
77        match mast_node {
78            Block(_) => Self::Block { ops_offset },
79            Join(join_node) => Self::Join {
80                left_child_id: join_node.first().0,
81                right_child_id: join_node.second().0,
82            },
83            Split(split_node) => Self::Split {
84                if_branch_id: split_node.on_true().0,
85                else_branch_id: split_node.on_false().0,
86            },
87            Loop(loop_node) => Self::Loop { body_id: loop_node.body().0 },
88            Call(call_node) => {
89                if call_node.is_syscall() {
90                    Self::SysCall { callee_id: call_node.callee().0 }
91                } else {
92                    Self::Call { callee_id: call_node.callee().0 }
93                }
94            },
95            Dyn(dyn_node) => {
96                if dyn_node.is_dyncall() {
97                    Self::Dyncall
98                } else {
99                    Self::Dyn
100                }
101            },
102            External(_) => Self::External,
103        }
104    }
105
106    /// Attempts to convert this [`MastNodeEntry`] into a [`MastNodeBuilder`].
107    ///
108    /// The `node_count` is the total expected number of nodes in the
109    /// [`crate::mast::MastForest`] **after deserialization**.
110    pub fn try_into_mast_node_builder(
111        self,
112        node_count: usize,
113        basic_block_data_decoder: &BasicBlockDataDecoder,
114        digest: Word,
115    ) -> Result<MastNodeBuilder, DeserializationError> {
116        match self {
117            Self::Block { ops_offset } => {
118                let op_batches = basic_block_data_decoder.decode_operations(ops_offset)?;
119                let builder =
120                    crate::mast::node::BasicBlockNodeBuilder::from_op_batches(op_batches, digest);
121                Ok(MastNodeBuilder::BasicBlock(builder))
122            },
123            Self::Join { left_child_id, right_child_id } => {
124                let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?;
125                let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?;
126                let builder = crate::mast::node::JoinNodeBuilder::new([left_child, right_child])
127                    .with_digest(digest);
128                Ok(MastNodeBuilder::Join(builder))
129            },
130            Self::Split { if_branch_id, else_branch_id } => {
131                let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?;
132                let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?;
133                let builder = crate::mast::node::SplitNodeBuilder::new([if_branch, else_branch])
134                    .with_digest(digest);
135                Ok(MastNodeBuilder::Split(builder))
136            },
137            Self::Loop { body_id } => {
138                let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?;
139                let builder = crate::mast::node::LoopNodeBuilder::new(body_id).with_digest(digest);
140                Ok(MastNodeBuilder::Loop(builder))
141            },
142            Self::Call { callee_id } => {
143                let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
144                let builder =
145                    crate::mast::node::CallNodeBuilder::new(callee_id).with_digest(digest);
146                Ok(MastNodeBuilder::Call(builder))
147            },
148            Self::SysCall { callee_id } => {
149                let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
150                let builder =
151                    crate::mast::node::CallNodeBuilder::new_syscall(callee_id).with_digest(digest);
152                Ok(MastNodeBuilder::Call(builder))
153            },
154            Self::Dyn => Ok(MastNodeBuilder::Dyn(
155                crate::mast::node::DynNodeBuilder::new_dyn().with_digest(digest),
156            )),
157            Self::Dyncall => Ok(MastNodeBuilder::Dyn(
158                crate::mast::node::DynNodeBuilder::new_dyncall().with_digest(digest),
159            )),
160            Self::External => {
161                Ok(MastNodeBuilder::External(crate::mast::node::ExternalNodeBuilder::new(digest)))
162            },
163        }
164    }
165}
166
167impl Serializable for MastNodeEntry {
168    fn write_into<W: ByteWriter>(&self, target: &mut W) {
169        let discriminant = self.discriminant() as u64;
170        assert!(discriminant <= 0b1111);
171
172        let payload = match *self {
173            Self::Join {
174                left_child_id: left,
175                right_child_id: right,
176            } => Self::encode_u32_pair(left, right),
177            Self::Split {
178                if_branch_id: if_branch,
179                else_branch_id: else_branch,
180            } => Self::encode_u32_pair(if_branch, else_branch),
181            Self::Loop { body_id: body } => Self::encode_u32_payload(body),
182            Self::Block { ops_offset } => Self::encode_u32_payload(ops_offset),
183            Self::Call { callee_id } => Self::encode_u32_payload(callee_id),
184            Self::SysCall { callee_id } => Self::encode_u32_payload(callee_id),
185            Self::Dyn | Self::Dyncall | Self::External => 0,
186        };
187
188        let value = (discriminant << 60) | payload;
189        target.write_u64(value);
190    }
191}
192
193impl Deserializable for MastNodeEntry {
194    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
195        let (discriminant, payload) = {
196            let value = source.read_u64()?;
197
198            let discriminant = (value >> 60) as u8;
199            let payload = value & 0x0f_ff_ff_ff_ff_ff_ff_ff;
200
201            (discriminant, payload)
202        };
203
204        match discriminant {
205            JOIN => {
206                let (left_child_id, right_child_id) = Self::decode_u32_pair(payload);
207                Ok(Self::Join { left_child_id, right_child_id })
208            },
209            SPLIT => {
210                let (if_branch_id, else_branch_id) = Self::decode_u32_pair(payload);
211                Ok(Self::Split { if_branch_id, else_branch_id })
212            },
213            LOOP => {
214                let body_id = Self::decode_u32_payload(payload)?;
215                Ok(Self::Loop { body_id })
216            },
217            BLOCK => {
218                let ops_offset = Self::decode_u32_payload(payload)?;
219                Ok(Self::Block { ops_offset })
220            },
221            CALL => {
222                let callee_id = Self::decode_u32_payload(payload)?;
223                Ok(Self::Call { callee_id })
224            },
225            SYSCALL => {
226                let callee_id = Self::decode_u32_payload(payload)?;
227                Ok(Self::SysCall { callee_id })
228            },
229            DYN => Ok(Self::Dyn),
230            DYNCALL => Ok(Self::Dyncall),
231            EXTERNAL => Ok(Self::External),
232            _ => Err(DeserializationError::InvalidValue(format!(
233                "Invalid tag for MAST node: {discriminant}"
234            ))),
235        }
236    }
237
238    /// Returns the fixed serialized size: always 8 bytes (u64).
239    fn min_serialized_size() -> usize {
240        Self::SERIALIZED_SIZE
241    }
242}
243
244/// Serialization helpers
245impl MastNodeEntry {
246    fn discriminant(&self) -> u8 {
247        // SAFETY: This is safe because we have given this enum a primitive representation with
248        // #[repr(u8)], with the first field of the underlying union-of-structs the discriminant.
249        //
250        // See the section on "accessing the numeric value of the discriminant"
251        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
252        unsafe { *<*const _>::from(self).cast::<u8>() }
253    }
254
255    /// Encodes two u32 numbers in the first 60 bits of a `u64`.
256    ///
257    /// # Panics
258    /// - Panics if either `left_value` or `right_value` doesn't fit in 30 bits.
259    fn encode_u32_pair(left_value: u32, right_value: u32) -> u64 {
260        assert!(
261            left_value.leading_zeros() >= 2,
262            "MastNodeEntry::encode_u32_pair: left value doesn't fit in 30 bits: {left_value}",
263        );
264        assert!(
265            right_value.leading_zeros() >= 2,
266            "MastNodeEntry::encode_u32_pair: right value doesn't fit in 30 bits: {right_value}",
267        );
268
269        ((left_value as u64) << 30) | (right_value as u64)
270    }
271
272    fn encode_u32_payload(payload: u32) -> u64 {
273        payload as u64
274    }
275}
276
277/// Deserialization helpers
278impl MastNodeEntry {
279    /// Decodes two `u32` numbers from a 60-bit payload.
280    fn decode_u32_pair(payload: u64) -> (u32, u32) {
281        let left_value = (payload >> 30) as u32;
282        let right_value = (payload & 0x3f_ff_ff_ff) as u32;
283
284        (left_value, right_value)
285    }
286
287    /// Decodes one `u32` number from a 60-bit payload.
288    ///
289    /// Returns an error if the payload doesn't fit in a `u32`.
290    pub fn decode_u32_payload(payload: u64) -> Result<u32, DeserializationError> {
291        payload.try_into().map_err(|_| {
292            DeserializationError::InvalidValue(format!(
293                "Invalid payload: expected to fit in u32, but was {payload}"
294            ))
295        })
296    }
297}
298
299// MAST NODE INFO
300// ================================================================================================
301
302/// Logical node metadata combining fixed-width structure and a digest value.
303///
304/// This is a convenience type for APIs that want both pieces together. The wire format does not
305/// require `MastNodeInfo` to appear as one contiguous fixed-width section.
306#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub struct MastNodeInfo {
308    entry: MastNodeEntry,
309    digest: Word,
310}
311
312impl MastNodeInfo {
313    /// Constructs a new [`MastNodeInfo`] from a [`MastNode`], along with an `ops_offset`
314    ///
315    /// For non-basic block nodes, `ops_offset` is ignored, and should be set to 0.
316    #[cfg(test)]
317    pub fn new(mast_node: &MastNode, ops_offset: NodeDataOffset) -> Self {
318        Self {
319            entry: MastNodeEntry::new(mast_node, ops_offset),
320            digest: mast_node.digest(),
321        }
322    }
323
324    /// Attempts to convert this [`MastNodeInfo`] into a [`MastNodeBuilder`].
325    #[cfg(test)]
326    pub fn try_into_mast_node_builder(
327        self,
328        node_count: usize,
329        basic_block_data_decoder: &BasicBlockDataDecoder,
330    ) -> Result<MastNodeBuilder, DeserializationError> {
331        self.entry
332            .try_into_mast_node_builder(node_count, basic_block_data_decoder, self.digest)
333    }
334
335    /// Returns the fixed-width structural node entry.
336    pub fn node_entry(&self) -> MastNodeEntry {
337        self.entry
338    }
339
340    /// Returns the stored node digest.
341    pub fn digest(&self) -> Word {
342        self.digest
343    }
344
345    /// Builds node metadata directly from serialized components.
346    pub(crate) fn from_entry(entry: MastNodeEntry, digest: Word) -> Self {
347        Self { entry, digest }
348    }
349}
350
351#[cfg(test)]
352impl Serializable for MastNodeInfo {
353    fn write_into<W: ByteWriter>(&self, target: &mut W) {
354        self.entry.write_into(target);
355        self.digest.write_into(target);
356    }
357}
358
359#[cfg(test)]
360impl Deserializable for MastNodeInfo {
361    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
362        let entry = MastNodeEntry::read_from(source)?;
363        let digest = Word::read_from(source)?;
364        Ok(Self { entry, digest })
365    }
366
367    /// Returns the minimum serialized size: 8 bytes for `MastNodeEntry` + 32 bytes for `Word`.
368    fn min_serialized_size() -> usize {
369        MastNodeEntry::min_serialized_size() + Word::min_serialized_size()
370    }
371}
372
373// TESTS
374// ================================================================================================
375
376#[cfg(test)]
377mod tests {
378    use alloc::vec::Vec;
379    use core::assert_matches;
380
381    use super::*;
382
383    #[test]
384    fn serialize_deserialize_60_bit_payload() {
385        // each child needs 30 bits
386        let mast_node_entry = MastNodeEntry::Join {
387            left_child_id: 0x3f_ff_ff_ff,
388            right_child_id: 0x3f_ff_ff_ff,
389        };
390
391        let serialized = mast_node_entry.to_bytes();
392        let deserialized = MastNodeEntry::read_from_bytes(&serialized).unwrap();
393
394        assert_eq!(mast_node_entry, deserialized);
395    }
396
397    #[test]
398    #[should_panic]
399    fn serialize_large_payloads_fails_1() {
400        // left child needs 31 bits
401        let mast_node_entry = MastNodeEntry::Join {
402            left_child_id: 0x4f_ff_ff_ff,
403            right_child_id: 0x0,
404        };
405
406        // must panic
407        let _serialized = mast_node_entry.to_bytes();
408    }
409
410    #[test]
411    #[should_panic]
412    fn serialize_large_payloads_fails_2() {
413        // right child needs 31 bits
414        let mast_node_entry = MastNodeEntry::Join {
415            left_child_id: 0x0,
416            right_child_id: 0x4f_ff_ff_ff,
417        };
418
419        // must panic
420        let _serialized = mast_node_entry.to_bytes();
421    }
422
423    #[test]
424    fn deserialize_large_payloads_fails() {
425        // Serialized `CALL` with a 33-bit payload
426        let serialized = {
427            let serialized_value = ((CALL as u64) << 60) | (u32::MAX as u64 + 1_u64);
428
429            let mut serialized_buffer: Vec<u8> = Vec::new();
430            serialized_value.write_into(&mut serialized_buffer);
431
432            serialized_buffer
433        };
434
435        let deserialized_result = MastNodeEntry::read_from_bytes(&serialized);
436
437        assert_matches!(deserialized_result, Err(DeserializationError::InvalidValue(_)));
438    }
439}