miden_core/mast/node/
join_node.rs1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8use super::{MastNodeErrorContext, MastNodeExt};
9use crate::{
10 OPCODE_JOIN,
11 chiplets::hasher,
12 mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
13 prettier::PrettyPrint,
14};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23pub struct JoinNode {
24 children: [MastNodeId; 2],
25 digest: Word,
26 #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
27 before_enter: Vec<DecoratorId>,
28 #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
29 after_exit: Vec<DecoratorId>,
30}
31
32impl JoinNode {
34 pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
36}
37
38impl JoinNode {
40 pub fn new(
42 children: [MastNodeId; 2],
43 mast_forest: &MastForest,
44 ) -> Result<Self, MastForestError> {
45 let forest_len = mast_forest.nodes.len();
46 if children[0].as_usize() >= forest_len {
47 return Err(MastForestError::NodeIdOverflow(children[0], forest_len));
48 } else if children[1].as_usize() >= forest_len {
49 return Err(MastForestError::NodeIdOverflow(children[1], forest_len));
50 }
51 let digest = {
52 let left_child_hash = mast_forest[children[0]].digest();
53 let right_child_hash = mast_forest[children[1]].digest();
54
55 hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN)
56 };
57
58 Ok(Self {
59 children,
60 digest,
61 before_enter: Vec::new(),
62 after_exit: Vec::new(),
63 })
64 }
65
66 pub fn new_unsafe(children: [MastNodeId; 2], digest: Word) -> Self {
69 Self {
70 children,
71 digest,
72 before_enter: Vec::new(),
73 after_exit: Vec::new(),
74 }
75 }
76}
77
78impl JoinNode {
80 pub fn first(&self) -> MastNodeId {
82 self.children[0]
83 }
84
85 pub fn second(&self) -> MastNodeId {
88 self.children[1]
89 }
90}
91
92impl MastNodeErrorContext for JoinNode {
93 fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
94 self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
95 }
96}
97
98impl JoinNode {
102 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
103 JoinNodePrettyPrint { join_node: self, mast_forest }
104 }
105
106 pub(super) fn to_pretty_print<'a>(
107 &'a self,
108 mast_forest: &'a MastForest,
109 ) -> impl PrettyPrint + 'a {
110 JoinNodePrettyPrint { join_node: self, mast_forest }
111 }
112}
113
114struct JoinNodePrettyPrint<'a> {
115 join_node: &'a JoinNode,
116 mast_forest: &'a MastForest,
117}
118
119impl PrettyPrint for JoinNodePrettyPrint<'_> {
120 #[rustfmt::skip]
121 fn render(&self) -> crate::prettier::Document {
122 use crate::prettier::*;
123
124 let pre_decorators = {
125 let mut pre_decorators = self
126 .join_node
127 .before_enter()
128 .iter()
129 .map(|&decorator_id| self.mast_forest[decorator_id].render())
130 .reduce(|acc, doc| acc + const_text(" ") + doc)
131 .unwrap_or_default();
132 if !pre_decorators.is_empty() {
133 pre_decorators += nl();
134 }
135
136 pre_decorators
137 };
138
139 let post_decorators = {
140 let mut post_decorators = self
141 .join_node
142 .after_exit()
143 .iter()
144 .map(|&decorator_id| self.mast_forest[decorator_id].render())
145 .reduce(|acc, doc| acc + const_text(" ") + doc)
146 .unwrap_or_default();
147 if !post_decorators.is_empty() {
148 post_decorators = nl() + post_decorators;
149 }
150
151 post_decorators
152 };
153
154 let first_child =
155 self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
156 let second_child =
157 self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
158
159 pre_decorators
160 + indent(
161 4,
162 const_text("join")
163 + nl()
164 + first_child.render()
165 + nl()
166 + second_child.render(),
167 ) + nl() + const_text("end")
168 + post_decorators
169 }
170}
171
172impl fmt::Display for JoinNodePrettyPrint<'_> {
173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 use crate::prettier::PrettyPrint;
175 self.pretty_print(f)
176 }
177}
178
179impl MastNodeExt for JoinNode {
183 fn digest(&self) -> Word {
195 self.digest
196 }
197
198 fn before_enter(&self) -> &[DecoratorId] {
200 &self.before_enter
201 }
202
203 fn after_exit(&self) -> &[DecoratorId] {
205 &self.after_exit
206 }
207 fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
209 self.before_enter.extend_from_slice(decorator_ids);
210 }
211
212 fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
214 self.after_exit.extend_from_slice(decorator_ids);
215 }
216
217 fn remove_decorators(&mut self) {
219 self.before_enter.truncate(0);
220 self.after_exit.truncate(0);
221 }
222
223 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
224 Box::new(JoinNode::to_display(self, mast_forest))
225 }
226
227 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
228 Box::new(JoinNode::to_pretty_print(self, mast_forest))
229 }
230
231 fn remap_children(&self, remapping: &Remapping) -> Self {
232 let mut node = self.clone();
233 node.children[0] = node.children[0].remap(remapping);
234 node.children[1] = node.children[1].remap(remapping);
235 node
236 }
237
238 fn has_children(&self) -> bool {
239 true
240 }
241
242 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
243 target.push(self.first());
244 target.push(self.second());
245 }
246
247 fn domain(&self) -> Felt {
248 Self::DOMAIN
249 }
250}