1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt};
8use crate::{
9 Felt, Word,
10 chiplets::hasher,
11 mast::{
12 DecoratorId, DecoratorStore, MastForest, MastForestError, MastNode, MastNodeFingerprint,
13 MastNodeId,
14 },
15 operations::OPCODE_JOIN,
16 prettier::PrettyPrint,
17 utils::{Idx, LookupByIdx},
18};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
28pub struct JoinNode {
29 children: [MastNodeId; 2],
30 digest: Word,
31 decorator_store: DecoratorStore,
32}
33
34impl JoinNode {
36 pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
38}
39
40impl JoinNode {
42 pub fn first(&self) -> MastNodeId {
44 self.children[0]
45 }
46
47 pub fn second(&self) -> MastNodeId {
50 self.children[1]
51 }
52}
53
54impl JoinNode {
58 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
59 JoinNodePrettyPrint { join_node: self, mast_forest }
60 }
61
62 pub(super) fn to_pretty_print<'a>(
63 &'a self,
64 mast_forest: &'a MastForest,
65 ) -> impl PrettyPrint + 'a {
66 JoinNodePrettyPrint { join_node: self, mast_forest }
67 }
68}
69
70struct JoinNodePrettyPrint<'a> {
71 join_node: &'a JoinNode,
72 mast_forest: &'a MastForest,
73}
74
75impl PrettyPrint for JoinNodePrettyPrint<'_> {
76 #[rustfmt::skip]
77 fn render(&self) -> crate::prettier::Document {
78 use crate::prettier::*;
79
80 let pre_decorators = {
81 let mut pre_decorators = self
82 .join_node
83 .before_enter(self.mast_forest)
84 .iter()
85 .map(|&decorator_id| self.mast_forest[decorator_id].render())
86 .reduce(|acc, doc| acc + const_text(" ") + doc)
87 .unwrap_or_default();
88 if !pre_decorators.is_empty() {
89 pre_decorators += nl();
90 }
91
92 pre_decorators
93 };
94
95 let post_decorators = {
96 let mut post_decorators = self
97 .join_node
98 .after_exit(self.mast_forest)
99 .iter()
100 .map(|&decorator_id| self.mast_forest[decorator_id].render())
101 .reduce(|acc, doc| acc + const_text(" ") + doc)
102 .unwrap_or_default();
103 if !post_decorators.is_empty() {
104 post_decorators = nl() + post_decorators;
105 }
106
107 post_decorators
108 };
109
110 let first_child =
111 self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
112 let second_child =
113 self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
114
115 pre_decorators
116 + indent(
117 4,
118 const_text("join")
119 + nl()
120 + first_child.render()
121 + nl()
122 + second_child.render(),
123 ) + nl() + const_text("end")
124 + post_decorators
125 }
126}
127
128impl fmt::Display for JoinNodePrettyPrint<'_> {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 use crate::prettier::PrettyPrint;
131 self.pretty_print(f)
132 }
133}
134
135#[cfg(test)]
139impl JoinNode {
140 #[cfg(test)]
146 pub fn semantic_eq(&self, other: &JoinNode, forest: &MastForest) -> bool {
147 if self.first() != other.first() || self.second() != other.second() {
149 return false;
150 }
151
152 if self.digest() != other.digest() {
154 return false;
155 }
156
157 if self.before_enter(forest) != other.before_enter(forest) {
159 return false;
160 }
161
162 if self.after_exit(forest) != other.after_exit(forest) {
164 return false;
165 }
166
167 true
168 }
169}
170
171impl MastNodeExt for JoinNode {
175 fn digest(&self) -> Word {
187 self.digest
188 }
189
190 fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
192 #[cfg(debug_assertions)]
193 self.verify_node_in_forest(forest);
194 self.decorator_store.before_enter(forest)
195 }
196
197 fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
199 #[cfg(debug_assertions)]
200 self.verify_node_in_forest(forest);
201 self.decorator_store.after_exit(forest)
202 }
203
204 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
205 Box::new(JoinNode::to_display(self, mast_forest))
206 }
207
208 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
209 Box::new(JoinNode::to_pretty_print(self, mast_forest))
210 }
211
212 fn has_children(&self) -> bool {
213 true
214 }
215
216 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
217 target.push(self.first());
218 target.push(self.second());
219 }
220
221 fn for_each_child<F>(&self, mut f: F)
222 where
223 F: FnMut(MastNodeId),
224 {
225 f(self.first());
226 f(self.second());
227 }
228
229 fn domain(&self) -> Felt {
230 Self::DOMAIN
231 }
232
233 type Builder = JoinNodeBuilder;
234
235 fn to_builder(self, forest: &MastForest) -> Self::Builder {
236 match self.decorator_store {
238 DecoratorStore::Owned { before_enter, after_exit, .. } => {
239 let mut builder = JoinNodeBuilder::new(self.children);
240 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
241 builder
242 },
243 DecoratorStore::Linked { id } => {
244 let before_enter = forest.before_enter_decorators(id).to_vec();
246 let after_exit = forest.after_exit_decorators(id).to_vec();
247 let mut builder = JoinNodeBuilder::new(self.children);
248 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
249 builder
250 },
251 }
252 }
253
254 #[cfg(debug_assertions)]
255 fn verify_node_in_forest(&self, forest: &MastForest) {
256 if let Some(id) = self.decorator_store.linked_id() {
257 let self_ptr = self as *const Self;
259 let forest_node = &forest.nodes[id];
260 let forest_node_ptr = match forest_node {
261 MastNode::Join(join_node) => join_node as *const JoinNode as *const (),
262 _ => panic!("Node type mismatch at {:?}", id),
263 };
264 let self_as_void = self_ptr as *const ();
265 debug_assert_eq!(
266 self_as_void, forest_node_ptr,
267 "Node pointer mismatch: expected node at {:?} to be self",
268 id
269 );
270 }
271 }
272}
273
274#[cfg(all(feature = "arbitrary", test))]
278impl proptest::prelude::Arbitrary for JoinNode {
279 type Parameters = ();
280
281 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
282 use proptest::prelude::*;
283
284 use crate::Felt;
285
286 (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
288 .prop_map(|(first_child, second_child, digest_array)| {
289 let digest = Word::from(digest_array.map(Felt::new));
291 JoinNode {
293 children: [first_child, second_child],
294 digest,
295 decorator_store: DecoratorStore::default(),
296 }
297 })
298 .no_shrink() .boxed()
300 }
301
302 type Strategy = proptest::prelude::BoxedStrategy<Self>;
303}
304
305#[derive(Debug)]
308pub struct JoinNodeBuilder {
309 children: [MastNodeId; 2],
310 before_enter: Vec<DecoratorId>,
311 after_exit: Vec<DecoratorId>,
312 digest: Option<Word>,
313}
314
315impl JoinNodeBuilder {
316 pub fn new(children: [MastNodeId; 2]) -> Self {
318 Self {
319 children,
320 before_enter: Vec::new(),
321 after_exit: Vec::new(),
322 digest: None,
323 }
324 }
325
326 pub fn build(self, mast_forest: &MastForest) -> Result<JoinNode, MastForestError> {
328 let forest_len = mast_forest.nodes.len();
329 if self.children[0].to_usize() >= forest_len {
330 return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
331 } else if self.children[1].to_usize() >= forest_len {
332 return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
333 }
334
335 let digest = if let Some(forced_digest) = self.digest {
337 forced_digest
338 } else {
339 let left_child_hash = mast_forest[self.children[0]].digest();
340 let right_child_hash = mast_forest[self.children[1]].digest();
341
342 hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
343 };
344
345 Ok(JoinNode {
346 children: self.children,
347 digest,
348 decorator_store: DecoratorStore::new_owned_with_decorators(
349 self.before_enter,
350 self.after_exit,
351 ),
352 })
353 }
354}
355
356impl MastForestContributor for JoinNodeBuilder {
357 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
358 let forest_len = forest.nodes.len();
360 if self.children[0].to_usize() >= forest_len {
361 return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
362 } else if self.children[1].to_usize() >= forest_len {
363 return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
364 }
365
366 let digest = if let Some(forced_digest) = self.digest {
368 forced_digest
369 } else {
370 let left_child_hash = forest[self.children[0]].digest();
371 let right_child_hash = forest[self.children[1]].digest();
372
373 hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
374 };
375
376 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
378
379 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
381
382 let node_id = forest
385 .nodes
386 .push(
387 JoinNode {
388 children: self.children,
389 digest,
390 decorator_store: DecoratorStore::Linked { id: future_node_id },
391 }
392 .into(),
393 )
394 .map_err(|_| MastForestError::TooManyNodes)?;
395
396 Ok(node_id)
397 }
398
399 fn fingerprint_for_node(
400 &self,
401 forest: &MastForest,
402 hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
403 ) -> Result<MastNodeFingerprint, MastForestError> {
404 crate::mast::node_fingerprint::fingerprint_from_parts(
406 forest,
407 hash_by_node_id,
408 &self.before_enter,
409 &self.after_exit,
410 &self.children,
411 if let Some(forced_digest) = self.digest {
413 forced_digest
414 } else {
415 let left_child_hash = forest[self.children[0]].digest();
416 let right_child_hash = forest[self.children[1]].digest();
417
418 crate::chiplets::hasher::merge_in_domain(
419 &[left_child_hash, right_child_hash],
420 JoinNode::DOMAIN,
421 )
422 },
423 )
424 }
425
426 fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
427 JoinNodeBuilder {
428 children: [
429 *remapping.get(self.children[0]).unwrap_or(&self.children[0]),
430 *remapping.get(self.children[1]).unwrap_or(&self.children[1]),
431 ],
432 before_enter: self.before_enter,
433 after_exit: self.after_exit,
434 digest: self.digest,
435 }
436 }
437
438 fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
439 self.before_enter = decorators.into();
440 self
441 }
442
443 fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
444 self.after_exit = decorators.into();
445 self
446 }
447
448 fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
449 self.before_enter.extend(decorators);
450 }
451
452 fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
453 self.after_exit.extend(decorators);
454 }
455
456 fn with_digest(mut self, digest: crate::Word) -> Self {
457 self.digest = Some(digest);
458 self
459 }
460}
461
462impl JoinNodeBuilder {
463 pub(in crate::mast) fn add_to_forest_relaxed(
473 self,
474 forest: &mut MastForest,
475 ) -> Result<MastNodeId, MastForestError> {
476 let Some(digest) = self.digest else {
479 return Err(MastForestError::DigestRequiredForDeserialization);
480 };
481
482 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
483
484 let node_id = forest
487 .nodes
488 .push(
489 JoinNode {
490 children: self.children,
491 digest,
492 decorator_store: DecoratorStore::Linked { id: future_node_id },
493 }
494 .into(),
495 )
496 .map_err(|_| MastForestError::TooManyNodes)?;
497
498 Ok(node_id)
499 }
500}
501
502#[cfg(any(test, feature = "arbitrary"))]
503impl proptest::prelude::Arbitrary for JoinNodeBuilder {
504 type Parameters = JoinNodeBuilderParams;
505 type Strategy = proptest::strategy::BoxedStrategy<Self>;
506
507 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
508 use proptest::prelude::*;
509
510 (
511 any::<[MastNodeId; 2]>(),
512 proptest::collection::vec(
513 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
514 0..=params.max_decorators,
515 ),
516 proptest::collection::vec(
517 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
518 0..=params.max_decorators,
519 ),
520 )
521 .prop_map(|(children, before_enter, after_exit)| {
522 Self::new(children).with_before_enter(before_enter).with_after_exit(after_exit)
523 })
524 .boxed()
525 }
526}
527
528#[cfg(any(test, feature = "arbitrary"))]
530#[derive(Clone, Debug)]
531pub struct JoinNodeBuilderParams {
532 pub max_decorators: usize,
533 pub max_decorator_id_u32: u32,
534}
535
536#[cfg(any(test, feature = "arbitrary"))]
537impl Default for JoinNodeBuilderParams {
538 fn default() -> Self {
539 Self {
540 max_decorators: 4,
541 max_decorator_id_u32: 10,
542 }
543 }
544}