miden_core/mast/node/
loop_node.rs1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5use miden_formatting::prettier::PrettyPrint;
6
7use super::MastNodeExt;
8use crate::{
9 OPCODE_LOOP,
10 chiplets::hasher,
11 mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
12};
13
14#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct LoopNode {
25 body: MastNodeId,
26 digest: Word,
27 before_enter: Vec<DecoratorId>,
28 after_exit: Vec<DecoratorId>,
29}
30
31impl LoopNode {
33 pub const DOMAIN: Felt = Felt::new(OPCODE_LOOP as u64);
35}
36
37impl LoopNode {
39 pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
41 if body.as_usize() >= mast_forest.nodes.len() {
42 return Err(MastForestError::NodeIdOverflow(body, mast_forest.nodes.len()));
43 }
44 let digest = {
45 let body_hash = mast_forest[body].digest();
46
47 hasher::merge_in_domain(&[body_hash, Word::default()], Self::DOMAIN)
48 };
49
50 Ok(Self {
51 body,
52 digest,
53 before_enter: Vec::new(),
54 after_exit: Vec::new(),
55 })
56 }
57
58 pub fn new_unsafe(body: MastNodeId, digest: Word) -> Self {
61 Self {
62 body,
63 digest,
64 before_enter: Vec::new(),
65 after_exit: Vec::new(),
66 }
67 }
68}
69
70impl LoopNode {
71 pub fn digest(&self) -> Word {
82 self.digest
83 }
84
85 pub fn body(&self) -> MastNodeId {
87 self.body
88 }
89
90 pub fn before_enter(&self) -> &[DecoratorId] {
92 &self.before_enter
93 }
94
95 pub fn after_exit(&self) -> &[DecoratorId] {
97 &self.after_exit
98 }
99}
100
101impl LoopNode {
104 pub fn remap_children(&self, remapping: &Remapping) -> Self {
105 let mut node = self.clone();
106 node.body = node.body.remap(remapping);
107 node
108 }
109
110 pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
112 self.before_enter.extend_from_slice(decorator_ids);
113 }
114
115 pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
117 self.after_exit.extend_from_slice(decorator_ids);
118 }
119
120 pub fn remove_decorators(&mut self) {
122 self.before_enter.truncate(0);
123 self.after_exit.truncate(0);
124 }
125}
126
127impl MastNodeExt for LoopNode {
128 fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
129 self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
130 }
131}
132
133impl LoopNode {
137 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
138 LoopNodePrettyPrint { loop_node: self, mast_forest }
139 }
140
141 pub(super) fn to_pretty_print<'a>(
142 &'a self,
143 mast_forest: &'a MastForest,
144 ) -> impl PrettyPrint + 'a {
145 LoopNodePrettyPrint { loop_node: self, mast_forest }
146 }
147}
148
149struct LoopNodePrettyPrint<'a> {
150 loop_node: &'a LoopNode,
151 mast_forest: &'a MastForest,
152}
153
154impl crate::prettier::PrettyPrint for LoopNodePrettyPrint<'_> {
155 fn render(&self) -> crate::prettier::Document {
156 use crate::prettier::*;
157
158 let pre_decorators = {
159 let mut pre_decorators = self
160 .loop_node
161 .before_enter()
162 .iter()
163 .map(|&decorator_id| self.mast_forest[decorator_id].render())
164 .reduce(|acc, doc| acc + const_text(" ") + doc)
165 .unwrap_or_default();
166 if !pre_decorators.is_empty() {
167 pre_decorators += nl();
168 }
169
170 pre_decorators
171 };
172
173 let post_decorators = {
174 let mut post_decorators = self
175 .loop_node
176 .after_exit()
177 .iter()
178 .map(|&decorator_id| self.mast_forest[decorator_id].render())
179 .reduce(|acc, doc| acc + const_text(" ") + doc)
180 .unwrap_or_default();
181 if !post_decorators.is_empty() {
182 post_decorators = nl() + post_decorators;
183 }
184
185 post_decorators
186 };
187
188 let loop_body = self.mast_forest[self.loop_node.body].to_pretty_print(self.mast_forest);
189
190 pre_decorators
191 + indent(4, const_text("while.true") + nl() + loop_body.render())
192 + nl()
193 + const_text("end")
194 + post_decorators
195 }
196}
197
198impl fmt::Display for LoopNodePrettyPrint<'_> {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 use crate::prettier::PrettyPrint;
201 self.pretty_print(f)
202 }
203}