1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8use super::{MastForestContributor, MastNodeErrorContext, MastNodeExt};
9use crate::{
10 Idx, OPCODE_JOIN,
11 chiplets::hasher,
12 mast::{
13 DecoratedOpLink, DecoratorId, DecoratorStore, MastForest, MastForestError, MastNode,
14 MastNodeId,
15 },
16 prettier::PrettyPrint,
17};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
27pub struct JoinNode {
28 children: [MastNodeId; 2],
29 digest: Word,
30 decorator_store: DecoratorStore,
31}
32
33impl JoinNode {
35 pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
37}
38
39impl JoinNode {
41 pub fn first(&self) -> MastNodeId {
43 self.children[0]
44 }
45
46 pub fn second(&self) -> MastNodeId {
49 self.children[1]
50 }
51}
52
53impl MastNodeErrorContext for JoinNode {
54 fn decorators<'a>(
55 &'a self,
56 forest: &'a MastForest,
57 ) -> impl Iterator<Item = DecoratedOpLink> + 'a {
58 let before_enter = self.decorator_store.before_enter(forest);
60 let after_exit = self.decorator_store.after_exit(forest);
61
62 before_enter
64 .iter()
65 .map(|&deco_id| (0, deco_id))
66 .chain(after_exit.iter().map(|&deco_id| (1, deco_id)))
67 }
68}
69
70impl JoinNode {
74 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
75 JoinNodePrettyPrint { join_node: self, mast_forest }
76 }
77
78 pub(super) fn to_pretty_print<'a>(
79 &'a self,
80 mast_forest: &'a MastForest,
81 ) -> impl PrettyPrint + 'a {
82 JoinNodePrettyPrint { join_node: self, mast_forest }
83 }
84}
85
86struct JoinNodePrettyPrint<'a> {
87 join_node: &'a JoinNode,
88 mast_forest: &'a MastForest,
89}
90
91impl PrettyPrint for JoinNodePrettyPrint<'_> {
92 #[rustfmt::skip]
93 fn render(&self) -> crate::prettier::Document {
94 use crate::prettier::*;
95
96 let pre_decorators = {
97 let mut pre_decorators = self
98 .join_node
99 .before_enter(self.mast_forest)
100 .iter()
101 .map(|&decorator_id| self.mast_forest[decorator_id].render())
102 .reduce(|acc, doc| acc + const_text(" ") + doc)
103 .unwrap_or_default();
104 if !pre_decorators.is_empty() {
105 pre_decorators += nl();
106 }
107
108 pre_decorators
109 };
110
111 let post_decorators = {
112 let mut post_decorators = self
113 .join_node
114 .after_exit(self.mast_forest)
115 .iter()
116 .map(|&decorator_id| self.mast_forest[decorator_id].render())
117 .reduce(|acc, doc| acc + const_text(" ") + doc)
118 .unwrap_or_default();
119 if !post_decorators.is_empty() {
120 post_decorators = nl() + post_decorators;
121 }
122
123 post_decorators
124 };
125
126 let first_child =
127 self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
128 let second_child =
129 self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
130
131 pre_decorators
132 + indent(
133 4,
134 const_text("join")
135 + nl()
136 + first_child.render()
137 + nl()
138 + second_child.render(),
139 ) + nl() + const_text("end")
140 + post_decorators
141 }
142}
143
144impl fmt::Display for JoinNodePrettyPrint<'_> {
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 use crate::prettier::PrettyPrint;
147 self.pretty_print(f)
148 }
149}
150
151#[cfg(test)]
155impl JoinNode {
156 #[cfg(test)]
162 pub fn semantic_eq(&self, other: &JoinNode, forest: &MastForest) -> bool {
163 if self.first() != other.first() || self.second() != other.second() {
165 return false;
166 }
167
168 if self.digest() != other.digest() {
170 return false;
171 }
172
173 if self.before_enter(forest) != other.before_enter(forest) {
175 return false;
176 }
177
178 if self.after_exit(forest) != other.after_exit(forest) {
180 return false;
181 }
182
183 true
184 }
185}
186
187impl MastNodeExt for JoinNode {
191 fn digest(&self) -> Word {
203 self.digest
204 }
205
206 fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
208 #[cfg(debug_assertions)]
209 self.verify_node_in_forest(forest);
210 self.decorator_store.before_enter(forest)
211 }
212
213 fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
215 #[cfg(debug_assertions)]
216 self.verify_node_in_forest(forest);
217 self.decorator_store.after_exit(forest)
218 }
219
220 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
221 Box::new(JoinNode::to_display(self, mast_forest))
222 }
223
224 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
225 Box::new(JoinNode::to_pretty_print(self, mast_forest))
226 }
227
228 fn has_children(&self) -> bool {
229 true
230 }
231
232 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
233 target.push(self.first());
234 target.push(self.second());
235 }
236
237 fn for_each_child<F>(&self, mut f: F)
238 where
239 F: FnMut(MastNodeId),
240 {
241 f(self.first());
242 f(self.second());
243 }
244
245 fn domain(&self) -> Felt {
246 Self::DOMAIN
247 }
248
249 type Builder = JoinNodeBuilder;
250
251 fn to_builder(self, forest: &MastForest) -> Self::Builder {
252 match self.decorator_store {
254 DecoratorStore::Owned { before_enter, after_exit, .. } => {
255 let mut builder = JoinNodeBuilder::new(self.children);
256 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
257 builder
258 },
259 DecoratorStore::Linked { id } => {
260 let before_enter = forest.before_enter_decorators(id).to_vec();
262 let after_exit = forest.after_exit_decorators(id).to_vec();
263 let mut builder = JoinNodeBuilder::new(self.children);
264 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
265 builder
266 },
267 }
268 }
269
270 #[cfg(debug_assertions)]
271 fn verify_node_in_forest(&self, forest: &MastForest) {
272 if let Some(id) = self.decorator_store.linked_id() {
273 let self_ptr = self as *const Self;
275 let forest_node = &forest.nodes[id];
276 let forest_node_ptr = match forest_node {
277 MastNode::Join(join_node) => join_node as *const JoinNode as *const (),
278 _ => panic!("Node type mismatch at {:?}", id),
279 };
280 let self_as_void = self_ptr as *const ();
281 debug_assert_eq!(
282 self_as_void, forest_node_ptr,
283 "Node pointer mismatch: expected node at {:?} to be self",
284 id
285 );
286 }
287 }
288}
289
290#[cfg(all(feature = "arbitrary", test))]
294impl proptest::prelude::Arbitrary for JoinNode {
295 type Parameters = ();
296
297 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
298 use proptest::prelude::*;
299
300 use crate::Felt;
301
302 (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
304 .prop_map(|(first_child, second_child, digest_array)| {
305 let digest = Word::from(digest_array.map(Felt::new));
307 JoinNode {
309 children: [first_child, second_child],
310 digest,
311 decorator_store: DecoratorStore::default(),
312 }
313 })
314 .no_shrink() .boxed()
316 }
317
318 type Strategy = proptest::prelude::BoxedStrategy<Self>;
319}
320
321#[derive(Debug)]
324pub struct JoinNodeBuilder {
325 children: [MastNodeId; 2],
326 before_enter: Vec<DecoratorId>,
327 after_exit: Vec<DecoratorId>,
328 digest: Option<Word>,
329}
330
331impl JoinNodeBuilder {
332 pub fn new(children: [MastNodeId; 2]) -> Self {
334 Self {
335 children,
336 before_enter: Vec::new(),
337 after_exit: Vec::new(),
338 digest: None,
339 }
340 }
341
342 pub fn build(self, mast_forest: &MastForest) -> Result<JoinNode, MastForestError> {
344 let forest_len = mast_forest.nodes.len();
345 if self.children[0].to_usize() >= forest_len {
346 return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
347 } else if self.children[1].to_usize() >= forest_len {
348 return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
349 }
350
351 let digest = if let Some(forced_digest) = self.digest {
353 forced_digest
354 } else {
355 let left_child_hash = mast_forest[self.children[0]].digest();
356 let right_child_hash = mast_forest[self.children[1]].digest();
357
358 hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
359 };
360
361 Ok(JoinNode {
362 children: self.children,
363 digest,
364 decorator_store: DecoratorStore::new_owned_with_decorators(
365 self.before_enter,
366 self.after_exit,
367 ),
368 })
369 }
370}
371
372impl MastForestContributor for JoinNodeBuilder {
373 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
374 let forest_len = forest.nodes.len();
376 if self.children[0].to_usize() >= forest_len {
377 return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
378 } else if self.children[1].to_usize() >= forest_len {
379 return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
380 }
381
382 let digest = if let Some(forced_digest) = self.digest {
384 forced_digest
385 } else {
386 let left_child_hash = forest[self.children[0]].digest();
387 let right_child_hash = forest[self.children[1]].digest();
388
389 hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
390 };
391
392 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
394
395 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
397
398 let node_id = forest
401 .nodes
402 .push(
403 JoinNode {
404 children: self.children,
405 digest,
406 decorator_store: DecoratorStore::Linked { id: future_node_id },
407 }
408 .into(),
409 )
410 .map_err(|_| MastForestError::TooManyNodes)?;
411
412 Ok(node_id)
413 }
414
415 fn fingerprint_for_node(
416 &self,
417 forest: &MastForest,
418 hash_by_node_id: &impl crate::LookupByIdx<MastNodeId, crate::mast::MastNodeFingerprint>,
419 ) -> Result<crate::mast::MastNodeFingerprint, MastForestError> {
420 crate::mast::node_fingerprint::fingerprint_from_parts(
422 forest,
423 hash_by_node_id,
424 &self.before_enter,
425 &self.after_exit,
426 &self.children,
427 if let Some(forced_digest) = self.digest {
429 forced_digest
430 } else {
431 let left_child_hash = forest[self.children[0]].digest();
432 let right_child_hash = forest[self.children[1]].digest();
433
434 crate::chiplets::hasher::merge_in_domain(
435 &[left_child_hash, right_child_hash],
436 JoinNode::DOMAIN,
437 )
438 },
439 )
440 }
441
442 fn remap_children(
443 self,
444 remapping: &impl crate::LookupByIdx<crate::mast::MastNodeId, crate::mast::MastNodeId>,
445 ) -> Self {
446 JoinNodeBuilder {
447 children: [
448 *remapping.get(self.children[0]).unwrap_or(&self.children[0]),
449 *remapping.get(self.children[1]).unwrap_or(&self.children[1]),
450 ],
451 before_enter: self.before_enter,
452 after_exit: self.after_exit,
453 digest: self.digest,
454 }
455 }
456
457 fn with_before_enter(mut self, decorators: impl Into<Vec<crate::mast::DecoratorId>>) -> Self {
458 self.before_enter = decorators.into();
459 self
460 }
461
462 fn with_after_exit(mut self, decorators: impl Into<Vec<crate::mast::DecoratorId>>) -> Self {
463 self.after_exit = decorators.into();
464 self
465 }
466
467 fn append_before_enter(
468 &mut self,
469 decorators: impl IntoIterator<Item = crate::mast::DecoratorId>,
470 ) {
471 self.before_enter.extend(decorators);
472 }
473
474 fn append_after_exit(
475 &mut self,
476 decorators: impl IntoIterator<Item = crate::mast::DecoratorId>,
477 ) {
478 self.after_exit.extend(decorators);
479 }
480
481 fn with_digest(mut self, digest: crate::Word) -> Self {
482 self.digest = Some(digest);
483 self
484 }
485}
486
487impl JoinNodeBuilder {
488 pub(in crate::mast) fn add_to_forest_relaxed(
498 self,
499 forest: &mut MastForest,
500 ) -> Result<MastNodeId, MastForestError> {
501 let Some(digest) = self.digest else {
504 return Err(MastForestError::DigestRequiredForDeserialization);
505 };
506
507 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
508
509 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
511
512 let node_id = forest
515 .nodes
516 .push(
517 JoinNode {
518 children: self.children,
519 digest,
520 decorator_store: DecoratorStore::Linked { id: future_node_id },
521 }
522 .into(),
523 )
524 .map_err(|_| MastForestError::TooManyNodes)?;
525
526 Ok(node_id)
527 }
528}
529
530#[cfg(any(test, feature = "arbitrary"))]
531impl proptest::prelude::Arbitrary for JoinNodeBuilder {
532 type Parameters = JoinNodeBuilderParams;
533 type Strategy = proptest::strategy::BoxedStrategy<Self>;
534
535 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
536 use proptest::prelude::*;
537
538 (
539 any::<[crate::mast::MastNodeId; 2]>(),
540 proptest::collection::vec(
541 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
542 0..=params.max_decorators,
543 ),
544 proptest::collection::vec(
545 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
546 0..=params.max_decorators,
547 ),
548 )
549 .prop_map(|(children, before_enter, after_exit)| {
550 Self::new(children).with_before_enter(before_enter).with_after_exit(after_exit)
551 })
552 .boxed()
553 }
554}
555
556#[cfg(any(test, feature = "arbitrary"))]
558#[derive(Clone, Debug)]
559pub struct JoinNodeBuilderParams {
560 pub max_decorators: usize,
561 pub max_decorator_id_u32: u32,
562}
563
564#[cfg(any(test, feature = "arbitrary"))]
565impl Default for JoinNodeBuilderParams {
566 fn default() -> Self {
567 Self {
568 max_decorators: 4,
569 max_decorator_id_u32: 10,
570 }
571 }
572}