air_parser/transforms/
constant_propagation.rs

1use std::{
2    collections::{HashMap, HashSet},
3    ops::ControlFlow,
4};
5
6use air_pass::Pass;
7use either::Either::{self, Left, Right};
8use miden_diagnostics::{DiagnosticsHandler, Span, Spanned};
9
10use crate::{
11    ast::{visit::VisitMut, *},
12    sema::{LexicalScope, SemanticAnalysisError},
13    symbols,
14};
15
16/// This pass performs constant propagation on a [Program], replacing all uses of a constant
17/// with the constant itself, converting accesses into constant aggregates with the accessed
18/// value, replacing local variables bound to constants with the constant value, and folding
19/// constant expressions into constant values.
20///
21/// It is expected that the provided [Program] has already been run through semantic analysis,
22/// so it will panic if it encounters invalid constructions to help catch bugs in the semantic
23/// analysis pass, should they exist.
24pub struct ConstantPropagation<'a> {
25    #[allow(unused)]
26    diagnostics: &'a DiagnosticsHandler,
27    global: HashMap<QualifiedIdentifier, Span<ConstantExpr>>,
28    local: LexicalScope<Identifier, Span<ConstantExpr>>,
29    /// The set of identifiers which are live (in use) in the current scope
30    live: HashSet<Identifier>,
31    in_constraint_comprehension: bool,
32    in_list_comprehension: bool,
33}
34impl Pass for ConstantPropagation<'_> {
35    type Input<'a> = Program;
36    type Output<'a> = Program;
37    type Error = SemanticAnalysisError;
38
39    fn run<'a>(&mut self, mut program: Self::Input<'a>) -> Result<Self::Output<'a>, Self::Error> {
40        self.global.reserve(program.constants.len());
41
42        match self.run_visitor(&mut program) {
43            ControlFlow::Continue(()) => Ok(program),
44            ControlFlow::Break(err) => {
45                self.diagnostics.emit(err.clone());
46                Err(err)
47            }
48        }
49    }
50}
51impl<'a> ConstantPropagation<'a> {
52    pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self {
53        Self {
54            diagnostics,
55            global: Default::default(),
56            local: Default::default(),
57            live: Default::default(),
58            in_constraint_comprehension: false,
59            in_list_comprehension: false,
60        }
61    }
62
63    fn run_visitor(&mut self, program: &mut Program) -> ControlFlow<SemanticAnalysisError> {
64        // Record all of the constant declarations
65        for (name, constant) in program.constants.iter() {
66            assert_eq!(
67                self.global
68                    .insert(*name, Span::new(constant.span(), constant.value.clone())),
69                None
70            );
71        }
72
73        // Visit all of the evaluators
74        for evaluator in program.evaluators.values_mut() {
75            self.visit_mut_evaluator_function(evaluator)?;
76        }
77
78        // Visit all of the functions
79        for function in program.functions.values_mut() {
80            self.visit_mut_function(function)?;
81        }
82
83        // Visit all of the buses
84        for bus in program.buses.values_mut() {
85            self.visit_mut_bus(bus)?;
86        }
87
88        // Visit all of the constraints
89        self.visit_mut_boundary_constraints(&mut program.boundary_constraints)?;
90        self.visit_mut_integrity_constraints(&mut program.integrity_constraints)
91    }
92
93    fn try_fold_binary_expr(
94        &mut self,
95        expr: &mut BinaryExpr,
96    ) -> Result<Option<Span<u64>>, SemanticAnalysisError> {
97        // Visit operands first to ensure they are reduced to constants if possible
98        if let ControlFlow::Break(err) = self.visit_mut_scalar_expr(expr.lhs.as_mut()) {
99            return Err(err);
100        }
101        if let ControlFlow::Break(err) = self.visit_mut_scalar_expr(expr.rhs.as_mut()) {
102            return Err(err);
103        }
104        // If both operands are constant, fold
105        try_fold_binary_expr(expr).map_err(SemanticAnalysisError::InvalidExpr)
106    }
107
108    /// When folding a `let`, one of the following can occur:
109    ///
110    /// * The let-bound variable is non-constant, so the entire let must remain, but we
111    ///   can constant-propagate as much of the bound expression and body as possible.
112    /// * The let-bound variable is constant, so once we have constant propagated the body,
113    ///   the let is no longer needed, and one of the following happens:
114    ///   * The `let` terminates with a constant expression, so the entire `let` is replaced
115    ///     with that expression.
116    ///   * The `let` terminates with a non-constant expression, or a constraint, so we inline
117    ///     the let body into the containing block. In the non-constant expression case, we
118    ///     replace the `let` with the last expression in the returned block, since in expression
119    ///     position, we may not have a statement block to inline into.
120    fn try_fold_let_expr(
121        &mut self,
122        expr: &mut Let,
123    ) -> Result<Either<Option<Span<ConstantExpr>>, Vec<Statement>>, SemanticAnalysisError> {
124        // Visit the binding expression first
125        if let ControlFlow::Break(err) = self.visit_mut_expr(&mut expr.value) {
126            return Err(err);
127        }
128
129        // Enter a new lexical scope
130        let prev_live = core::mem::take(&mut self.live);
131        self.local.enter();
132        // If the value is constant, record it in our bindings map
133        let is_constant = expr.value.is_constant();
134        if is_constant {
135            match expr.value {
136                Expr::Const(ref value) => {
137                    self.local.insert(expr.name, value.clone());
138                }
139                Expr::Range(ref range) => {
140                    let span = range.span();
141                    let range = range.to_slice_range();
142                    let vector = range.map(|i| i as u64).collect();
143                    self.local
144                        .insert(expr.name, Span::new(span, ConstantExpr::Vector(vector)));
145                }
146                _ => unreachable!(),
147            }
148        }
149
150        // Visit the let body
151        if let ControlFlow::Break(err) = self.visit_mut_statement_block(&mut expr.body) {
152            return Err(err);
153        }
154
155        // If this let is constant, then the binding is no longer
156        // used in the body after constant propagation, so we can
157        // fold away the let entirely
158        let is_live = self.live.contains(&expr.name);
159        let result = if is_constant && !is_live {
160            match expr.body.last().unwrap() {
161                Statement::Expr(Expr::Const(const_value)) => {
162                    Left(Some(Span::new(expr.span(), const_value.item.clone())))
163                }
164                _ => Right(core::mem::take(&mut expr.body)),
165            }
166        } else {
167            Left(None)
168        };
169
170        // Propagate liveness from the body of the let to its parent scope
171        let mut live = core::mem::take(&mut self.live);
172        live.remove(&expr.name);
173        self.live = &prev_live | &live;
174
175        // Restore the previous scope
176        self.local.exit();
177
178        Ok(result)
179    }
180}
181impl VisitMut<SemanticAnalysisError> for ConstantPropagation<'_> {
182    /// Fold constant expressions
183    fn visit_mut_scalar_expr(
184        &mut self,
185        expr: &mut ScalarExpr,
186    ) -> ControlFlow<SemanticAnalysisError> {
187        match expr {
188            // Expression is already folded
189            ScalarExpr::Const(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => {
190                ControlFlow::Continue(())
191            }
192            // Need to check if this access is to a constant value, and transform to a constant if so
193            ScalarExpr::SymbolAccess(sym) => {
194                let constant_value = match sym.name {
195                    // Possibly a reference to a constant declaration
196                    ResolvableIdentifier::Resolved(ref qid) => {
197                        self.global.get(qid).cloned().map(|s| (s.span(), s.item))
198                    }
199                    // Possibly a reference to a local bound to a constant
200                    ResolvableIdentifier::Local(ref id) => {
201                        self.local.get(id).cloned().map(|s| (s.span(), s.item))
202                    }
203                    // Other identifiers cannot possibly be constant
204                    _ => None,
205                };
206                if let Some((span, constant_expr)) = constant_value {
207                    match constant_expr {
208                        ConstantExpr::Scalar(value) => {
209                            assert_eq!(sym.access_type, AccessType::Default);
210                            *expr = ScalarExpr::Const(Span::new(span, value));
211                        }
212                        ConstantExpr::Vector(value) => match sym.access_type {
213                            AccessType::Index(idx) => {
214                                *expr = ScalarExpr::Const(Span::new(span, value[idx]));
215                            }
216                            // This access cannot be resolved here, so we need to record the fact
217                            // that there are still live uses of this binding
218                            _ => {
219                                self.live.insert(*sym.name.as_ref());
220                            }
221                        },
222                        ConstantExpr::Matrix(value) => match sym.access_type {
223                            AccessType::Matrix(row, col) => {
224                                *expr = ScalarExpr::Const(Span::new(span, value[row][col]));
225                            }
226                            // This access cannot be resolved here, so we need to record the fact
227                            // that there are still live uses of this binding
228                            _ => {
229                                self.live.insert(*sym.name.as_ref());
230                            }
231                        },
232                    }
233                } else {
234                    // This value is not constant, so there are live uses of this symbol
235                    self.live.insert(*sym.name.as_ref());
236                }
237                ControlFlow::Continue(())
238            }
239            // Fold constant expressions
240            ScalarExpr::Binary(binary_expr) => {
241                match self.try_fold_binary_expr(binary_expr) {
242                    Ok(maybe_folded) => {
243                        if let Some(folded) = maybe_folded {
244                            *expr = ScalarExpr::Const(folded);
245                        }
246                        ControlFlow::Continue(())
247                    }
248                    Err(SemanticAnalysisError::InvalidExpr(
249                        InvalidExprError::NonConstantExponent(_),
250                    )) if self.in_list_comprehension => {
251                        // If we are in a list comprehension, we do not know iterators'
252                        // lengths yet, since loop unrolling happens during MIR passes.
253                        // The check for non-constant exponents in list comprehensions is done
254                        // during lowering from MIR to AIR, so we can safely silence it here.
255                        ControlFlow::Continue(())
256                    }
257                    Err(err) => ControlFlow::Break(err),
258                }
259            }
260            // While calls cannot be constant folded, arguments can be
261            ScalarExpr::Call(call) => self.visit_mut_call(call),
262            // This cannot be constant folded
263            ScalarExpr::BoundedSymbolAccess(_) => ControlFlow::Continue(()),
264            // A let that evaluates to a constant value can be folded to the constant value
265            ScalarExpr::Let(let_expr) => {
266                match self.try_fold_let_expr(let_expr) {
267                    Ok(Left(Some(const_expr))) => {
268                        let span = const_expr.span();
269                        match const_expr.item {
270                            ConstantExpr::Scalar(value) => {
271                                *expr = ScalarExpr::Const(Span::new(span, value));
272                            }
273                            _ => {
274                                self.diagnostics.diagnostic(miden_diagnostics::Severity::Error)
275                                    .with_message("invalid scalar expression")
276                                    .with_primary_label(span, "expected scalar value, but this expression evaluates to an aggregate type")
277                                    .emit();
278                                return ControlFlow::Break(SemanticAnalysisError::Invalid);
279                            }
280                        }
281                    }
282                    Ok(Left(None)) => (),
283                    Ok(Right(mut block)) => match block.pop().unwrap() {
284                        Statement::Let(inner_expr) => {
285                            *let_expr.as_mut() = inner_expr;
286                        }
287                        Statement::Expr(inner_expr) => {
288                            match ScalarExpr::try_from(inner_expr)
289                                .map_err(SemanticAnalysisError::InvalidExpr)
290                            {
291                                Ok(scalar_expr) => {
292                                    *expr = scalar_expr;
293                                }
294                                Err(err) => return ControlFlow::Break(err),
295                            }
296                        }
297                        Statement::Enforce(_)
298                        | Statement::EnforceIf(_, _)
299                        | Statement::EnforceAll(_)
300                        | Statement::BusEnforce(_) => unreachable!(),
301                    },
302                    Err(err) => return ControlFlow::Break(err),
303                }
304                ControlFlow::Continue(())
305            }
306            ScalarExpr::BusOperation(expr) => self.visit_mut_bus_operation(expr),
307        }
308    }
309
310    fn visit_mut_expr(&mut self, expr: &mut Expr) -> ControlFlow<SemanticAnalysisError> {
311        let span = expr.span();
312        match expr {
313            // Already constant
314            Expr::Const(_) => ControlFlow::Continue(()),
315            // Lift to `Expr::Const` if the scalar expression is constant
316            //
317            // We deal with symbol accesses directly, as they may evaluate to an aggregate constant
318            Expr::SymbolAccess(access) => {
319                let constant_value = match access.name {
320                    // Possibly a reference to a constant declaration
321                    ResolvableIdentifier::Resolved(ref qid) => {
322                        self.global.get(qid).cloned().map(|s| (s.span(), s.item))
323                    }
324                    // Possibly a reference to a local bound to a constant
325                    ResolvableIdentifier::Local(ref id) => {
326                        self.local.get(id).cloned().map(|s| (s.span(), s.item))
327                    }
328                    // Other identifiers cannot possibly be constant
329                    _ => None,
330                };
331                if let Some((span, constant_expr)) = constant_value {
332                    match constant_expr {
333                        cexpr @ ConstantExpr::Scalar(_) => {
334                            assert_eq!(access.access_type, AccessType::Default);
335                            *expr = Expr::Const(Span::new(span, cexpr));
336                        }
337                        ConstantExpr::Vector(value) => match access.access_type.clone() {
338                            AccessType::Default => {
339                                *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(value)));
340                            }
341                            AccessType::Slice(range) => {
342                                let range = range.to_slice_range();
343                                let vector = value[range].to_vec();
344                                *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(vector)));
345                            }
346                            AccessType::Index(idx) => {
347                                *expr =
348                                    Expr::Const(Span::new(span, ConstantExpr::Scalar(value[idx])));
349                            }
350                            ref ty => panic!(
351                                "invalid constant reference, expected scalar access, got {ty:?}",
352                            ),
353                        },
354                        ConstantExpr::Matrix(value) => match access.access_type.clone() {
355                            AccessType::Default => {
356                                *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(value)));
357                            }
358                            AccessType::Slice(range) => {
359                                let range = range.to_slice_range();
360                                let matrix = value[range].to_vec();
361                                *expr = Expr::Const(Span::new(span, ConstantExpr::Matrix(matrix)));
362                            }
363                            AccessType::Index(idx) => {
364                                *expr = Expr::Const(Span::new(
365                                    span,
366                                    ConstantExpr::Vector(value[idx].clone()),
367                                ));
368                            }
369                            AccessType::Matrix(row, col) => {
370                                *expr = Expr::Const(Span::new(
371                                    span,
372                                    ConstantExpr::Scalar(value[row][col]),
373                                ));
374                            }
375                        },
376                    }
377                } else {
378                    // This reference is not constant, so we have to record a live use here
379                    self.live.insert(*access.name.as_ref());
380                }
381                ControlFlow::Continue(())
382            }
383            Expr::Call(call) if call.is_builtin() => {
384                self.visit_mut_call(call)?;
385                match call.callee.as_ref().name() {
386                    name @ (symbols::Sum | symbols::Prod) => {
387                        assert_eq!(call.args.len(), 1);
388                        if let Expr::Const(value) = &call.args[0] {
389                            let span = value.span();
390                            match &value.item {
391                                ConstantExpr::Vector(elems) => {
392                                    let folded = if name == symbols::Sum {
393                                        elems.iter().sum::<u64>()
394                                    } else {
395                                        elems.iter().product::<u64>()
396                                    };
397                                    *expr =
398                                        Expr::Const(Span::new(span, ConstantExpr::Scalar(folded)));
399                                }
400                                invalid => {
401                                    panic!("bad argument to list folding builtin: {invalid:#?}")
402                                }
403                            }
404                        }
405                    }
406                    invalid => unimplemented!("unknown builtin function: {invalid}"),
407                }
408                ControlFlow::Continue(())
409            }
410            Expr::Call(call) => self.visit_mut_call(call),
411            Expr::Binary(binary_expr) => match self.try_fold_binary_expr(binary_expr) {
412                Ok(maybe_folded) => {
413                    if let Some(folded) = maybe_folded {
414                        *expr = Expr::Const(Span::new(
415                            folded.span(),
416                            ConstantExpr::Scalar(folded.item),
417                        ));
418                    }
419                    ControlFlow::Continue(())
420                }
421                Err(SemanticAnalysisError::InvalidExpr(InvalidExprError::NonConstantExponent(
422                    _,
423                ))) if self.in_list_comprehension => {
424                    // If we are in a list comprehension, we do not know iterators'
425                    // lengths yet, since loop unrolling happens during MIR passes.
426                    // The check for non-constant exponents in list comprehensions is done
427                    // during lowering from MIR to AIR, so we can safely silence it here.
428                    ControlFlow::Continue(())
429                }
430                Err(err) => ControlFlow::Break(err),
431            },
432            // Ranges are constant
433            Expr::Range(_) => ControlFlow::Continue(()),
434            // Visit vector elements, and promote the vector to `Expr::Const` if possible
435            Expr::Vector(vector) => {
436                if vector.is_empty() {
437                    return ControlFlow::Continue(());
438                }
439
440                let mut is_constant = true;
441                for elem in vector.iter_mut() {
442                    self.visit_mut_expr(elem)?;
443                    is_constant &= elem.is_constant();
444                }
445
446                if is_constant {
447                    let ty = match vector.first().and_then(|e| e.ty()).unwrap() {
448                        Type::Felt => Type::Vector(vector.len()),
449                        Type::Vector(n) => Type::Matrix(vector.len(), n),
450                        _ => unreachable!(),
451                    };
452
453                    let new_expr = match ty {
454                        Type::Vector(_) => ConstantExpr::Vector(
455                            vector
456                                .iter()
457                                .map(|expr| match expr {
458                                    Expr::Const(Span {
459                                        item: ConstantExpr::Scalar(v),
460                                        ..
461                                    }) => *v,
462                                    _ => unreachable!(),
463                                })
464                                .collect(),
465                        ),
466                        Type::Matrix(_, _) => ConstantExpr::Matrix(
467                            vector
468                                .iter()
469                                .map(|expr| match expr {
470                                    Expr::Const(Span {
471                                        item: ConstantExpr::Vector(vs),
472                                        ..
473                                    }) => vs.clone(),
474                                    _ => unreachable!(),
475                                })
476                                .collect(),
477                        ),
478                        _ => unreachable!(),
479                    };
480                    *expr = Expr::Const(Span::new(span, new_expr));
481                }
482                ControlFlow::Continue(())
483            }
484            // Visit matrix elements, and promote the matrix to `Expr::Const` if possible
485            Expr::Matrix(matrix) => {
486                let mut is_constant = true;
487                for row in matrix.iter_mut() {
488                    for column in row.iter_mut() {
489                        self.visit_mut_scalar_expr(column)?;
490                        is_constant &= column.is_constant();
491                    }
492                }
493                if is_constant {
494                    let matrix = ConstantExpr::Matrix(
495                        matrix
496                            .iter()
497                            .map(|row| {
498                                row.iter()
499                                    .map(|col| match col {
500                                        ScalarExpr::Const(elem) => elem.item,
501                                        _ => unreachable!(),
502                                    })
503                                    .collect::<Vec<_>>()
504                            })
505                            .collect(),
506                    );
507                    *expr = Expr::Const(Span::new(span, matrix));
508                }
509                ControlFlow::Continue(())
510            }
511            // Visit list comprehensions and convert to constant if possible
512            Expr::ListComprehension(lc) => {
513                let old_in_lc = core::mem::replace(&mut self.in_list_comprehension, true);
514                let mut has_constant_iterables = true;
515                for iterable in lc.iterables.iter_mut() {
516                    self.visit_mut_expr(iterable)?;
517                    has_constant_iterables &= iterable.is_constant();
518                }
519                // First, fold all other constants inside the body of the comprehension
520                self.visit_mut_scalar_expr(&mut lc.body)?;
521
522                // If we have constant iterables, drive the comprehension, evaluating it at
523                // each step. If any part of the body cannot be compile-time evaluated, then
524                // we bail early, as the comprehension can only be folded if all parts of it
525                // are constant.
526                if !has_constant_iterables {
527                    self.in_list_comprehension = old_in_lc;
528                    return ControlFlow::Continue(());
529                }
530
531                // Start a new lexical scope
532                self.local.enter();
533
534                // All iterables must be the same length, so determine the number of
535                // steps based on the length of the first iterable
536                let max_len = match &lc.iterables[0] {
537                    Expr::Const(Span {
538                        item: ConstantExpr::Vector(elems),
539                        ..
540                    }) => elems.len(),
541                    Expr::Const(Span {
542                        item: ConstantExpr::Matrix(rows),
543                        ..
544                    }) => rows.len(),
545                    Expr::Const(_) => panic!("expected iterable constant, got scalar"),
546                    Expr::Range(range) => range.to_slice_range().len(),
547                    _ => unreachable!(
548                        "expected iterable constant or range, got {:?}",
549                        lc.iterables[0]
550                    ),
551                };
552
553                // Drive the comprehension step-by-step
554                let mut folded = vec![];
555                for step in 0..max_len {
556                    for (binding, iterable) in lc.bindings.iter().copied().zip(lc.iterables.iter())
557                    {
558                        let span = iterable.span();
559                        match iterable {
560                            Expr::Const(Span {
561                                item: ConstantExpr::Vector(elems),
562                                ..
563                            }) => {
564                                let value = ConstantExpr::Scalar(elems[step]);
565                                self.local.insert(binding, Span::new(span, value));
566                            }
567                            Expr::Const(Span {
568                                item: ConstantExpr::Matrix(elems),
569                                ..
570                            }) => {
571                                let value = ConstantExpr::Vector(elems[step].clone());
572                                self.local.insert(binding, Span::new(span, value));
573                            }
574                            Expr::Range(range) => {
575                                let range = range.to_slice_range();
576                                assert!(range.end > range.start + step);
577                                let value = ConstantExpr::Scalar((range.start + step) as u64);
578                                self.local.insert(binding, Span::new(span, value));
579                            }
580                            _ => unreachable!(
581                                "expected iterable constant or range, got {:#?}",
582                                iterable
583                            ),
584                        }
585                    }
586
587                    if let Some(mut selector) = lc.selector.as_ref().cloned() {
588                        self.visit_mut_scalar_expr(&mut selector)?;
589                        match selector {
590                            ScalarExpr::Const(selected) => {
591                                // If the selector returns false on this iteration, go to the next step
592                                if *selected == 0 {
593                                    continue;
594                                }
595                            }
596                            // The selector cannot be evaluated, bail out early
597                            _ => {
598                                self.in_list_comprehension = old_in_lc;
599                                return ControlFlow::Continue(());
600                            }
601                        }
602                    }
603
604                    let mut body = lc.body.as_ref().clone();
605                    self.visit_mut_scalar_expr(&mut body)?;
606
607                    // If the body is constant, store the result in the vector, otherwise we must
608                    // bail because this comprehension cannot be folded
609                    if let ScalarExpr::Const(folded_body) = body {
610                        folded.push(folded_body.item);
611                    } else {
612                        self.in_list_comprehension = old_in_lc;
613                        return ControlFlow::Continue(());
614                    }
615                }
616
617                // Exit lexical scope
618                self.local.exit();
619
620                // If we reach here, the comprehension was expanded to a constant vector
621                *expr = Expr::Const(Span::new(span, ConstantExpr::Vector(folded)));
622                self.in_list_comprehension = old_in_lc;
623                ControlFlow::Continue(())
624            }
625            Expr::Let(let_expr) => {
626                match self.try_fold_let_expr(let_expr) {
627                    Ok(Left(Some(const_expr))) => {
628                        *expr = Expr::Const(Span::new(span, const_expr.item));
629                    }
630                    Ok(Left(None)) => (),
631                    Ok(Right(mut block)) => match block.pop().unwrap() {
632                        Statement::Let(inner_expr) => {
633                            *let_expr.as_mut() = inner_expr;
634                        }
635                        Statement::Expr(inner_expr) => {
636                            *expr = inner_expr;
637                        }
638                        Statement::Enforce(_)
639                        | Statement::EnforceIf(_, _)
640                        | Statement::EnforceAll(_)
641                        | Statement::BusEnforce(_) => unreachable!(),
642                    },
643                    Err(err) => return ControlFlow::Break(err),
644                }
645                ControlFlow::Continue(())
646            }
647            Expr::BusOperation(expr) => self.visit_mut_bus_operation(expr),
648            Expr::Null(_) | Expr::Unconstrained(_) => ControlFlow::Continue(()),
649        }
650    }
651
652    fn visit_mut_statement_block(
653        &mut self,
654        statements: &mut Vec<Statement>,
655    ) -> ControlFlow<SemanticAnalysisError> {
656        let mut current_statement = 0;
657
658        let mut buffer = vec![];
659        while current_statement < statements.len() {
660            let num_statements = statements.len();
661            match &mut statements[current_statement] {
662                Statement::Let(expr) => {
663                    // A `let` may only appear once in a statement block, and must be the
664                    // last statement in the block
665                    assert_eq!(
666                        current_statement,
667                        num_statements - 1,
668                        "let is not in tail position of block"
669                    );
670                    match self.try_fold_let_expr(expr) {
671                        Ok(Left(Some(const_expr))) => {
672                            buffer.push(Statement::Expr(Expr::Const(const_expr)));
673                        }
674                        Ok(Left(None)) => (),
675                        Ok(Right(mut block)) => {
676                            buffer.append(&mut block);
677                        }
678                        Err(err) => return ControlFlow::Break(err),
679                    }
680                }
681                Statement::Enforce(expr) => {
682                    self.visit_mut_enforce(expr)?;
683                }
684                Statement::EnforceAll(expr) => {
685                    self.in_constraint_comprehension = true;
686                    self.visit_mut_list_comprehension(expr)?;
687                    self.in_constraint_comprehension = false;
688                }
689                Statement::Expr(expr) => {
690                    self.visit_mut_expr(expr)?;
691                }
692                Statement::BusEnforce(expr) => {
693                    self.in_constraint_comprehension = true;
694                    self.visit_mut_list_comprehension(expr)?;
695                    self.in_constraint_comprehension = false;
696                }
697                // This statement type is only present in the AST after inlining
698                Statement::EnforceIf(_, _) => unreachable!(),
699            }
700
701            // If we have a non-empty buffer, then we are collapsing a let into the current block,
702            // and that let must have been the last expression in the block, so as soon as we fold
703            // its body into the current block, we're done
704            if buffer.is_empty() {
705                current_statement += 1;
706                continue;
707            }
708
709            // Drop the let statement being folded in to this block
710            statements.pop();
711
712            // Append the buffer
713            statements.append(&mut buffer);
714
715            // We're done
716            break;
717        }
718
719        ControlFlow::Continue(())
720    }
721
722    /// It should not be possible to reach this, as we handle statements at the block level
723    fn visit_mut_statement(&mut self, _: &mut Statement) -> ControlFlow<SemanticAnalysisError> {
724        panic!("unexpectedly reached visit_mut_statement");
725    }
726}
727
728/// This function attempts to folds a binary operator expression into a constant value.
729///
730/// If the operands are both constant, the operator is applied, and if the result does not
731/// overflow/underflow, then `Ok(Some)` is returned with the result of the evaluation.
732///
733/// If the operands are not both constant, or the operation would overflow/underflow, then
734/// `Ok(None)` is returned.
735///
736/// If the operands are constant, or there is some validation error with the expression,
737/// `Err(InvalidExprError)` will be returned.
738pub(crate) fn try_fold_binary_expr(
739    expr: &BinaryExpr,
740) -> Result<Option<Span<u64>>, InvalidExprError> {
741    // If both operands are constant, fold
742    if let (ScalarExpr::Const(l), ScalarExpr::Const(r)) = (expr.lhs.as_ref(), expr.rhs.as_ref()) {
743        let folded = match expr.op {
744            BinaryOp::Add => l.item.checked_add(r.item),
745            BinaryOp::Sub => l.item.checked_sub(r.item),
746            BinaryOp::Mul => l.item.checked_mul(r.item),
747            BinaryOp::Exp => match r.item.try_into() {
748                Ok(exp) => l.item.checked_pow(exp),
749                Err(_) => return Err(InvalidExprError::InvalidExponent(expr.span())),
750            },
751            // This op cannot be folded
752            BinaryOp::Eq => return Ok(None),
753        };
754        Ok(folded.map(|v| Span::new(expr.span(), v)))
755    } else {
756        // If we observe a non-constant power in an exponentiation operation, raise an error
757        if expr.op == BinaryOp::Exp && !expr.rhs.is_constant() {
758            Err(InvalidExprError::NonConstantExponent(expr.rhs.span()))
759        } else {
760            Ok(None)
761        }
762    }
763}