miden_core/mast/node/
join_node.rs1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, hash::rpo::RpoDigest};
5
6use super::MastNodeExt;
7use crate::{
8 OPCODE_JOIN,
9 chiplets::hasher,
10 mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
11 prettier::PrettyPrint,
12};
13
14#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct JoinNode {
21 children: [MastNodeId; 2],
22 digest: RpoDigest,
23 before_enter: Vec<DecoratorId>,
24 after_exit: Vec<DecoratorId>,
25}
26
27impl JoinNode {
29 pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
31}
32
33impl JoinNode {
35 pub fn new(
37 children: [MastNodeId; 2],
38 mast_forest: &MastForest,
39 ) -> Result<Self, MastForestError> {
40 let forest_len = mast_forest.nodes.len();
41 if children[0].as_usize() >= forest_len {
42 return Err(MastForestError::NodeIdOverflow(children[0], forest_len));
43 } else if children[1].as_usize() >= forest_len {
44 return Err(MastForestError::NodeIdOverflow(children[1], forest_len));
45 }
46 let digest = {
47 let left_child_hash = mast_forest[children[0]].digest();
48 let right_child_hash = mast_forest[children[1]].digest();
49
50 hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN)
51 };
52
53 Ok(Self {
54 children,
55 digest,
56 before_enter: Vec::new(),
57 after_exit: Vec::new(),
58 })
59 }
60
61 pub fn new_unsafe(children: [MastNodeId; 2], digest: RpoDigest) -> Self {
64 Self {
65 children,
66 digest,
67 before_enter: Vec::new(),
68 after_exit: Vec::new(),
69 }
70 }
71}
72
73impl JoinNode {
75 pub fn digest(&self) -> RpoDigest {
87 self.digest
88 }
89
90 pub fn first(&self) -> MastNodeId {
92 self.children[0]
93 }
94
95 pub fn second(&self) -> MastNodeId {
98 self.children[1]
99 }
100
101 pub fn before_enter(&self) -> &[DecoratorId] {
103 &self.before_enter
104 }
105
106 pub fn after_exit(&self) -> &[DecoratorId] {
108 &self.after_exit
109 }
110}
111
112impl JoinNode {
114 pub fn remap_children(&self, remapping: &Remapping) -> Self {
115 let mut node = self.clone();
116 node.children[0] = node.children[0].remap(remapping);
117 node.children[1] = node.children[1].remap(remapping);
118 node
119 }
120
121 pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
123 self.before_enter.extend_from_slice(decorator_ids);
124 }
125
126 pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
128 self.after_exit.extend_from_slice(decorator_ids);
129 }
130}
131
132impl MastNodeExt for JoinNode {
133 fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
134 self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
135 }
136}
137
138impl JoinNode {
142 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
143 JoinNodePrettyPrint { join_node: self, mast_forest }
144 }
145
146 pub(super) fn to_pretty_print<'a>(
147 &'a self,
148 mast_forest: &'a MastForest,
149 ) -> impl PrettyPrint + 'a {
150 JoinNodePrettyPrint { join_node: self, mast_forest }
151 }
152}
153
154struct JoinNodePrettyPrint<'a> {
155 join_node: &'a JoinNode,
156 mast_forest: &'a MastForest,
157}
158
159impl PrettyPrint for JoinNodePrettyPrint<'_> {
160 #[rustfmt::skip]
161 fn render(&self) -> crate::prettier::Document {
162 use crate::prettier::*;
163
164 let pre_decorators = {
165 let mut pre_decorators = self
166 .join_node
167 .before_enter()
168 .iter()
169 .map(|&decorator_id| self.mast_forest[decorator_id].render())
170 .reduce(|acc, doc| acc + const_text(" ") + doc)
171 .unwrap_or_default();
172 if !pre_decorators.is_empty() {
173 pre_decorators += nl();
174 }
175
176 pre_decorators
177 };
178
179 let post_decorators = {
180 let mut post_decorators = self
181 .join_node
182 .after_exit()
183 .iter()
184 .map(|&decorator_id| self.mast_forest[decorator_id].render())
185 .reduce(|acc, doc| acc + const_text(" ") + doc)
186 .unwrap_or_default();
187 if !post_decorators.is_empty() {
188 post_decorators = nl() + post_decorators;
189 }
190
191 post_decorators
192 };
193
194 let first_child =
195 self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
196 let second_child =
197 self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
198
199 pre_decorators
200 + indent(
201 4,
202 const_text("join")
203 + nl()
204 + first_child.render()
205 + nl()
206 + second_child.render(),
207 ) + nl() + const_text("end")
208 + post_decorators
209 }
210}
211
212impl fmt::Display for JoinNodePrettyPrint<'_> {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 use crate::prettier::PrettyPrint;
215 self.pretty_print(f)
216 }
217}