1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt, fingerprint_with_child_fingerprints};
8use crate::{
9 Felt, Word,
10 chiplets::hasher,
11 mast::{MastForest, MastForestError, MastNodeId},
12 operations::opcodes,
13 prettier::PrettyPrint,
14 utils::{Idx, LookupByIdx},
15};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
25pub struct JoinNode {
26 children: [MastNodeId; 2],
27 digest: Word,
28}
29
30impl JoinNode {
32 pub const DOMAIN: Felt = Felt::new_unchecked(opcodes::JOIN as u64);
34}
35
36impl JoinNode {
38 pub fn first(&self) -> MastNodeId {
40 self.children[0]
41 }
42
43 pub fn second(&self) -> MastNodeId {
46 self.children[1]
47 }
48}
49
50impl JoinNode {
54 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
55 JoinNodePrettyPrint { join_node: self, mast_forest }
56 }
57
58 pub(super) fn to_pretty_print<'a>(
59 &'a self,
60 mast_forest: &'a MastForest,
61 ) -> impl PrettyPrint + 'a {
62 JoinNodePrettyPrint { join_node: self, mast_forest }
63 }
64}
65
66struct JoinNodePrettyPrint<'a> {
67 join_node: &'a JoinNode,
68 mast_forest: &'a MastForest,
69}
70
71impl PrettyPrint for JoinNodePrettyPrint<'_> {
72 #[rustfmt::skip]
73 fn render(&self) -> crate::prettier::Document {
74 use crate::prettier::*;
75
76 let first_child =
77 self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
78 let second_child =
79 self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
80
81 indent(
82 4,
83 const_text("join")
84 + nl()
85 + first_child.render()
86 + nl()
87 + second_child.render(),
88 ) + nl() + const_text("end")
89 }
90}
91
92impl fmt::Display for JoinNodePrettyPrint<'_> {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 use crate::prettier::PrettyPrint;
95 self.pretty_print(f)
96 }
97}
98
99#[cfg(test)]
103impl JoinNode {
104 #[cfg(test)]
110 pub fn semantic_eq(&self, other: &JoinNode, _forest: &MastForest) -> bool {
111 if self.first() != other.first() || self.second() != other.second() {
113 return false;
114 }
115
116 if self.digest() != other.digest() {
118 return false;
119 }
120
121 true
122 }
123}
124
125impl MastNodeExt for JoinNode {
129 fn digest(&self) -> Word {
141 self.digest
142 }
143
144 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
145 Box::new(JoinNode::to_display(self, mast_forest))
146 }
147
148 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
149 Box::new(JoinNode::to_pretty_print(self, mast_forest))
150 }
151
152 fn has_children(&self) -> bool {
153 true
154 }
155
156 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
157 target.push(self.first());
158 target.push(self.second());
159 }
160
161 fn for_each_child<F>(&self, mut f: F)
162 where
163 F: FnMut(MastNodeId),
164 {
165 f(self.first());
166 f(self.second());
167 }
168
169 fn domain(&self) -> Felt {
170 Self::DOMAIN
171 }
172
173 type Builder = JoinNodeBuilder;
174
175 fn to_builder(self, _forest: &MastForest) -> Self::Builder {
176 JoinNodeBuilder::new(self.children).with_digest(self.digest)
177 }
178}
179
180#[cfg(all(feature = "arbitrary", test))]
184impl proptest::prelude::Arbitrary for JoinNode {
185 type Parameters = ();
186
187 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
188 use proptest::prelude::*;
189
190 use crate::Felt;
191
192 (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
194 .prop_map(|(first_child, second_child, digest_array)| {
195 let digest = Word::from(digest_array.map(Felt::new_unchecked));
197 JoinNode {
199 children: [first_child, second_child],
200 digest,
201 }
202 })
203 .no_shrink() .boxed()
205 }
206
207 type Strategy = proptest::prelude::BoxedStrategy<Self>;
208}
209
210#[derive(Debug)]
213pub struct JoinNodeBuilder {
214 children: [MastNodeId; 2],
215 digest: Option<Word>,
216}
217
218impl JoinNodeBuilder {
219 pub fn new(children: [MastNodeId; 2]) -> Self {
221 Self { children, digest: None }
222 }
223
224 pub fn build(self, mast_forest: &MastForest) -> Result<JoinNode, MastForestError> {
226 let forest_len = mast_forest.nodes.len();
227 if self.children[0].to_usize() >= forest_len {
228 return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
229 } else if self.children[1].to_usize() >= forest_len {
230 return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
231 }
232
233 let digest = if let Some(forced_digest) = self.digest {
235 forced_digest
236 } else {
237 let left_child_hash = mast_forest[self.children[0]].digest();
238 let right_child_hash = mast_forest[self.children[1]].digest();
239
240 hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
241 };
242
243 Ok(JoinNode { children: self.children, digest })
244 }
245
246 pub(in crate::mast) fn build_linked(self) -> Result<JoinNode, MastForestError> {
247 Ok(JoinNode {
248 children: self.children,
249 digest: self.digest.ok_or(MastForestError::DigestRequiredForDeserialization)?,
250 })
251 }
252}
253
254impl MastForestContributor for JoinNodeBuilder {
255 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
256 let forest_len = forest.nodes.len();
258 if self.children[0].to_usize() >= forest_len {
259 return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
260 } else if self.children[1].to_usize() >= forest_len {
261 return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
262 }
263
264 let digest = if let Some(forced_digest) = self.digest {
266 forced_digest
267 } else {
268 let left_child_hash = forest[self.children[0]].digest();
269 let right_child_hash = forest[self.children[1]].digest();
270
271 hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
272 };
273
274 let node_id = forest
277 .nodes
278 .push(JoinNode { children: self.children, digest }.into())
279 .map_err(|_| MastForestError::TooManyNodes)?;
280
281 Ok(node_id)
282 }
283
284 fn fingerprint_for_node(
285 &self,
286 forest: &MastForest,
287 hash_by_node_id: &impl LookupByIdx<MastNodeId, Word>,
288 ) -> Result<Word, MastForestError> {
289 let node_digest = if let Some(forced_digest) = self.digest {
290 forced_digest
291 } else {
292 let left_child_hash = forest[self.children[0]].digest();
293 let right_child_hash = forest[self.children[1]].digest();
294
295 hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
296 };
297
298 fingerprint_with_child_fingerprints(node_digest, &self.children, forest, hash_by_node_id)
299 }
300
301 fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
302 JoinNodeBuilder {
303 children: [
304 *remapping.get(self.children[0]).unwrap_or(&self.children[0]),
305 *remapping.get(self.children[1]).unwrap_or(&self.children[1]),
306 ],
307 digest: self.digest,
308 }
309 }
310
311 fn with_digest(mut self, digest: Word) -> Self {
312 self.digest = Some(digest);
313 self
314 }
315}
316
317impl JoinNodeBuilder {
318 pub(in crate::mast) fn add_to_forest_relaxed(
328 self,
329 forest: &mut MastForest,
330 ) -> Result<MastNodeId, MastForestError> {
331 let Some(digest) = self.digest else {
334 return Err(MastForestError::DigestRequiredForDeserialization);
335 };
336
337 let node_id = forest
340 .nodes
341 .push(JoinNode { children: self.children, digest }.into())
342 .map_err(|_| MastForestError::TooManyNodes)?;
343
344 Ok(node_id)
345 }
346}
347
348#[cfg(any(test, feature = "arbitrary"))]
349impl proptest::prelude::Arbitrary for JoinNodeBuilder {
350 type Parameters = JoinNodeBuilderParams;
351 type Strategy = proptest::strategy::BoxedStrategy<Self>;
352
353 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
354 use proptest::prelude::*;
355
356 let _ = params;
357 any::<[MastNodeId; 2]>().prop_map(Self::new).boxed()
358 }
359}
360
361#[cfg(any(test, feature = "arbitrary"))]
363#[derive(Clone, Debug, Default)]
364pub struct JoinNodeBuilderParams {}