air_parser/transforms/constant_propagation.rs
1use std::{
2 collections::{HashMap, HashSet},
3 ops::ControlFlow,
4};
5
6use air_pass::Pass;
7use either::Either::{self, Left, Right};
8use miden_diagnostics::{DiagnosticsHandler, Span, Spanned};
9
10use crate::{
11 ast::{visit::VisitMut, *},
12 sema::{LexicalScope, SemanticAnalysisError},
13 symbols,
14};
15
16/// This pass performs constant propagation on a [Program], replacing all uses of a constant
17/// with the constant itself, converting accesses into constant aggregates with the accessed
18/// value, replacing local variables bound to constants with the constant value, and folding
19/// constant expressions into constant values.
20///
21/// It is expected that the provided [Program] has already been run through semantic analysis,
22/// so it will panic if it encounters invalid constructions to help catch bugs in the semantic
23/// analysis pass, should they exist.
24pub struct ConstantPropagation<'a> {
25 #[allow(unused)]
26 diagnostics: &'a DiagnosticsHandler,
27 global: HashMap<QualifiedIdentifier, Span<ConstantExpr>>,
28 local: LexicalScope<Identifier, Span<ConstantExpr>>,
29 /// The set of identifiers which are live (in use) in the current scope
30 live: HashSet<Identifier>,
31 in_constraint_comprehension: bool,
32 in_list_comprehension: bool,
33}
34impl Pass for ConstantPropagation<'_> {
35 type Input<'a> = Program;
36 type Output<'a> = Program;
37 type Error = SemanticAnalysisError;
38
39 fn run<'a>(&mut self, mut program: Self::Input<'a>) -> Result<Self::Output<'a>, Self::Error> {
40 self.global.reserve(program.constants.len());
41
42 match self.run_visitor(&mut program) {
43 ControlFlow::Continue(()) => Ok(program),
44 ControlFlow::Break(err) => {
45 self.diagnostics.emit(err.clone());
46 Err(err)
47 }
48 }
49 }
50}
51impl<'a> ConstantPropagation<'a> {
52 pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self {
53 Self {
54 diagnostics,
55 global: Default::default(),
56 local: Default::default(),
57 live: Default::default(),
58 in_constraint_comprehension: false,
59 in_list_comprehension: false,
60 }
61 }
62
63 fn run_visitor(&mut self, program: &mut Program) -> ControlFlow<SemanticAnalysisError> {
64 // Record all of the constant declarations
65 for (name, constant) in program.constants.iter() {
66 assert_eq!(
67 self.global
68 .insert(*name, Span::new(constant.span(), constant.value.clone())),
69 None
70 );
71 }
72
73 // Visit all of the evaluators
74 for evaluator in program.evaluators.values_mut() {
75 self.visit_mut_evaluator_function(evaluator)?;
76 }
77
78 // Visit all of the functions
79 for function in program.functions.values_mut() {
80 self.visit_mut_function(function)?;
81 }
82
83 // Visit all of the buses
84 for bus in program.buses.values_mut() {
85 self.visit_mut_bus(bus)?;
86 }
87
88 // Visit all of the constraints
89 self.visit_mut_boundary_constraints(&mut program.boundary_constraints)?;
90 self.visit_mut_integrity_constraints(&mut program.integrity_constraints)
91 }
92
93 fn try_fold_binary_expr(
94 &mut self,
95 expr: &mut BinaryExpr,
96 ) -> Result<Option<Span<u64>>, SemanticAnalysisError> {
97 // Visit operands first to ensure they are reduced to constants if possible
98 if let ControlFlow::Break(err) = self.visit_mut_scalar_expr(expr.lhs.as_mut()) {
99 return Err(err);
100 }
101 if let ControlFlow::Break(err) = self.visit_mut_scalar_expr(expr.rhs.as_mut()) {
102 return Err(err);
103 }
104 // If both operands are constant, fold
105 try_fold_binary_expr(expr).map_err(SemanticAnalysisError::InvalidExpr)
106 }
107
108 /// When folding a `let`, one of the following can occur:
109 ///
110 /// * The let-bound variable is non-constant, so the entire let must remain, but we
111 /// can constant-propagate as much of the bound expression and body as possible.
112 /// * The let-bound variable is constant, so once we have constant propagated the body,
113 /// the let is no longer needed, and one of the following happens:
114 /// * The `let` terminates with a constant expression, so the entire `let` is replaced
115 /// with that expression.
116 /// * The `let` terminates with a non-constant expression, or a constraint, so we inline
117 /// the let body into the containing block. In the non-constant expression case, we
118 /// replace the `let` with the last expression in the returned block, since in expression
119 /// position, we may not have a statement block to inline into.
120 fn try_fold_let_expr(
121 &mut self,
122 expr: &mut Let,
123 ) -> Result<Either<Option<Span<ConstantExpr>>, Vec<Statement>>, SemanticAnalysisError> {
124 // Visit the binding expression first
125 if let ControlFlow::Break(err) = self.visit_mut_expr(&mut expr.value) {
126 return Err(err);
127 }
128
129 // Enter a new lexical scope
130 let prev_live = core::mem::take(&mut self.live);
131 self.local.enter();
132 // If the value is constant, record it in our bindings map
133 let is_constant = expr.value.is_constant();
134 if is_constant {
135 match expr.value {
136 Expr::Const(ref value) => {
137 self.local.insert(expr.name, value.clone());
138 }
139 Expr::Range(ref range) => {
140 let span = range.span();
141 let range = range.to_slice_range();
142 let vector = range.map(|i| i as u64).collect();
143 self.local
144 .insert(expr.name, Span::new(span, ConstantExpr::Vector(vector)));
145 }
146 _ => unreachable!(),
147 }
148 }
149
150 // Visit the let body
151 if let ControlFlow::Break(err) = self.visit_mut_statement_block(&mut expr.body) {
152 return Err(err);
153 }
154
155 // If this let is constant, then the binding is no longer
156 // used in the body after constant propagation, so we can
157 // fold away the let entirely
158 let is_live = self.live.contains(&expr.name);
159 let result = if is_constant && !is_live {
160 match expr.body.last().unwrap() {
161 Statement::Expr(Expr::Const(const_value)) => {
162 Left(Some(Span::new(expr.span(), const_value.item.clone())))
163 }
164 _ => Right(core::mem::take(&mut expr.body)),
165 }
166 } else {
167 Left(None)
168 };
169
170 // Propagate liveness from the body of the let to its parent scope
171 let mut live = core::mem::take(&mut self.live);
172 live.remove(&expr.name);
173 self.live = &prev_live | &live;
174
175 // Restore the previous scope
176 self.local.exit();
177
178 Ok(result)
179 }
180}
181impl VisitMut<SemanticAnalysisError> for ConstantPropagation<'_> {
182 /// Fold constant expressions
183 fn visit_mut_scalar_expr(
184 &mut self,
185 expr: &mut ScalarExpr,
186 ) -> ControlFlow<SemanticAnalysisError> {
187 match expr {
188 // Expression is already folded
189 ScalarExpr::Const(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => {
190 ControlFlow::Continue(())
191 }
192 // Need to check if this access is to a constant value, and transform to a constant if so
193 ScalarExpr::SymbolAccess(sym) => {
194 let constant_value = match sym.name {
195 // Possibly a reference to a constant declaration
196 ResolvableIdentifier::Resolved(ref qid) => {
197 self.global.get(qid).cloned().map(|s| (s.span(), s.item))
198 }
199 // Possibly a reference to a local bound to a constant
200 ResolvableIdentifier::Local(ref id) => {
201 self.local.get(id).cloned().map(|s| (s.span(), s.item))
202 }
203 // Other identifiers cannot possibly be constant
204 _ => None,
205 };
206 if let Some((span, constant_expr)) = constant_value {
207 match constant_expr {
208 ConstantExpr::Scalar(value) => {
209 assert_eq!(sym.access_type, AccessType::Default);
210 *expr = ScalarExpr::Const(Span::new(span, value));
211 }
212 ConstantExpr::Vector(value) => match sym.access_type {
213 AccessType::Index(idx) => {
214 *expr = ScalarExpr::Const(Span::new(span, value[idx]));
215 }
216 // This access cannot be resolved here, so we need to record the fact
217 // that there are still live uses of this binding
218 _ => {
219 self.live.insert(*sym.name.as_ref());
220 }
221 },
222 ConstantExpr::Matrix(value) => match sym.access_type {
223 AccessType::Matrix(row, col) => {
224 *expr = ScalarExpr::Const(Span::new(span, value[row][col]));
225 }
226 // This access cannot be resolved here, so we need to record the fact
227 // that there are still live uses of this binding
228 _ => {
229 self.live.insert(*sym.name.as_ref());
230 }
231 },
232 }
233 } else {
234 // This value is not constant, so there are live uses of this symbol
235 self.live.insert(*sym.name.as_ref());
236 }
237 ControlFlow::Continue(())
238 }
239 // Fold constant expressions
240 ScalarExpr::Binary(binary_expr) => {
241 match self.try_fold_binary_expr(binary_expr) {
242 Ok(maybe_folded) => {
243 if let Some(folded) = maybe_folded {
244 *expr = ScalarExpr::Const(folded);
245 }
246 ControlFlow::Continue(())
247 }
248 Err(SemanticAnalysisError::InvalidExpr(
249 InvalidExprError::NonConstantExponent(_),
250 )) if self.in_list_comprehension => {
251 // If we are in a list comprehension, we do not know iterators'
252 // lengths yet, since loop unrolling happens during MIR passes.
253 // The check for non-constant exponents in list comprehensions is done
254 // during lowering from MIR to AIR, so we can safely silence it here.
255 ControlFlow::Continue(())
256 }
257 Err(err) => ControlFlow::Break(err),
258 }
259 }
260 // While calls cannot be constant folded, arguments can be
261 ScalarExpr::Call(call) => self.visit_mut_call(call),
262 // This cannot be constant folded
263 ScalarExpr::BoundedSymbolAccess(_) => ControlFlow::Continue(()),
264 // A let that evaluates to a constant value can be folded to the constant value
265 ScalarExpr::Let(let_expr) => {
266 match self.try_fold_let_expr(let_expr) {
267 Ok(Left(Some(const_expr))) => {
268 let span = const_expr.span();
269 match const_expr.item {
270 ConstantExpr::Scalar(value) => {
271 *expr = ScalarExpr::Const(Span::new(span, value));
272 }
273 _ => {
274 self.diagnostics.diagnostic(miden_diagnostics::Severity::Error)
275 .with_message("invalid scalar expression")
276 .with_primary_label(span, "expected scalar value, but this expression evaluates to an aggregate type")
277 .emit();
278 return ControlFlow::Break(SemanticAnalysisError::Invalid);
279 }
280 }
281 }
282 Ok(Left(None)) => (),
283 Ok(Right(mut block)) => match block.pop().unwrap() {
284 Statement::Let(inner_expr) => {
285 *let_expr.as_mut() = inner_expr;
286 }
287 Statement::Expr(inner_expr) => {
288 match ScalarExpr::try_from(inner_expr)
289 .map_err(SemanticAnalysisError::InvalidExpr)
290 {
291 Ok(scalar_expr) => {
292 *expr = scalar_expr;
293 }
294 Err(err) => return ControlFlow::Break(err),
295 }
296 }
297 Statement::Enforce(_)
298 | Statement::EnforceIf(_, _)
299 | Statement::EnforceAll(_)
300 | Statement::BusEnforce(_) => unreachable!(),
301 },
302 Err(err) => return ControlFlow::Break(err),
303 }
304 ControlFlow::Continue(())
305 }
306 ScalarExpr::BusOperation(expr) => self.visit_mut_bus_operation(expr),
307 }
308 }
309
310 fn visit_mut_expr(&mut self, expr: &mut Expr) -> ControlFlow<SemanticAnalysisError> {
311 let span = expr.span();
312 match expr {
313 // Already constant
314 Expr::Const(_) => ControlFlow::Continue(()),
315 // Lift to `Expr::Const` if the scalar expression is constant
316 //
317 // We deal with symbol accesses directly, as they may evaluate to an aggregate constant
318 Expr::SymbolAccess(access) => {
319 let constant_value = match access.name {
320 // Possibly a reference to a constant declaration
321 ResolvableIdentifier::Resolved(ref qid) => {
322 self.global.get(qid).cloned().map(|s| (s.span(), s.item))
323 }
324 // Possibly a reference to a local bound to a constant
325 ResolvableIdentifier::Local(ref id) => {
326 self.local.get(id).cloned().map(|s| (s.span(), s.item))
327 }
328 // Other identifiers cannot possibly be constant
329 _ => None,
330 };
331 if let Some((span, constant_expr)) = constant_value {
332 match constant_expr {
333 cexpr @ ConstantExpr::Scalar(_) => {
334 assert_eq!(access.access_type, AccessType::Default);
335 *expr = Expr::Const(Span::new(span, cexpr));
336 }
337 ConstantExpr::Vector(value) => match access.access_type.clone() {
338 AccessType::Default => {
339 *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(value)));
340 }
341 AccessType::Slice(range) => {
342 let range = range.to_slice_range();
343 let vector = value[range].to_vec();
344 *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(vector)));
345 }
346 AccessType::Index(idx) => {
347 *expr =
348 Expr::Const(Span::new(span, ConstantExpr::Scalar(value[idx])));
349 }
350 ref ty => panic!(
351 "invalid constant reference, expected scalar access, got {ty:?}",
352 ),
353 },
354 ConstantExpr::Matrix(value) => match access.access_type.clone() {
355 AccessType::Default => {
356 *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(value)));
357 }
358 AccessType::Slice(range) => {
359 let range = range.to_slice_range();
360 let matrix = value[range].to_vec();
361 *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(matrix)));
362 }
363 AccessType::Index(idx) => {
364 *expr = Expr::Const(Span::new(
365 span,
366 ConstantExpr::Vector(value[idx].clone()),
367 ));
368 }
369 AccessType::Matrix(row, col) => {
370 *expr = Expr::Const(Span::new(
371 span,
372 ConstantExpr::Scalar(value[row][col]),
373 ));
374 }
375 },
376 }
377 } else {
378 // This reference is not constant, so we have to record a live use here
379 self.live.insert(*access.name.as_ref());
380 }
381 ControlFlow::Continue(())
382 }
383 Expr::Call(call) if call.is_builtin() => {
384 self.visit_mut_call(call)?;
385 match call.callee.as_ref().name() {
386 name @ (symbols::Sum | symbols::Prod) => {
387 assert_eq!(call.args.len(), 1);
388 if let Expr::Const(value) = &call.args[0] {
389 let span = value.span();
390 match &value.item {
391 ConstantExpr::Vector(elems) => {
392 let folded = if name == symbols::Sum {
393 elems.iter().sum::<u64>()
394 } else {
395 elems.iter().product::<u64>()
396 };
397 *expr =
398 Expr::Const(Span::new(span, ConstantExpr::Scalar(folded)));
399 }
400 invalid => {
401 panic!("bad argument to list folding builtin: {invalid:#?}")
402 }
403 }
404 }
405 }
406 invalid => unimplemented!("unknown builtin function: {invalid}"),
407 }
408 ControlFlow::Continue(())
409 }
410 Expr::Call(call) => self.visit_mut_call(call),
411 Expr::Binary(binary_expr) => match self.try_fold_binary_expr(binary_expr) {
412 Ok(maybe_folded) => {
413 if let Some(folded) = maybe_folded {
414 *expr = Expr::Const(Span::new(
415 folded.span(),
416 ConstantExpr::Scalar(folded.item),
417 ));
418 }
419 ControlFlow::Continue(())
420 }
421 Err(SemanticAnalysisError::InvalidExpr(InvalidExprError::NonConstantExponent(
422 _,
423 ))) if self.in_list_comprehension => {
424 // If we are in a list comprehension, we do not know iterators'
425 // lengths yet, since loop unrolling happens during MIR passes.
426 // The check for non-constant exponents in list comprehensions is done
427 // during lowering from MIR to AIR, so we can safely silence it here.
428 ControlFlow::Continue(())
429 }
430 Err(err) => ControlFlow::Break(err),
431 },
432 // Ranges are constant
433 Expr::Range(_) => ControlFlow::Continue(()),
434 // Visit vector elements, and promote the vector to `Expr::Const` if possible
435 Expr::Vector(vector) => {
436 if vector.is_empty() {
437 return ControlFlow::Continue(());
438 }
439
440 let mut is_constant = true;
441 for elem in vector.iter_mut() {
442 self.visit_mut_expr(elem)?;
443 is_constant &= elem.is_constant();
444 }
445
446 if is_constant {
447 let ty = match vector.first().and_then(|e| e.ty()).unwrap() {
448 Type::Felt => Type::Vector(vector.len()),
449 Type::Vector(n) => Type::Matrix(vector.len(), n),
450 _ => unreachable!(),
451 };
452
453 let new_expr = match ty {
454 Type::Vector(_) => ConstantExpr::Vector(
455 vector
456 .iter()
457 .map(|expr| match expr {
458 Expr::Const(Span {
459 item: ConstantExpr::Scalar(v),
460 ..
461 }) => *v,
462 _ => unreachable!(),
463 })
464 .collect(),
465 ),
466 Type::Matrix(_, _) => ConstantExpr::Matrix(
467 vector
468 .iter()
469 .map(|expr| match expr {
470 Expr::Const(Span {
471 item: ConstantExpr::Vector(vs),
472 ..
473 }) => vs.clone(),
474 _ => unreachable!(),
475 })
476 .collect(),
477 ),
478 _ => unreachable!(),
479 };
480 *expr = Expr::Const(Span::new(span, new_expr));
481 }
482 ControlFlow::Continue(())
483 }
484 // Visit matrix elements, and promote the matrix to `Expr::Const` if possible
485 Expr::Matrix(matrix) => {
486 let mut is_constant = true;
487 for row in matrix.iter_mut() {
488 for column in row.iter_mut() {
489 self.visit_mut_scalar_expr(column)?;
490 is_constant &= column.is_constant();
491 }
492 }
493 if is_constant {
494 let matrix = ConstantExpr::Matrix(
495 matrix
496 .iter()
497 .map(|row| {
498 row.iter()
499 .map(|col| match col {
500 ScalarExpr::Const(elem) => elem.item,
501 _ => unreachable!(),
502 })
503 .collect::<Vec<_>>()
504 })
505 .collect(),
506 );
507 *expr = Expr::Const(Span::new(span, matrix));
508 }
509 ControlFlow::Continue(())
510 }
511 // Visit list comprehensions and convert to constant if possible
512 Expr::ListComprehension(lc) => {
513 let old_in_lc = core::mem::replace(&mut self.in_list_comprehension, true);
514 let mut has_constant_iterables = true;
515 for iterable in lc.iterables.iter_mut() {
516 self.visit_mut_expr(iterable)?;
517 has_constant_iterables &= iterable.is_constant();
518 }
519 // First, fold all other constants inside the body of the comprehension
520 self.visit_mut_scalar_expr(&mut lc.body)?;
521
522 // If we have constant iterables, drive the comprehension, evaluating it at
523 // each step. If any part of the body cannot be compile-time evaluated, then
524 // we bail early, as the comprehension can only be folded if all parts of it
525 // are constant.
526 if !has_constant_iterables {
527 self.in_list_comprehension = old_in_lc;
528 return ControlFlow::Continue(());
529 }
530
531 // Start a new lexical scope
532 self.local.enter();
533
534 // All iterables must be the same length, so determine the number of
535 // steps based on the length of the first iterable
536 let max_len = match &lc.iterables[0] {
537 Expr::Const(Span {
538 item: ConstantExpr::Vector(elems),
539 ..
540 }) => elems.len(),
541 Expr::Const(Span {
542 item: ConstantExpr::Matrix(rows),
543 ..
544 }) => rows.len(),
545 Expr::Const(_) => panic!("expected iterable constant, got scalar"),
546 Expr::Range(range) => range.to_slice_range().len(),
547 _ => unreachable!(
548 "expected iterable constant or range, got {:?}",
549 lc.iterables[0]
550 ),
551 };
552
553 // Drive the comprehension step-by-step
554 let mut folded = vec![];
555 for step in 0..max_len {
556 for (binding, iterable) in lc.bindings.iter().copied().zip(lc.iterables.iter())
557 {
558 let span = iterable.span();
559 match iterable {
560 Expr::Const(Span {
561 item: ConstantExpr::Vector(elems),
562 ..
563 }) => {
564 let value = ConstantExpr::Scalar(elems[step]);
565 self.local.insert(binding, Span::new(span, value));
566 }
567 Expr::Const(Span {
568 item: ConstantExpr::Matrix(elems),
569 ..
570 }) => {
571 let value = ConstantExpr::Vector(elems[step].clone());
572 self.local.insert(binding, Span::new(span, value));
573 }
574 Expr::Range(range) => {
575 let range = range.to_slice_range();
576 assert!(range.end > range.start + step);
577 let value = ConstantExpr::Scalar((range.start + step) as u64);
578 self.local.insert(binding, Span::new(span, value));
579 }
580 _ => unreachable!(
581 "expected iterable constant or range, got {:#?}",
582 iterable
583 ),
584 }
585 }
586
587 if let Some(mut selector) = lc.selector.as_ref().cloned() {
588 self.visit_mut_scalar_expr(&mut selector)?;
589 match selector {
590 ScalarExpr::Const(selected) => {
591 // If the selector returns false on this iteration, go to the next step
592 if *selected == 0 {
593 continue;
594 }
595 }
596 // The selector cannot be evaluated, bail out early
597 _ => {
598 self.in_list_comprehension = old_in_lc;
599 return ControlFlow::Continue(());
600 }
601 }
602 }
603
604 let mut body = lc.body.as_ref().clone();
605 self.visit_mut_scalar_expr(&mut body)?;
606
607 // If the body is constant, store the result in the vector, otherwise we must
608 // bail because this comprehension cannot be folded
609 if let ScalarExpr::Const(folded_body) = body {
610 folded.push(folded_body.item);
611 } else {
612 self.in_list_comprehension = old_in_lc;
613 return ControlFlow::Continue(());
614 }
615 }
616
617 // Exit lexical scope
618 self.local.exit();
619
620 // If we reach here, the comprehension was expanded to a constant vector
621 *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(folded)));
622 self.in_list_comprehension = old_in_lc;
623 ControlFlow::Continue(())
624 }
625 Expr::Let(let_expr) => {
626 match self.try_fold_let_expr(let_expr) {
627 Ok(Left(Some(const_expr))) => {
628 *expr = Expr::Const(Span::new(span, const_expr.item));
629 }
630 Ok(Left(None)) => (),
631 Ok(Right(mut block)) => match block.pop().unwrap() {
632 Statement::Let(inner_expr) => {
633 *let_expr.as_mut() = inner_expr;
634 }
635 Statement::Expr(inner_expr) => {
636 *expr = inner_expr;
637 }
638 Statement::Enforce(_)
639 | Statement::EnforceIf(_, _)
640 | Statement::EnforceAll(_)
641 | Statement::BusEnforce(_) => unreachable!(),
642 },
643 Err(err) => return ControlFlow::Break(err),
644 }
645 ControlFlow::Continue(())
646 }
647 Expr::BusOperation(expr) => self.visit_mut_bus_operation(expr),
648 Expr::Null(_) | Expr::Unconstrained(_) => ControlFlow::Continue(()),
649 }
650 }
651
652 fn visit_mut_statement_block(
653 &mut self,
654 statements: &mut Vec<Statement>,
655 ) -> ControlFlow<SemanticAnalysisError> {
656 let mut current_statement = 0;
657
658 let mut buffer = vec![];
659 while current_statement < statements.len() {
660 let num_statements = statements.len();
661 match &mut statements[current_statement] {
662 Statement::Let(expr) => {
663 // A `let` may only appear once in a statement block, and must be the
664 // last statement in the block
665 assert_eq!(
666 current_statement,
667 num_statements - 1,
668 "let is not in tail position of block"
669 );
670 match self.try_fold_let_expr(expr) {
671 Ok(Left(Some(const_expr))) => {
672 buffer.push(Statement::Expr(Expr::Const(const_expr)));
673 }
674 Ok(Left(None)) => (),
675 Ok(Right(mut block)) => {
676 buffer.append(&mut block);
677 }
678 Err(err) => return ControlFlow::Break(err),
679 }
680 }
681 Statement::Enforce(expr) => {
682 self.visit_mut_enforce(expr)?;
683 }
684 Statement::EnforceAll(expr) => {
685 self.in_constraint_comprehension = true;
686 self.visit_mut_list_comprehension(expr)?;
687 self.in_constraint_comprehension = false;
688 }
689 Statement::Expr(expr) => {
690 self.visit_mut_expr(expr)?;
691 }
692 Statement::BusEnforce(expr) => {
693 self.in_constraint_comprehension = true;
694 self.visit_mut_list_comprehension(expr)?;
695 self.in_constraint_comprehension = false;
696 }
697 // This statement type is only present in the AST after inlining
698 Statement::EnforceIf(_, _) => unreachable!(),
699 }
700
701 // If we have a non-empty buffer, then we are collapsing a let into the current block,
702 // and that let must have been the last expression in the block, so as soon as we fold
703 // its body into the current block, we're done
704 if buffer.is_empty() {
705 current_statement += 1;
706 continue;
707 }
708
709 // Drop the let statement being folded in to this block
710 statements.pop();
711
712 // Append the buffer
713 statements.append(&mut buffer);
714
715 // We're done
716 break;
717 }
718
719 ControlFlow::Continue(())
720 }
721
722 /// It should not be possible to reach this, as we handle statements at the block level
723 fn visit_mut_statement(&mut self, _: &mut Statement) -> ControlFlow<SemanticAnalysisError> {
724 panic!("unexpectedly reached visit_mut_statement");
725 }
726}
727
728/// This function attempts to folds a binary operator expression into a constant value.
729///
730/// If the operands are both constant, the operator is applied, and if the result does not
731/// overflow/underflow, then `Ok(Some)` is returned with the result of the evaluation.
732///
733/// If the operands are not both constant, or the operation would overflow/underflow, then
734/// `Ok(None)` is returned.
735///
736/// If the operands are constant, or there is some validation error with the expression,
737/// `Err(InvalidExprError)` will be returned.
738pub(crate) fn try_fold_binary_expr(
739 expr: &BinaryExpr,
740) -> Result<Option<Span<u64>>, InvalidExprError> {
741 // If both operands are constant, fold
742 if let (ScalarExpr::Const(l), ScalarExpr::Const(r)) = (expr.lhs.as_ref(), expr.rhs.as_ref()) {
743 let folded = match expr.op {
744 BinaryOp::Add => l.item.checked_add(r.item),
745 BinaryOp::Sub => l.item.checked_sub(r.item),
746 BinaryOp::Mul => l.item.checked_mul(r.item),
747 BinaryOp::Exp => match r.item.try_into() {
748 Ok(exp) => l.item.checked_pow(exp),
749 Err(_) => return Err(InvalidExprError::InvalidExponent(expr.span())),
750 },
751 // This op cannot be folded
752 BinaryOp::Eq => return Ok(None),
753 };
754 Ok(folded.map(|v| Span::new(expr.span(), v)))
755 } else {
756 // If we observe a non-constant power in an exponentiation operation, raise an error
757 if expr.op == BinaryOp::Exp && !expr.rhs.is_constant() {
758 Err(InvalidExprError::NonConstantExponent(expr.rhs.span()))
759 } else {
760 Ok(None)
761 }
762 }
763}