miden_core/mast/node/
join_node.rs1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{hash::rpo::RpoDigest, Felt};
5
6use crate::{
7 chiplets::hasher,
8 mast::{DecoratorId, MastForest, MastForestError, MastNodeId},
9 prettier::PrettyPrint,
10 OPCODE_JOIN,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct JoinNode {
20 children: [MastNodeId; 2],
21 digest: RpoDigest,
22 before_enter: Vec<DecoratorId>,
23 after_exit: Vec<DecoratorId>,
24}
25
26impl JoinNode {
28 pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
30}
31
32impl JoinNode {
34 pub fn new(
36 children: [MastNodeId; 2],
37 mast_forest: &MastForest,
38 ) -> Result<Self, MastForestError> {
39 let forest_len = mast_forest.nodes.len();
40 if children[0].as_usize() >= forest_len {
41 return Err(MastForestError::NodeIdOverflow(children[0], forest_len));
42 } else if children[1].as_usize() >= forest_len {
43 return Err(MastForestError::NodeIdOverflow(children[1], forest_len));
44 }
45 let digest = {
46 let left_child_hash = mast_forest[children[0]].digest();
47 let right_child_hash = mast_forest[children[1]].digest();
48
49 hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN)
50 };
51
52 Ok(Self {
53 children,
54 digest,
55 before_enter: Vec::new(),
56 after_exit: Vec::new(),
57 })
58 }
59
60 pub fn new_unsafe(children: [MastNodeId; 2], digest: RpoDigest) -> Self {
63 Self {
64 children,
65 digest,
66 before_enter: Vec::new(),
67 after_exit: Vec::new(),
68 }
69 }
70}
71
72impl JoinNode {
74 pub fn digest(&self) -> RpoDigest {
86 self.digest
87 }
88
89 pub fn first(&self) -> MastNodeId {
91 self.children[0]
92 }
93
94 pub fn second(&self) -> MastNodeId {
97 self.children[1]
98 }
99
100 pub fn before_enter(&self) -> &[DecoratorId] {
102 &self.before_enter
103 }
104
105 pub fn after_exit(&self) -> &[DecoratorId] {
107 &self.after_exit
108 }
109}
110
111impl JoinNode {
113 pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
115 self.before_enter = decorator_ids;
116 }
117
118 pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
120 self.after_exit = decorator_ids;
121 }
122}
123
124impl JoinNode {
128 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
129 JoinNodePrettyPrint { join_node: self, mast_forest }
130 }
131
132 pub(super) fn to_pretty_print<'a>(
133 &'a self,
134 mast_forest: &'a MastForest,
135 ) -> impl PrettyPrint + 'a {
136 JoinNodePrettyPrint { join_node: self, mast_forest }
137 }
138}
139
140struct JoinNodePrettyPrint<'a> {
141 join_node: &'a JoinNode,
142 mast_forest: &'a MastForest,
143}
144
145impl PrettyPrint for JoinNodePrettyPrint<'_> {
146 #[rustfmt::skip]
147 fn render(&self) -> crate::prettier::Document {
148 use crate::prettier::*;
149
150 let pre_decorators = {
151 let mut pre_decorators = self
152 .join_node
153 .before_enter()
154 .iter()
155 .map(|&decorator_id| self.mast_forest[decorator_id].render())
156 .reduce(|acc, doc| acc + const_text(" ") + doc)
157 .unwrap_or_default();
158 if !pre_decorators.is_empty() {
159 pre_decorators += nl();
160 }
161
162 pre_decorators
163 };
164
165 let post_decorators = {
166 let mut post_decorators = self
167 .join_node
168 .after_exit()
169 .iter()
170 .map(|&decorator_id| self.mast_forest[decorator_id].render())
171 .reduce(|acc, doc| acc + const_text(" ") + doc)
172 .unwrap_or_default();
173 if !post_decorators.is_empty() {
174 post_decorators = nl() + post_decorators;
175 }
176
177 post_decorators
178 };
179
180 let first_child =
181 self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
182 let second_child =
183 self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
184
185 pre_decorators
186 + indent(
187 4,
188 const_text("join")
189 + nl()
190 + first_child.render()
191 + nl()
192 + second_child.render(),
193 ) + nl() + const_text("end")
194 + post_decorators
195 }
196}
197
198impl fmt::Display for JoinNodePrettyPrint<'_> {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 use crate::prettier::PrettyPrint;
201 self.pretty_print(f)
202 }
203}