1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt};
8#[cfg(debug_assertions)]
9use crate::mast::MastNode;
10use crate::{
11 Felt, Word,
12 chiplets::hasher,
13 mast::{
14 DecoratorId, DecoratorStore, MastForest, MastForestError, MastNodeFingerprint, MastNodeId,
15 },
16 operations::opcodes,
17 prettier::PrettyPrint,
18 utils::{Idx, LookupByIdx},
19};
20
21#[derive(Debug, Clone, PartialEq, Eq)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
29pub struct JoinNode {
30 children: [MastNodeId; 2],
31 digest: Word,
32 decorator_store: DecoratorStore,
33}
34
35impl JoinNode {
37 pub const DOMAIN: Felt = Felt::new(opcodes::JOIN as u64);
39}
40
41impl JoinNode {
43 pub fn first(&self) -> MastNodeId {
45 self.children[0]
46 }
47
48 pub fn second(&self) -> MastNodeId {
51 self.children[1]
52 }
53}
54
55impl JoinNode {
59 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
60 JoinNodePrettyPrint { join_node: self, mast_forest }
61 }
62
63 pub(super) fn to_pretty_print<'a>(
64 &'a self,
65 mast_forest: &'a MastForest,
66 ) -> impl PrettyPrint + 'a {
67 JoinNodePrettyPrint { join_node: self, mast_forest }
68 }
69}
70
71struct JoinNodePrettyPrint<'a> {
72 join_node: &'a JoinNode,
73 mast_forest: &'a MastForest,
74}
75
76impl PrettyPrint for JoinNodePrettyPrint<'_> {
77 #[rustfmt::skip]
78 fn render(&self) -> crate::prettier::Document {
79 use crate::prettier::*;
80
81 let pre_decorators = {
82 let mut pre_decorators = self
83 .join_node
84 .before_enter(self.mast_forest)
85 .iter()
86 .map(|&decorator_id| self.mast_forest[decorator_id].render())
87 .reduce(|acc, doc| acc + const_text(" ") + doc)
88 .unwrap_or_default();
89 if !pre_decorators.is_empty() {
90 pre_decorators += nl();
91 }
92
93 pre_decorators
94 };
95
96 let post_decorators = {
97 let mut post_decorators = self
98 .join_node
99 .after_exit(self.mast_forest)
100 .iter()
101 .map(|&decorator_id| self.mast_forest[decorator_id].render())
102 .reduce(|acc, doc| acc + const_text(" ") + doc)
103 .unwrap_or_default();
104 if !post_decorators.is_empty() {
105 post_decorators = nl() + post_decorators;
106 }
107
108 post_decorators
109 };
110
111 let first_child =
112 self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
113 let second_child =
114 self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
115
116 pre_decorators
117 + indent(
118 4,
119 const_text("join")
120 + nl()
121 + first_child.render()
122 + nl()
123 + second_child.render(),
124 ) + nl() + const_text("end")
125 + post_decorators
126 }
127}
128
129impl fmt::Display for JoinNodePrettyPrint<'_> {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 use crate::prettier::PrettyPrint;
132 self.pretty_print(f)
133 }
134}
135
136#[cfg(test)]
140impl JoinNode {
141 #[cfg(test)]
147 pub fn semantic_eq(&self, other: &JoinNode, forest: &MastForest) -> bool {
148 if self.first() != other.first() || self.second() != other.second() {
150 return false;
151 }
152
153 if self.digest() != other.digest() {
155 return false;
156 }
157
158 if self.before_enter(forest) != other.before_enter(forest) {
160 return false;
161 }
162
163 if self.after_exit(forest) != other.after_exit(forest) {
165 return false;
166 }
167
168 true
169 }
170}
171
172impl MastNodeExt for JoinNode {
176 fn digest(&self) -> Word {
188 self.digest
189 }
190
191 fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
193 #[cfg(debug_assertions)]
194 self.verify_node_in_forest(forest);
195 self.decorator_store.before_enter(forest)
196 }
197
198 fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
200 #[cfg(debug_assertions)]
201 self.verify_node_in_forest(forest);
202 self.decorator_store.after_exit(forest)
203 }
204
205 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
206 Box::new(JoinNode::to_display(self, mast_forest))
207 }
208
209 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
210 Box::new(JoinNode::to_pretty_print(self, mast_forest))
211 }
212
213 fn has_children(&self) -> bool {
214 true
215 }
216
217 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
218 target.push(self.first());
219 target.push(self.second());
220 }
221
222 fn for_each_child<F>(&self, mut f: F)
223 where
224 F: FnMut(MastNodeId),
225 {
226 f(self.first());
227 f(self.second());
228 }
229
230 fn domain(&self) -> Felt {
231 Self::DOMAIN
232 }
233
234 type Builder = JoinNodeBuilder;
235
236 fn to_builder(self, forest: &MastForest) -> Self::Builder {
237 match self.decorator_store {
239 DecoratorStore::Owned { before_enter, after_exit, .. } => {
240 let mut builder = JoinNodeBuilder::new(self.children);
241 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
242 builder
243 },
244 DecoratorStore::Linked { id } => {
245 let before_enter = forest.before_enter_decorators(id).to_vec();
247 let after_exit = forest.after_exit_decorators(id).to_vec();
248 let mut builder = JoinNodeBuilder::new(self.children);
249 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
250 builder
251 },
252 }
253 }
254
255 #[cfg(debug_assertions)]
256 fn verify_node_in_forest(&self, forest: &MastForest) {
257 if let Some(id) = self.decorator_store.linked_id() {
258 let self_ptr = self as *const Self;
260 let forest_node = &forest.nodes[id];
261 let forest_node_ptr = match forest_node {
262 MastNode::Join(join_node) => join_node as *const JoinNode as *const (),
263 _ => panic!("Node type mismatch at {:?}", id),
264 };
265 let self_as_void = self_ptr as *const ();
266 debug_assert_eq!(
267 self_as_void, forest_node_ptr,
268 "Node pointer mismatch: expected node at {:?} to be self",
269 id
270 );
271 }
272 }
273}
274
275#[cfg(all(feature = "arbitrary", test))]
279impl proptest::prelude::Arbitrary for JoinNode {
280 type Parameters = ();
281
282 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
283 use proptest::prelude::*;
284
285 use crate::Felt;
286
287 (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
289 .prop_map(|(first_child, second_child, digest_array)| {
290 let digest = Word::from(digest_array.map(Felt::new));
292 JoinNode {
294 children: [first_child, second_child],
295 digest,
296 decorator_store: DecoratorStore::default(),
297 }
298 })
299 .no_shrink() .boxed()
301 }
302
303 type Strategy = proptest::prelude::BoxedStrategy<Self>;
304}
305
306#[derive(Debug)]
309pub struct JoinNodeBuilder {
310 children: [MastNodeId; 2],
311 before_enter: Vec<DecoratorId>,
312 after_exit: Vec<DecoratorId>,
313 digest: Option<Word>,
314}
315
316impl JoinNodeBuilder {
317 pub fn new(children: [MastNodeId; 2]) -> Self {
319 Self {
320 children,
321 before_enter: Vec::new(),
322 after_exit: Vec::new(),
323 digest: None,
324 }
325 }
326
327 pub fn build(self, mast_forest: &MastForest) -> Result<JoinNode, MastForestError> {
329 let forest_len = mast_forest.nodes.len();
330 if self.children[0].to_usize() >= forest_len {
331 return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
332 } else if self.children[1].to_usize() >= forest_len {
333 return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
334 }
335
336 let digest = if let Some(forced_digest) = self.digest {
338 forced_digest
339 } else {
340 let left_child_hash = mast_forest[self.children[0]].digest();
341 let right_child_hash = mast_forest[self.children[1]].digest();
342
343 hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
344 };
345
346 Ok(JoinNode {
347 children: self.children,
348 digest,
349 decorator_store: DecoratorStore::new_owned_with_decorators(
350 self.before_enter,
351 self.after_exit,
352 ),
353 })
354 }
355}
356
357impl MastForestContributor for JoinNodeBuilder {
358 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
359 let forest_len = forest.nodes.len();
361 if self.children[0].to_usize() >= forest_len {
362 return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
363 } else if self.children[1].to_usize() >= forest_len {
364 return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
365 }
366
367 let digest = if let Some(forced_digest) = self.digest {
369 forced_digest
370 } else {
371 let left_child_hash = forest[self.children[0]].digest();
372 let right_child_hash = forest[self.children[1]].digest();
373
374 hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
375 };
376
377 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
379
380 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
382
383 let node_id = forest
386 .nodes
387 .push(
388 JoinNode {
389 children: self.children,
390 digest,
391 decorator_store: DecoratorStore::Linked { id: future_node_id },
392 }
393 .into(),
394 )
395 .map_err(|_| MastForestError::TooManyNodes)?;
396
397 Ok(node_id)
398 }
399
400 fn fingerprint_for_node(
401 &self,
402 forest: &MastForest,
403 hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
404 ) -> Result<MastNodeFingerprint, MastForestError> {
405 crate::mast::node_fingerprint::fingerprint_from_parts(
407 forest,
408 hash_by_node_id,
409 &self.before_enter,
410 &self.after_exit,
411 &self.children,
412 if let Some(forced_digest) = self.digest {
414 forced_digest
415 } else {
416 let left_child_hash = forest[self.children[0]].digest();
417 let right_child_hash = forest[self.children[1]].digest();
418
419 crate::chiplets::hasher::merge_in_domain(
420 &[left_child_hash, right_child_hash],
421 JoinNode::DOMAIN,
422 )
423 },
424 )
425 }
426
427 fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
428 JoinNodeBuilder {
429 children: [
430 *remapping.get(self.children[0]).unwrap_or(&self.children[0]),
431 *remapping.get(self.children[1]).unwrap_or(&self.children[1]),
432 ],
433 before_enter: self.before_enter,
434 after_exit: self.after_exit,
435 digest: self.digest,
436 }
437 }
438
439 fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
440 self.before_enter = decorators.into();
441 self
442 }
443
444 fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
445 self.after_exit = decorators.into();
446 self
447 }
448
449 fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
450 self.before_enter.extend(decorators);
451 }
452
453 fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
454 self.after_exit.extend(decorators);
455 }
456
457 fn with_digest(mut self, digest: crate::Word) -> Self {
458 self.digest = Some(digest);
459 self
460 }
461}
462
463impl JoinNodeBuilder {
464 pub(in crate::mast) fn add_to_forest_relaxed(
474 self,
475 forest: &mut MastForest,
476 ) -> Result<MastNodeId, MastForestError> {
477 let Some(digest) = self.digest else {
480 return Err(MastForestError::DigestRequiredForDeserialization);
481 };
482
483 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
484
485 let node_id = forest
488 .nodes
489 .push(
490 JoinNode {
491 children: self.children,
492 digest,
493 decorator_store: DecoratorStore::Linked { id: future_node_id },
494 }
495 .into(),
496 )
497 .map_err(|_| MastForestError::TooManyNodes)?;
498
499 Ok(node_id)
500 }
501}
502
503#[cfg(any(test, feature = "arbitrary"))]
504impl proptest::prelude::Arbitrary for JoinNodeBuilder {
505 type Parameters = JoinNodeBuilderParams;
506 type Strategy = proptest::strategy::BoxedStrategy<Self>;
507
508 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
509 use proptest::prelude::*;
510
511 (
512 any::<[MastNodeId; 2]>(),
513 proptest::collection::vec(
514 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
515 0..=params.max_decorators,
516 ),
517 proptest::collection::vec(
518 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
519 0..=params.max_decorators,
520 ),
521 )
522 .prop_map(|(children, before_enter, after_exit)| {
523 Self::new(children).with_before_enter(before_enter).with_after_exit(after_exit)
524 })
525 .boxed()
526 }
527}
528
529#[cfg(any(test, feature = "arbitrary"))]
531#[derive(Clone, Debug)]
532pub struct JoinNodeBuilderParams {
533 pub max_decorators: usize,
534 pub max_decorator_id_u32: u32,
535}
536
537#[cfg(any(test, feature = "arbitrary"))]
538impl Default for JoinNodeBuilderParams {
539 fn default() -> Self {
540 Self {
541 max_decorators: 4,
542 max_decorator_id_u32: 10,
543 }
544 }
545}