miden_core/mast/serialization/
info.rs1use super::{NodeDataOffset, basic_blocks::BasicBlockDataDecoder};
2#[cfg(test)]
3use crate::mast::node::MastNodeExt;
4use crate::{
5 mast::{MastForestContributor, MastNode, MastNodeId, Word, node::MastNodeBuilder},
6 serde::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
7};
8
9const JOIN: u8 = 0;
13const SPLIT: u8 = 1;
14const LOOP: u8 = 2;
15const BLOCK: u8 = 3;
16const CALL: u8 = 4;
17const SYSCALL: u8 = 5;
18const DYN: u8 = 6;
19const DYNCALL: u8 = 7;
20const EXTERNAL: u8 = 8;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[repr(u8)]
37pub enum MastNodeEntry {
38 Join {
39 left_child_id: u32,
40 right_child_id: u32,
41 } = JOIN,
42 Split {
43 if_branch_id: u32,
44 else_branch_id: u32,
45 } = SPLIT,
46 Loop {
47 body_id: u32,
48 } = LOOP,
49 Block {
50 ops_offset: u32,
52 } = BLOCK,
53 Call {
54 callee_id: u32,
55 } = CALL,
56 SysCall {
57 callee_id: u32,
58 } = SYSCALL,
59 Dyn = DYN,
60 Dyncall = DYNCALL,
61 External = EXTERNAL,
62}
63
64impl MastNodeEntry {
66 pub const SERIALIZED_SIZE: usize = 8;
68
69 pub fn new(mast_node: &MastNode, ops_offset: NodeDataOffset) -> Self {
71 use MastNode::*;
72
73 if !matches!(mast_node, &Block(_)) {
74 debug_assert_eq!(ops_offset, 0);
75 }
76
77 match mast_node {
78 Block(_) => Self::Block { ops_offset },
79 Join(join_node) => Self::Join {
80 left_child_id: join_node.first().0,
81 right_child_id: join_node.second().0,
82 },
83 Split(split_node) => Self::Split {
84 if_branch_id: split_node.on_true().0,
85 else_branch_id: split_node.on_false().0,
86 },
87 Loop(loop_node) => Self::Loop { body_id: loop_node.body().0 },
88 Call(call_node) => {
89 if call_node.is_syscall() {
90 Self::SysCall { callee_id: call_node.callee().0 }
91 } else {
92 Self::Call { callee_id: call_node.callee().0 }
93 }
94 },
95 Dyn(dyn_node) => {
96 if dyn_node.is_dyncall() {
97 Self::Dyncall
98 } else {
99 Self::Dyn
100 }
101 },
102 External(_) => Self::External,
103 }
104 }
105
106 pub fn try_into_mast_node_builder(
111 self,
112 node_count: usize,
113 basic_block_data_decoder: &BasicBlockDataDecoder,
114 digest: Word,
115 ) -> Result<MastNodeBuilder, DeserializationError> {
116 match self {
117 Self::Block { ops_offset } => {
118 let op_batches = basic_block_data_decoder.decode_operations(ops_offset)?;
119 let builder =
120 crate::mast::node::BasicBlockNodeBuilder::from_op_batches(op_batches, digest);
121 Ok(MastNodeBuilder::BasicBlock(builder))
122 },
123 Self::Join { left_child_id, right_child_id } => {
124 let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?;
125 let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?;
126 let builder = crate::mast::node::JoinNodeBuilder::new([left_child, right_child])
127 .with_digest(digest);
128 Ok(MastNodeBuilder::Join(builder))
129 },
130 Self::Split { if_branch_id, else_branch_id } => {
131 let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?;
132 let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?;
133 let builder = crate::mast::node::SplitNodeBuilder::new([if_branch, else_branch])
134 .with_digest(digest);
135 Ok(MastNodeBuilder::Split(builder))
136 },
137 Self::Loop { body_id } => {
138 let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?;
139 let builder = crate::mast::node::LoopNodeBuilder::new(body_id).with_digest(digest);
140 Ok(MastNodeBuilder::Loop(builder))
141 },
142 Self::Call { callee_id } => {
143 let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
144 let builder =
145 crate::mast::node::CallNodeBuilder::new(callee_id).with_digest(digest);
146 Ok(MastNodeBuilder::Call(builder))
147 },
148 Self::SysCall { callee_id } => {
149 let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
150 let builder =
151 crate::mast::node::CallNodeBuilder::new_syscall(callee_id).with_digest(digest);
152 Ok(MastNodeBuilder::Call(builder))
153 },
154 Self::Dyn => Ok(MastNodeBuilder::Dyn(
155 crate::mast::node::DynNodeBuilder::new_dyn().with_digest(digest),
156 )),
157 Self::Dyncall => Ok(MastNodeBuilder::Dyn(
158 crate::mast::node::DynNodeBuilder::new_dyncall().with_digest(digest),
159 )),
160 Self::External => {
161 Ok(MastNodeBuilder::External(crate::mast::node::ExternalNodeBuilder::new(digest)))
162 },
163 }
164 }
165}
166
167impl Serializable for MastNodeEntry {
168 fn write_into<W: ByteWriter>(&self, target: &mut W) {
169 let discriminant = self.discriminant() as u64;
170 assert!(discriminant <= 0b1111);
171
172 let payload = match *self {
173 Self::Join {
174 left_child_id: left,
175 right_child_id: right,
176 } => Self::encode_u32_pair(left, right),
177 Self::Split {
178 if_branch_id: if_branch,
179 else_branch_id: else_branch,
180 } => Self::encode_u32_pair(if_branch, else_branch),
181 Self::Loop { body_id: body } => Self::encode_u32_payload(body),
182 Self::Block { ops_offset } => Self::encode_u32_payload(ops_offset),
183 Self::Call { callee_id } => Self::encode_u32_payload(callee_id),
184 Self::SysCall { callee_id } => Self::encode_u32_payload(callee_id),
185 Self::Dyn | Self::Dyncall | Self::External => 0,
186 };
187
188 let value = (discriminant << 60) | payload;
189 target.write_u64(value);
190 }
191}
192
193impl Deserializable for MastNodeEntry {
194 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
195 let (discriminant, payload) = {
196 let value = source.read_u64()?;
197
198 let discriminant = (value >> 60) as u8;
199 let payload = value & 0x0f_ff_ff_ff_ff_ff_ff_ff;
200
201 (discriminant, payload)
202 };
203
204 match discriminant {
205 JOIN => {
206 let (left_child_id, right_child_id) = Self::decode_u32_pair(payload);
207 Ok(Self::Join { left_child_id, right_child_id })
208 },
209 SPLIT => {
210 let (if_branch_id, else_branch_id) = Self::decode_u32_pair(payload);
211 Ok(Self::Split { if_branch_id, else_branch_id })
212 },
213 LOOP => {
214 let body_id = Self::decode_u32_payload(payload)?;
215 Ok(Self::Loop { body_id })
216 },
217 BLOCK => {
218 let ops_offset = Self::decode_u32_payload(payload)?;
219 Ok(Self::Block { ops_offset })
220 },
221 CALL => {
222 let callee_id = Self::decode_u32_payload(payload)?;
223 Ok(Self::Call { callee_id })
224 },
225 SYSCALL => {
226 let callee_id = Self::decode_u32_payload(payload)?;
227 Ok(Self::SysCall { callee_id })
228 },
229 DYN => Ok(Self::Dyn),
230 DYNCALL => Ok(Self::Dyncall),
231 EXTERNAL => Ok(Self::External),
232 _ => Err(DeserializationError::InvalidValue(format!(
233 "Invalid tag for MAST node: {discriminant}"
234 ))),
235 }
236 }
237
238 fn min_serialized_size() -> usize {
240 Self::SERIALIZED_SIZE
241 }
242}
243
244impl MastNodeEntry {
246 fn discriminant(&self) -> u8 {
247 unsafe { *<*const _>::from(self).cast::<u8>() }
253 }
254
255 fn encode_u32_pair(left_value: u32, right_value: u32) -> u64 {
260 assert!(
261 left_value.leading_zeros() >= 2,
262 "MastNodeEntry::encode_u32_pair: left value doesn't fit in 30 bits: {left_value}",
263 );
264 assert!(
265 right_value.leading_zeros() >= 2,
266 "MastNodeEntry::encode_u32_pair: right value doesn't fit in 30 bits: {right_value}",
267 );
268
269 ((left_value as u64) << 30) | (right_value as u64)
270 }
271
272 fn encode_u32_payload(payload: u32) -> u64 {
273 payload as u64
274 }
275}
276
277impl MastNodeEntry {
279 fn decode_u32_pair(payload: u64) -> (u32, u32) {
281 let left_value = (payload >> 30) as u32;
282 let right_value = (payload & 0x3f_ff_ff_ff) as u32;
283
284 (left_value, right_value)
285 }
286
287 pub fn decode_u32_payload(payload: u64) -> Result<u32, DeserializationError> {
291 payload.try_into().map_err(|_| {
292 DeserializationError::InvalidValue(format!(
293 "Invalid payload: expected to fit in u32, but was {payload}"
294 ))
295 })
296 }
297}
298
299#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub struct MastNodeInfo {
308 entry: MastNodeEntry,
309 digest: Word,
310}
311
312impl MastNodeInfo {
313 #[cfg(test)]
317 pub fn new(mast_node: &MastNode, ops_offset: NodeDataOffset) -> Self {
318 Self {
319 entry: MastNodeEntry::new(mast_node, ops_offset),
320 digest: mast_node.digest(),
321 }
322 }
323
324 #[cfg(test)]
326 pub fn try_into_mast_node_builder(
327 self,
328 node_count: usize,
329 basic_block_data_decoder: &BasicBlockDataDecoder,
330 ) -> Result<MastNodeBuilder, DeserializationError> {
331 self.entry
332 .try_into_mast_node_builder(node_count, basic_block_data_decoder, self.digest)
333 }
334
335 pub fn node_entry(&self) -> MastNodeEntry {
337 self.entry
338 }
339
340 pub fn digest(&self) -> Word {
342 self.digest
343 }
344
345 pub(crate) fn from_entry(entry: MastNodeEntry, digest: Word) -> Self {
347 Self { entry, digest }
348 }
349}
350
351#[cfg(test)]
352impl Serializable for MastNodeInfo {
353 fn write_into<W: ByteWriter>(&self, target: &mut W) {
354 self.entry.write_into(target);
355 self.digest.write_into(target);
356 }
357}
358
359#[cfg(test)]
360impl Deserializable for MastNodeInfo {
361 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
362 let entry = MastNodeEntry::read_from(source)?;
363 let digest = Word::read_from(source)?;
364 Ok(Self { entry, digest })
365 }
366
367 fn min_serialized_size() -> usize {
369 MastNodeEntry::min_serialized_size() + Word::min_serialized_size()
370 }
371}
372
373#[cfg(test)]
377mod tests {
378 use alloc::vec::Vec;
379 use core::assert_matches;
380
381 use super::*;
382
383 #[test]
384 fn serialize_deserialize_60_bit_payload() {
385 let mast_node_entry = MastNodeEntry::Join {
387 left_child_id: 0x3f_ff_ff_ff,
388 right_child_id: 0x3f_ff_ff_ff,
389 };
390
391 let serialized = mast_node_entry.to_bytes();
392 let deserialized = MastNodeEntry::read_from_bytes(&serialized).unwrap();
393
394 assert_eq!(mast_node_entry, deserialized);
395 }
396
397 #[test]
398 #[should_panic]
399 fn serialize_large_payloads_fails_1() {
400 let mast_node_entry = MastNodeEntry::Join {
402 left_child_id: 0x4f_ff_ff_ff,
403 right_child_id: 0x0,
404 };
405
406 let _serialized = mast_node_entry.to_bytes();
408 }
409
410 #[test]
411 #[should_panic]
412 fn serialize_large_payloads_fails_2() {
413 let mast_node_entry = MastNodeEntry::Join {
415 left_child_id: 0x0,
416 right_child_id: 0x4f_ff_ff_ff,
417 };
418
419 let _serialized = mast_node_entry.to_bytes();
421 }
422
423 #[test]
424 fn deserialize_large_payloads_fails() {
425 let serialized = {
427 let serialized_value = ((CALL as u64) << 60) | (u32::MAX as u64 + 1_u64);
428
429 let mut serialized_buffer: Vec<u8> = Vec::new();
430 serialized_value.write_into(&mut serialized_buffer);
431
432 serialized_buffer
433 };
434
435 let deserialized_result = MastNodeEntry::read_from_bytes(&serialized);
436
437 assert_matches!(deserialized_result, Err(DeserializationError::InvalidValue(_)));
438 }
439}