air_parser/transforms/
inlining.rs

1use std::{
2    collections::{BTreeMap, HashMap, HashSet, VecDeque},
3    ops::ControlFlow,
4    vec,
5};
6
7use air_pass::Pass;
8use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned};
9
10use crate::{
11    ast::{visit::VisitMut, *},
12    sema::{BindingType, LexicalScope, SemanticAnalysisError},
13    symbols,
14};
15
16use super::constant_propagation;
17
18/// This pass performs the following transformations on a [Program]:
19///
20/// * Monomorphizing and inlining evaluators/functions at their call sites
21/// * Unrolling constraint comprehensions into a sequence of scalar constraints
22/// * Unrolling list comprehensions into a tree of `let` statements which end in
23///   a vector expression (the implicit result of the tree). Each iteration of the
24///   unrolled comprehension is reified as a value and bound to a variable so that
25///   other transformations may refer to it directly.
26/// * Rewriting aliases of top-level declarations to refer to those declarations directly
27/// * Removing let-bound variables which are unused, which is also used to clean up
28///   after the aliasing rewrite mentioned above.
29///
30/// The trickiest transformation comes with inlining the body of evaluators at their
31/// call sites, as evaluator parameter lists can arbitrarily destructure/regroup columns
32/// provided as arguments for each trace segment. This means that columns can be passed
33/// in a variety of configurations as arguments, and the patterns expressed in the evaluator
34/// parameter list can arbitrarily reconfigure them for use in the evaluator body.
35///
36/// For example, let's say you call an evaluator `foo` with three columns, passed as individual
37/// bindings, like so: `foo([a, b, c])`. Let's further assume that the evaluator signature
38/// is defined as `ev foo([x[2], y])`. While you might expect that this would be an error,
39/// and that the caller would need to provide the columns in the same configuration, that
40/// is not the case. Instead, `a` and `b` are implicitly re-bound as a vector of trace column
41/// bindings for use in the function body. There is further no requirement that `a` and `b`
42/// are consecutive bindings either, as long as they are from the same trace segment. During
43/// compilation however, accesses to individual elements of the vector will be rewritten to use
44/// the correct binding in the caller after inlining, e.g. an access like `x[1]` becomes `b`.
45///
46/// This pass accomplishes three goals:
47///
48/// * Remove all function abstractions from the program
49/// * Remove all comprehensions from the program
50/// * Inline all constraints into the integrity and boundary constraints sections
51/// * Make all references to top-level declarations concrete
52///
53/// When done, it should be impossible for there to be any invalid trace column references.
54///
55/// It is expected that the provided [Program] has already been run through semantic analysis
56/// and constant propagation, so a number of assumptions are made with regard to what syntax can
57/// be observed at this stage of compilation (e.g. no references to constant declarations, no
58/// undefined variables, expressions are well-typed, etc.).
59pub struct Inlining<'a> {
60    // This may be unused for now, but it's helpful to assume its needed in case we want it in the future
61    #[allow(unused)]
62    diagnostics: &'a DiagnosticsHandler,
63    /// The name of the root module
64    root: Identifier,
65    /// The global trace segment configuration
66    trace: Vec<TraceSegment>,
67    /// The public_inputs declaration
68    public_inputs: BTreeMap<Identifier, PublicInput>,
69    /// All local/global bindings in scope
70    bindings: LexicalScope<Identifier, BindingType>,
71    /// The values of all let-bound variables in scope
72    let_bound: LexicalScope<Identifier, Expr>,
73    /// All items which must be referenced fully-qualified, namely periodic columns at this point
74    imported: HashMap<QualifiedIdentifier, BindingType>,
75    /// All evaluator functions in the program
76    evaluators: HashMap<QualifiedIdentifier, EvaluatorFunction>,
77    /// All pure functions in the program
78    functions: HashMap<QualifiedIdentifier, Function>,
79    /// A set of identifiers for which accesses should be rewritten.
80    ///
81    /// When an identifier is in this set, it means it is a local alias for a trace column,
82    /// and should be rewritten based on the current `BindingType` associated with the alias
83    /// identifier in `bindings`.
84    rewrites: HashSet<Identifier>,
85    /// The call stack during expansion of a function call.
86    ///
87    /// Each time we begin to expand a call, we check if it is already present on the call
88    /// stack, and if so, raise a diagnostic due to infinite recursion. If not, the callee
89    /// is pushed on the stack while we expand its body. When we finish expanding the body
90    /// of the callee, we pop it off this stack, and proceed as usual.
91    call_stack: Vec<QualifiedIdentifier>,
92    in_comprehension_constraint: bool,
93    next_ident_lc: usize,
94    next_ident: usize,
95}
96impl Pass for Inlining<'_> {
97    type Input<'a> = Program;
98    type Output<'a> = Program;
99    type Error = SemanticAnalysisError;
100
101    fn run<'a>(&mut self, mut program: Self::Input<'a>) -> Result<Self::Output<'a>, Self::Error> {
102        self.root = program.name;
103        self.evaluators = program
104            .evaluators
105            .iter()
106            .map(|(k, v)| (*k, v.clone()))
107            .collect();
108
109        self.functions = program
110            .functions
111            .iter()
112            .map(|(k, v)| (*k, v.clone()))
113            .collect();
114
115        // We'll be referencing the trace configuration during inlining, so keep a copy of it
116        self.trace.clone_from(&program.trace_columns);
117        // And the public inputs
118        self.public_inputs.clone_from(&program.public_inputs);
119
120        // Add all of the local bindings visible in the root module, except for
121        // constants and periodic columns, which by this point have been rewritten
122        // to use fully-qualified names (or in the case of constants, have been
123        // eliminated entirely)
124        //
125        // Trace first..
126        for segment in program.trace_columns.iter() {
127            self.bindings.insert(
128                segment.name,
129                BindingType::TraceColumn(TraceBinding {
130                    span: segment.name.span(),
131                    segment: segment.id,
132                    name: Some(segment.name),
133                    offset: 0,
134                    size: segment.size,
135                    ty: Type::Vector(segment.size),
136                }),
137            );
138            for binding in segment.bindings.iter().copied() {
139                self.bindings.insert(
140                    binding.name.unwrap(),
141                    BindingType::TraceColumn(TraceBinding {
142                        span: segment.name.span(),
143                        segment: segment.id,
144                        name: binding.name,
145                        offset: binding.offset,
146                        size: binding.size,
147                        ty: binding.ty,
148                    }),
149                );
150            }
151        }
152        // Public inputs..
153        for input in program.public_inputs.values() {
154            self.bindings.insert(
155                input.name(),
156                BindingType::PublicInput(Type::Vector(input.size())),
157            );
158        }
159        // For periodic columns, we register the imported item, but do not add any to the local bindings.
160        for (name, periodic) in program.periodic_columns.iter() {
161            let binding_ty = BindingType::PeriodicColumn(periodic.values.len());
162            self.imported.insert(*name, binding_ty);
163        }
164
165        // The root of the inlining process is the integrity_constraints and
166        // boundary_constraints blocks. Function calls in inlined functions are
167        // inlined at the same time as the parent
168        self.expand_boundary_constraints(&mut program.boundary_constraints)?;
169        self.expand_integrity_constraints(&mut program.integrity_constraints)?;
170
171        Ok(program)
172    }
173}
174impl<'a> Inlining<'a> {
175    pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self {
176        Self {
177            diagnostics,
178            root: Identifier::new(SourceSpan::UNKNOWN, crate::symbols::Main),
179            trace: vec![],
180            public_inputs: Default::default(),
181            bindings: Default::default(),
182            let_bound: Default::default(),
183            imported: Default::default(),
184            evaluators: Default::default(),
185            functions: Default::default(),
186            rewrites: Default::default(),
187            in_comprehension_constraint: false,
188            call_stack: vec![],
189            next_ident_lc: 0,
190            next_ident: 0,
191        }
192    }
193
194    /// Generate a new variable
195    ///
196    /// This is only used when expanding list comprehensions, so we use a special prefix for
197    /// these generated identifiers to make it clear what they were expanded from.
198    fn get_next_ident_lc(&mut self, span: SourceSpan) -> Identifier {
199        let id = self.next_ident_lc;
200        self.next_ident_lc += 1;
201        Identifier::new(span, crate::Symbol::intern(format!("%lc{id}")))
202    }
203
204    fn get_next_ident(&mut self, span: SourceSpan) -> Identifier {
205        let id = self.next_ident;
206        self.next_ident += 1;
207        Identifier::new(span, crate::Symbol::intern(format!("%{id}")))
208    }
209
210    /// Inline/expand all of the statements in the `boundary_constraints` section
211    fn expand_boundary_constraints(
212        &mut self,
213        body: &mut Vec<Statement>,
214    ) -> Result<(), SemanticAnalysisError> {
215        // Save the current bindings set, as we're entering a new lexical scope
216        self.bindings.enter();
217        // Visit all of the statements, check variable usage, and track referenced imports
218        self.expand_statement_block(body)?;
219        // Restore the original lexical scope
220        self.bindings.exit();
221
222        Ok(())
223    }
224
225    /// Inline/expand all of the statements in the `integrity_constraints` section
226    fn expand_integrity_constraints(
227        &mut self,
228        body: &mut Vec<Statement>,
229    ) -> Result<(), SemanticAnalysisError> {
230        // Save the current bindings set, as we're entering a new lexical scope
231        self.bindings.enter();
232        // Visit all of the statements, check variable usage, and track referenced imports
233        self.expand_statement_block(body)?;
234        // Restore the original lexical scope
235        self.bindings.exit();
236
237        Ok(())
238    }
239
240    /// Expand a block of statements by visiting each statement front-to-back
241    fn expand_statement_block(
242        &mut self,
243        statements: &mut Vec<Statement>,
244    ) -> Result<(), SemanticAnalysisError> {
245        // This conversion is free, and gives us a natural way to treat the block as a queue
246        let mut buffer: VecDeque<Statement> = core::mem::take(statements).into();
247        // Visit each statement, appending the resulting expansion to the original vector
248        while let Some(statement) = buffer.pop_front() {
249            let mut expanded = self.expand_statement(statement)?;
250            if expanded.is_empty() {
251                continue;
252            }
253            statements.append(&mut expanded);
254        }
255
256        Ok(())
257    }
258
259    /// Expand a single statement into one or more statements which are fully-expanded
260    fn expand_statement(
261        &mut self,
262        statement: Statement,
263    ) -> Result<Vec<Statement>, SemanticAnalysisError> {
264        match statement {
265            // Expanding a let requires special treatment, as let-bound values may be inlined as a block
266            // of statements, which requires us to rewrite the `let` into a `let` tree
267            Statement::Let(expr) => self.expand_let(expr),
268            // A call to an evaluator function is expanded by inlining the function itself at the call site
269            Statement::Enforce(ScalarExpr::Call(call)) => self.expand_evaluator_callsite(call),
270            // Constraints are inlined by expanding the constraint expression
271            Statement::Enforce(expr) => self.expand_constraint(expr),
272            // Constraint comprehensions are inlined by unrolling the comprehension into a sequence of constraints
273            Statement::EnforceAll(expr) => {
274                let in_cc = core::mem::replace(&mut self.in_comprehension_constraint, true);
275                let result = self.expand_comprehension(expr);
276                self.in_comprehension_constraint = in_cc;
277                result
278            }
279            // Conditional constraints are expanded like regular constraints, except the selector is applied
280            // to all constraints in the expansion.
281            Statement::EnforceIf(expr, mut selector) => {
282                let mut statements = match expr {
283                    ScalarExpr::Call(call) => self.expand_evaluator_callsite(call)?,
284                    expr => self.expand_constraint(expr)?,
285                };
286                self.rewrite_scalar_expr(&mut selector)?;
287                // We need to make sure the selector is applied to all constraints in the expansion
288                for statement in statements.iter_mut() {
289                    let mut visitor = ApplyConstraintSelector {
290                        selector: &selector,
291                    };
292                    if let ControlFlow::Break(err) = visitor.visit_mut_statement(statement) {
293                        return Err(err);
294                    }
295                }
296                Ok(statements)
297            }
298            // Expresssions containing function calls require expansion via inlining, otherwise
299            // all other expression types are introduced during inlining and are thus already expanded,
300            // but we must still visit them to apply rewrites.
301            Statement::Expr(expr) => match self.expand_expr(expr)? {
302                Expr::Let(let_expr) => Ok(vec![Statement::Let(*let_expr)]),
303                expr => Ok(vec![Statement::Expr(expr)]),
304            },
305            Statement::BusEnforce(_) => {
306                self.diagnostics
307                    .diagnostic(Severity::Error)
308                    .with_message("buses are not implemented for this Pipeline")
309                    .emit();
310                Err(SemanticAnalysisError::Invalid)
311            }
312        }
313    }
314
315    fn expand_expr(&mut self, expr: Expr) -> Result<Expr, SemanticAnalysisError> {
316        match expr {
317            Expr::Vector(mut elements) => {
318                let elems = Vec::with_capacity(elements.len());
319                for elem in core::mem::replace(&mut elements.item, elems) {
320                    elements.push(self.expand_expr(elem)?);
321                }
322                Ok(Expr::Vector(elements))
323            }
324            Expr::Matrix(mut rows) => {
325                for row in rows.iter_mut() {
326                    let cols = Vec::with_capacity(row.len());
327                    for col in core::mem::replace(row, cols) {
328                        row.push(self.expand_scalar_expr(col)?);
329                    }
330                }
331                Ok(Expr::Matrix(rows))
332            }
333            Expr::Binary(expr) => self.expand_binary_expr(expr),
334            Expr::Call(expr) => self.expand_call(expr),
335            Expr::ListComprehension(expr) => {
336                let mut block = self.expand_comprehension(expr)?;
337                assert_eq!(block.len(), 1);
338                Expr::try_from(block.pop().unwrap()).map_err(SemanticAnalysisError::InvalidExpr)
339            }
340            Expr::Let(expr) => {
341                let mut block = self.expand_let(*expr)?;
342                assert_eq!(block.len(), 1);
343                Expr::try_from(block.pop().unwrap()).map_err(SemanticAnalysisError::InvalidExpr)
344            }
345            expr @ (Expr::Const(_) | Expr::Range(_) | Expr::SymbolAccess(_)) => Ok(expr),
346            Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => {
347                self.diagnostics
348                    .diagnostic(Severity::Error)
349                    .with_message("buses are not implemented for this Pipeline")
350                    .emit();
351                Err(SemanticAnalysisError::Invalid)
352            }
353        }
354    }
355
356    /// Let expressions are expanded using the following rules:
357    ///
358    /// * The let-bound expression is expanded first. If it expands to a statement block and
359    ///   not an expression, the block is inlined in place of the let being expanded, and the
360    ///   rest of the expansion takes place at the end of the block; replacing the last statement
361    ///   in the block. If the last statement in the block was an expression, it is treated as
362    ///   the let-bound value. If the last statement in the block was another `let` however, then
363    ///   we recursively walk down the let tree until we reach the bottom, which must always be
364    ///   an expression statement.
365    ///
366    /// * The body is expanded in-place after the previous step has been completed.
367    ///
368    /// * If a let-bound variable is an alias for a declaration, we replace all uses
369    ///   of the variable with direct references to the declaration, making the let-bound
370    ///   variable dead
371    ///
372    /// * If a let-bound variable is dead (i.e. has no references), then the let is elided,
373    ///   by replacing it with the result of expanding its body
374    fn expand_let(&mut self, expr: Let) -> Result<Vec<Statement>, SemanticAnalysisError> {
375        let span = expr.span();
376        let name = expr.name;
377        let body = expr.body;
378
379        // Visit the let-bound expression first, since it determines how the rest of the process goes
380        let value = match expr.value {
381            // When expanding a call in this context, we're expecting a single
382            // statement of either `Expr` or `Let` type, as calls to pure functions
383            // can never contain constraints.
384            Expr::Call(call) => self.expand_call(call)?,
385            // Same as above, but for list comprehensions.
386            //
387            // The rules for expansion are the same.
388            Expr::ListComprehension(lc) => {
389                let mut expanded = self.expand_comprehension(lc)?;
390                match expanded.pop().unwrap() {
391                    Statement::Let(let_expr) => Expr::Let(Box::new(let_expr)),
392                    Statement::Expr(expr) => expr,
393                    Statement::Enforce(_)
394                    | Statement::EnforceIf(_, _)
395                    | Statement::EnforceAll(_)
396                    | Statement::BusEnforce(_) => unreachable!(),
397                }
398            }
399            // The operands of a binary expression can contain function calls, so we must ensure
400            // that we expand the operands as needed, and then proceed with expanding the let.
401            Expr::Binary(expr) => self.expand_binary_expr(expr)?,
402            // Other expressions we visit just to expand rewrites
403            mut expr => {
404                self.rewrite_expr(&mut expr)?;
405                expr
406            }
407        };
408
409        let expr = Let {
410            span,
411            name,
412            value,
413            body,
414        };
415
416        self.expand_let_tree(expr)
417    }
418
419    /// This is only expected to be called on a let tree which is guaranteed to only have
420    /// simple values as let-bound expressions, i.e. the `value` of the `Let` requires no
421    /// expansion or rewrites. You should use `expand_let` in general.
422    fn expand_let_tree(&mut self, mut expr: Let) -> Result<Vec<Statement>, SemanticAnalysisError> {
423        // Start new lexical scope for the body
424        self.bindings.enter();
425        self.let_bound.enter();
426        let prev_rewrites = self.rewrites.clone();
427
428        // Register the binding
429        let binding_ty = self.expr_binding_type(&expr.value).unwrap();
430
431        // If this let is a vector of trace column bindings, then we can
432        // elide the let, and rewrite all uses of the let-bound variable
433        // to the respective elements of the vector
434        let inline_body = binding_ty.is_trace_binding();
435        if inline_body {
436            self.rewrites.insert(expr.name);
437        }
438        self.bindings.insert(expr.name, binding_ty);
439        self.let_bound.insert(expr.name, expr.value.clone());
440
441        // Visit the let body
442        self.expand_statement_block(&mut expr.body)?;
443
444        // Restore the original lexical scope
445        self.bindings.exit();
446        self.let_bound.exit();
447        self.rewrites = prev_rewrites;
448
449        // If we're inlining the body, return the body block as the result;
450        // otherwise re-wrap the `let` as the sole statement in the resulting block
451        if inline_body {
452            Ok(expr.body)
453        } else {
454            Ok(vec![Statement::Let(expr)])
455        }
456    }
457
458    /// Expand a call to a pure function (including builtin list folding functions)
459    fn expand_call(&mut self, mut call: Call) -> Result<Expr, SemanticAnalysisError> {
460        if call.is_builtin() {
461            match call.callee.as_ref().name() {
462                symbols::Sum => {
463                    assert_eq!(call.args.len(), 1);
464                    self.expand_fold(BinaryOp::Add, call.args.pop().unwrap())
465                }
466                symbols::Prod => {
467                    assert_eq!(call.args.len(), 1);
468                    self.expand_fold(BinaryOp::Mul, call.args.pop().unwrap())
469                }
470                other => unimplemented!("unhandled builtin: {}", other),
471            }
472        } else {
473            self.expand_function_callsite(call)
474        }
475    }
476
477    fn expand_scalar_expr(
478        &mut self,
479        expr: ScalarExpr,
480    ) -> Result<ScalarExpr, SemanticAnalysisError> {
481        match expr {
482            ScalarExpr::Binary(expr) if expr.has_block_like_expansion() => {
483                self.expand_binary_expr(expr).and_then(|expr| {
484                    ScalarExpr::try_from(expr).map_err(SemanticAnalysisError::InvalidExpr)
485                })
486            }
487            ScalarExpr::Call(lhs) => self.expand_call(lhs).and_then(|expr| {
488                ScalarExpr::try_from(expr).map_err(SemanticAnalysisError::InvalidExpr)
489            }),
490            mut expr => {
491                self.rewrite_scalar_expr(&mut expr)?;
492                Ok(expr)
493            }
494        }
495    }
496
497    fn expand_binary_expr(&mut self, expr: BinaryExpr) -> Result<Expr, SemanticAnalysisError> {
498        let span = expr.span();
499        let op = expr.op;
500        let lhs = self.expand_scalar_expr(*expr.lhs)?;
501        let rhs = self.expand_scalar_expr(*expr.rhs)?;
502
503        Ok(Expr::Binary(BinaryExpr {
504            span,
505            op,
506            lhs: Box::new(lhs),
507            rhs: Box::new(rhs),
508        }))
509    }
510
511    /// Expand a list folding operation (e.g. sum/prod) over an expression of aggregate type into an equivalent expression tree
512    fn expand_fold(&mut self, op: BinaryOp, list: Expr) -> Result<Expr, SemanticAnalysisError> {
513        let span = list.span();
514        match list {
515            Expr::Vector(mut elems) => self.expand_vector_fold(span, op, &mut elems),
516            Expr::ListComprehension(lc) => {
517                // Expand the comprehension, but ensure we don't treat it like a comprehension constraint
518                let in_cc = core::mem::replace(&mut self.in_comprehension_constraint, false);
519                let mut expanded = self.expand_comprehension(lc)?;
520                self.in_comprehension_constraint = in_cc;
521                // Apply the fold to the expanded comprehension in the bottom of the let tree
522                with_let_result(self, &mut expanded, |inliner, value| {
523                    match value {
524                        // The result value of expanding a comprehension _must_ be a vector
525                        Expr::Vector(elems) => {
526                            // We're going to replace the vector binding with the fold
527                            let folded = inliner.expand_vector_fold(span, op, elems)?;
528                            *value = folded;
529                            Ok(None)
530                        }
531                        _ => unreachable!(),
532                    }
533                })?;
534                match expanded.pop().unwrap() {
535                    Statement::Expr(expr) => Ok(expr),
536                    Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))),
537                    Statement::Enforce(_)
538                    | Statement::EnforceIf(_, _)
539                    | Statement::EnforceAll(_)
540                    | Statement::BusEnforce(_) => unreachable!(),
541                }
542            }
543            Expr::SymbolAccess(ref access) => {
544                match self.let_bound.get(access.name.as_ref()).cloned() {
545                    Some(expr) => self.expand_fold(op, expr),
546                    None => match self.access_binding_type(access) {
547                        Ok(BindingType::TraceColumn(tb)) => {
548                            let mut vector = vec![];
549                            for i in 0..tb.size {
550                                vector.push(Expr::SymbolAccess(
551                                    access.access(AccessType::Index(i)).unwrap(),
552                                ));
553                            }
554                            self.expand_vector_fold(span, op, &mut vector)
555                        }
556                        Ok(_) | Err(_) => unimplemented!(),
557                    },
558                }
559            }
560            // Constant propagation will have already folded calls to list-folding builtins
561            // with constant arguments, so we should panic if we ever see one here
562            Expr::Const(_) => panic!("expected constant to have been folded"),
563            // All other invalid expressions should have been caught by now
564            invalid => panic!("invalid argument to list folding builtin: {invalid:#?}"),
565        }
566    }
567
568    /// Expand a list folding operation (e.g. sum/prod) over a vector into an equivalent expression tree
569    fn expand_vector_fold(
570        &mut self,
571        span: SourceSpan,
572        op: BinaryOp,
573        vector: &mut Vec<Expr>,
574    ) -> Result<Expr, SemanticAnalysisError> {
575        // To expand this fold, we simply produce a nested sequence of BinaryExpr
576        let mut elems = vector.drain(..);
577        let mut acc = elems.next().unwrap();
578        self.rewrite_expr(&mut acc)?;
579        let mut acc: ScalarExpr = acc.try_into().map_err(SemanticAnalysisError::InvalidExpr)?;
580        for mut elem in elems {
581            self.rewrite_expr(&mut elem)?;
582            let elem: ScalarExpr = elem.try_into().expect("invalid scalar expr");
583            let new_acc = ScalarExpr::Binary(BinaryExpr::new(span, op, acc, elem));
584            acc = new_acc;
585        }
586        acc.try_into().map_err(SemanticAnalysisError::InvalidExpr)
587    }
588
589    fn expand_constraint(
590        &mut self,
591        constraint: ScalarExpr,
592    ) -> Result<Vec<Statement>, SemanticAnalysisError> {
593        // The constraint itself must be an equality at this point, as evaluator
594        // calls are handled separately in `expand_statement`
595        match constraint {
596            ScalarExpr::Binary(BinaryExpr {
597                op: BinaryOp::Eq,
598                lhs,
599                rhs,
600                span,
601            }) => {
602                let lhs = self.expand_scalar_expr(*lhs)?;
603                let rhs = self.expand_scalar_expr(*rhs)?;
604
605                Ok(vec![Statement::Enforce(ScalarExpr::Binary(BinaryExpr {
606                    span,
607                    op: BinaryOp::Eq,
608                    lhs: Box::new(lhs),
609                    rhs: Box::new(rhs),
610                }))])
611            }
612            invalid => unreachable!("unexpected constraint node: {:#?}", invalid),
613        }
614    }
615
616    /// This function rewrites expressions which contain accesses for which rewrites have been registered.
617    fn rewrite_expr(&mut self, expr: &mut Expr) -> Result<(), SemanticAnalysisError> {
618        match expr {
619            Expr::Const(_) | Expr::Range(_) => return Ok(()),
620            Expr::Vector(elems) => {
621                for elem in elems.iter_mut() {
622                    self.rewrite_expr(elem)?;
623                }
624            }
625            Expr::Matrix(rows) => {
626                for row in rows.iter_mut() {
627                    for col in row.iter_mut() {
628                        self.rewrite_scalar_expr(col)?;
629                    }
630                }
631            }
632            Expr::Binary(binary_expr) => {
633                self.rewrite_scalar_expr(binary_expr.lhs.as_mut())?;
634                self.rewrite_scalar_expr(binary_expr.rhs.as_mut())?;
635            }
636            Expr::SymbolAccess(access) => {
637                if let Some(rewrite) = self.get_trace_access_rewrite(access) {
638                    *access = rewrite;
639                }
640            }
641            Expr::Call(call) => {
642                for arg in call.args.iter_mut() {
643                    self.rewrite_expr(arg)?;
644                }
645            }
646            // Comprehension rewrites happen when they are expanded, but we do visit the iterables now
647            Expr::ListComprehension(lc) => {
648                for expr in lc.iterables.iter_mut() {
649                    self.rewrite_expr(expr)?;
650                }
651            }
652            Expr::Let(let_expr) => {
653                let mut next = Some(let_expr.as_mut());
654                while let Some(next_let) = next.take() {
655                    self.rewrite_expr(&mut next_let.value)?;
656                    match next_let.body.last_mut().unwrap() {
657                        Statement::Let(inner) => {
658                            next = Some(inner);
659                        }
660                        Statement::Expr(expr) => {
661                            self.rewrite_expr(expr)?;
662                        }
663                        Statement::Enforce(_)
664                        | Statement::EnforceIf(_, _)
665                        | Statement::EnforceAll(_)
666                        | Statement::BusEnforce(_) => unreachable!(),
667                    }
668                }
669            }
670            Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => {
671                self.diagnostics
672                    .diagnostic(Severity::Error)
673                    .with_message("buses are not implemented for this Pipeline")
674                    .emit();
675                return Err(SemanticAnalysisError::Invalid);
676            }
677        }
678        Ok(())
679    }
680
681    /// This function rewrites scalar expressions which contain accesses for which rewrites have been registered.
682    fn rewrite_scalar_expr(&mut self, expr: &mut ScalarExpr) -> Result<(), SemanticAnalysisError> {
683        match expr {
684            ScalarExpr::Const(_) => Ok(()),
685            ScalarExpr::SymbolAccess(access)
686            | ScalarExpr::BoundedSymbolAccess(BoundedSymbolAccess { column: access, .. }) => {
687                if let Some(rewrite) = self.get_trace_access_rewrite(access) {
688                    *access = rewrite;
689                }
690                Ok(())
691            }
692            ScalarExpr::Binary(BinaryExpr { op, lhs, rhs, .. }) => {
693                self.rewrite_scalar_expr(lhs.as_mut())?;
694                self.rewrite_scalar_expr(rhs.as_mut())?;
695                match op {
696                    BinaryOp::Exp if !rhs.is_constant() => Err(SemanticAnalysisError::InvalidExpr(
697                        InvalidExprError::NonConstantExponent(rhs.span()),
698                    )),
699                    _ => Ok(()),
700                }
701            }
702            ScalarExpr::Call(expr) => {
703                for arg in expr.args.iter_mut() {
704                    self.rewrite_expr(arg)?;
705                }
706                Ok(())
707            }
708            ScalarExpr::Let(let_expr) => {
709                let mut next = Some(let_expr.as_mut());
710                while let Some(next_let) = next.take() {
711                    self.rewrite_expr(&mut next_let.value)?;
712                    match next_let.body.last_mut().unwrap() {
713                        Statement::Let(inner) => {
714                            next = Some(inner);
715                        }
716                        Statement::Expr(expr) => {
717                            self.rewrite_expr(expr)?;
718                        }
719                        Statement::Enforce(_)
720                        | Statement::EnforceIf(_, _)
721                        | Statement::EnforceAll(_)
722                        | Statement::BusEnforce(_) => unreachable!(),
723                    }
724                }
725                Ok(())
726            }
727            ScalarExpr::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => {
728                self.diagnostics
729                    .diagnostic(Severity::Error)
730                    .with_message("buses are not implemented for this Pipeline")
731                    .emit();
732                Err(SemanticAnalysisError::Invalid)
733            }
734        }
735    }
736
737    /// This function expands a comprehension into a sequence of statements.
738    ///
739    /// This is done using abstract interpretation. By this point in the compilation process,
740    /// all iterables should have been typed and have known static sizes. Some iterables may even
741    /// be constant, such as in the case of ranges. Because of this, we are able to "unroll" the
742    /// comprehension, evaluating the effective value of all iterable bindings at each iteration,
743    /// and rewriting the comprehension body accordingly.
744    ///
745    /// Depending on whether this is a standard list comprehension, or a constraint comprehension,
746    /// the expansion is, respectively:
747    ///
748    /// * A tree of let statements (using generated variables), where each let binds the value of a
749    ///   single iteration of the comprehension. The body of the final let, and thus the effective
750    ///   value of the entire tree, is a vector containing all of the bindings in the evaluation
751    ///   order of the comprehension.
752    /// * A flat list of constraint statements
753    fn expand_comprehension(
754        &mut self,
755        mut expr: ListComprehension,
756    ) -> Result<Vec<Statement>, SemanticAnalysisError> {
757        // Lift any function calls in iterable position out of the comprehension,
758        // binding the result of those calls via `let`. Rewrite the iterable as
759        // a symbol access to the newly-bound variable.
760        //
761        // NOTE: The actual expansion of the lifted iterables occurs after we expand
762        // the comprehension, so that we can place the expanded comprehension in the
763        // body of the final let
764        let mut lifted_bindings = vec![];
765        let mut lifted = vec![];
766        for param in expr.iterables.iter_mut() {
767            if !matches!(param, Expr::Call(_)) {
768                continue;
769            }
770
771            let span = param.span();
772            let name = self.get_next_ident(span);
773            let ty = match param {
774                Expr::Call(Call { callee, .. }) => {
775                    let callee = callee
776                        .resolved()
777                        .expect("callee should have been resolved by now");
778                    self.functions[&callee].return_type
779                }
780                _ => unsafe { core::hint::unreachable_unchecked() },
781            };
782            let param = core::mem::replace(
783                param,
784                Expr::SymbolAccess(SymbolAccess {
785                    span,
786                    name: ResolvableIdentifier::Local(name),
787                    access_type: AccessType::Default,
788                    offset: 0,
789                    ty: Some(ty),
790                }),
791            );
792            match param {
793                Expr::Call(call) => {
794                    lifted_bindings.push((name, BindingType::Local(ty)));
795                    lifted.push((name, call));
796                }
797                _ => unsafe { core::hint::unreachable_unchecked() },
798            }
799        }
800
801        // Get the number of iterations in this comprehension
802        let Type::Vector(num_iterations) = expr.ty.unwrap() else {
803            panic!("invalid comprehension type");
804        };
805
806        // Step the iterables for each iteration, giving each it's own lexical scope
807        let mut statement_groups = vec![];
808        for i in 0..num_iterations {
809            self.bindings.enter();
810            // Ensure any lifted iterables are in scope for the expansion of this iteration
811            for (name, binding_ty) in lifted_bindings.iter() {
812                self.bindings.insert(*name, binding_ty.clone());
813            }
814            let expansion = self.expand_comprehension_iteration(&expr, i)?;
815            // An expansion can be empty if a constraint selector with a constant selector expression
816            // evaluates to false (allowing us to elide the constraint for that iteration entirely).
817            if !expansion.is_empty() {
818                statement_groups.push(expansion);
819            }
820            self.bindings.exit();
821        }
822
823        // At this point, we have one or more statement groups, representing the expansions
824        // of each iteration of the comprehension. Additionally, we may have a set of lifted
825        // iterables which we need to bind (and expand) "around" the expansion of the comprehension
826        // itself.
827        //
828        // In short, we must take this list of statement groups, and flatten/treeify it. Once
829        // a let binding is introduced into scope, all subsequent statements must occur in the body
830        // of that let, forming a tree. Consecutive statements which introduce no new bindings do
831        // not require any nesting, resulting in the groups containing those statements being flattened.
832        //
833        // Lastly, whether this is a list or constraint comprehension determines if we will also be
834        // constructing a vector from the values produced by each iteration, and returning it as the
835        // result of the comprehension itself.
836        let span = expr.span();
837        if self.in_comprehension_constraint {
838            Ok(statement_groups.into_iter().flatten().collect())
839        } else {
840            // For list comprehensions, we must emit a let tree that binds each iteration,
841            // and ensure that the expansion of the iteration itself is properly nested so
842            // that the lexical scope of all bound variables is correct. This is more complex
843            // than the constraint comprehension case, as we must emit a single expression
844            // representing the entire expansion of the comprehension as an aggregate, whereas
845            // constraints produce no results.
846
847            // Generate a new variable name for each element in the comprehension
848            let symbols = statement_groups
849                .iter()
850                .map(|_| self.get_next_ident_lc(span))
851                .collect::<Vec<_>>();
852            // Generate the list of elements for the vector which is to be the result of the let-tree
853            let vars = statement_groups
854                .iter()
855                .zip(symbols.iter().copied())
856                .map(|(group, name)| {
857                    // The type of these statements must be known by now
858                    let ty = match group.last().unwrap() {
859                        Statement::Expr(value) => value.ty(),
860                        Statement::Let(nested) => nested.ty(),
861                        stmt => unreachable!(
862                            "unexpected statement type in comprehension body: {}",
863                            stmt.display(0)
864                        ),
865                    };
866                    Expr::SymbolAccess(SymbolAccess {
867                        span,
868                        name: ResolvableIdentifier::Local(name),
869                        access_type: AccessType::Default,
870                        offset: 0,
871                        ty,
872                    })
873                })
874                .collect();
875            // Construct the let tree by visiting the statements bottom-up
876            let acc = vec![Statement::Expr(Expr::Vector(Span::new(span, vars)))];
877            let expanded = statement_groups.into_iter().zip(symbols).try_rfold(
878                acc,
879                |acc, (mut group, name)| {
880                    match group.pop().unwrap() {
881                        // If the current statement is an expression, it represents the value of this
882                        // iteration of the comprehension, and we must generate a let to bind it, using
883                        // the accumulator expression as the body
884                        Statement::Expr(expr) => {
885                            group.push(Statement::Let(Let::new(span, name, expr, acc)));
886                        }
887                        // If the current statement is a `let`-tree, we need to generate a new `let` at
888                        // the bottom of the tree, which binds the result expression as the value of the
889                        // generated `let`, and uses the accumulator as the body
890                        Statement::Let(mut wrapper) => {
891                            with_let_result(self, &mut wrapper.body, move |_, value| {
892                                let value = core::mem::replace(
893                                    value,
894                                    Expr::Const(Span::new(span, ConstantExpr::Scalar(0))),
895                                );
896                                Ok(Some(Statement::Let(Let::new(span, name, value, acc))))
897                            })?;
898                            group.push(Statement::Let(wrapper));
899                        }
900                        _ => unreachable!(),
901                    }
902                    Ok::<_, SemanticAnalysisError>(group)
903                },
904            )?;
905            // Lastly, construct the let tree for the lifted iterables, placing the expanded
906            // comprehension at the bottom of that tree.
907            lifted.into_iter().try_rfold(expanded, |acc, (name, call)| {
908                let span = call.span();
909                match self.expand_call(call)? {
910                    Expr::Let(mut wrapper) => {
911                        with_let_result(self, &mut wrapper.body, move |_, value| {
912                            let value = core::mem::replace(
913                                value,
914                                Expr::Const(Span::new(span, ConstantExpr::Scalar(0))),
915                            );
916                            Ok(Some(Statement::Let(Let::new(span, name, value, acc))))
917                        })?;
918                        Ok(vec![Statement::Let(*wrapper)])
919                    }
920                    expr => Ok(vec![Statement::Let(Let::new(span, name, expr, acc))]),
921                }
922            })
923        }
924    }
925
926    fn expand_comprehension_iteration(
927        &mut self,
928        lc: &ListComprehension,
929        index: usize,
930    ) -> Result<Vec<Statement>, SemanticAnalysisError> {
931        // Register each iterable binding and its abstract value.
932        //
933        // The abstract value is either a constant (in which case it is concrete, not abstract), or
934        // an expression which represents accessing the iterable at the index corresponding to the
935        // current iteration.
936        let mut bound_values = HashMap::<Identifier, Expr>::default();
937        for (iterable, binding) in lc.iterables.iter().zip(lc.bindings.iter().copied()) {
938            let abstract_value = match iterable {
939                // If the iterable is constant, the value of it's corresponding binding is also constant
940                Expr::Const(constant) => {
941                    let span = constant.span();
942                    let value = match constant.item {
943                        ConstantExpr::Vector(ref elems) => ConstantExpr::Scalar(elems[index]),
944                        ConstantExpr::Matrix(ref rows) => ConstantExpr::Vector(rows[index].clone()),
945                        // An iterable may never be a scalar value, this will be caught by semantic analysis
946                        ConstantExpr::Scalar(_) => unreachable!(),
947                    };
948                    let binding_ty = BindingType::Constant(value.ty());
949                    self.bindings.insert(binding, binding_ty);
950                    Expr::Const(Span::new(span, value))
951                }
952                // Ranges are constant, so same rules as above apply here
953                Expr::Range(range) => {
954                    let span = range.span();
955                    let range = range.to_slice_range();
956                    let binding_ty = BindingType::Constant(Type::Felt);
957                    self.bindings.insert(binding, binding_ty);
958                    Expr::Const(Span::new(
959                        span,
960                        ConstantExpr::Scalar((range.start + index) as u64),
961                    ))
962                }
963                // If the iterable was a vector, the abstract value is whatever expression is at
964                // the corresponding index of the vector.
965                Expr::Vector(elems) => {
966                    let abstract_value = elems[index].clone();
967                    let binding_ty = self.expr_binding_type(&abstract_value).unwrap();
968                    self.bindings.insert(binding, binding_ty);
969                    abstract_value
970                }
971                // If the iterable was a matrix, the abstract value is a vector of expressions
972                // representing the current row of the matrix. We calculate the binding type of
973                // each element in that vector so that accesses into the vector are well typed.
974                Expr::Matrix(rows) => {
975                    let row: Vec<Expr> = rows[index]
976                        .iter()
977                        .cloned()
978                        .map(|se| se.try_into().unwrap())
979                        .collect();
980                    let mut tys = vec![];
981                    for elem in row.iter() {
982                        tys.push(self.expr_binding_type(elem).unwrap());
983                    }
984                    let binding_ty = BindingType::Vector(tys);
985                    self.bindings.insert(binding, binding_ty);
986                    Expr::Vector(Span::new(rows.span(), row))
987                }
988                // If the iterable was a variable/access, then we must first index into that
989                // access, and then rewrite it, if applicable.
990                Expr::SymbolAccess(access) => {
991                    // The access here must be of aggregate type, so index into it for the current iteration
992                    let mut current_access = access.access(AccessType::Index(index)).unwrap();
993                    // Rewrite the resulting access if we have a rewrite for the underlying symbol
994                    if let Some(rewrite) = self.get_trace_access_rewrite(&current_access) {
995                        current_access = rewrite;
996                    }
997                    let binding_ty = self.access_binding_type(&current_access).unwrap();
998                    self.bindings.insert(binding, binding_ty);
999                    Expr::SymbolAccess(current_access)
1000                }
1001                // Binary expressions are scalar, so cannot be used as iterables, and we don't
1002                // (currently) support nested comprehensions, so it is never possible to observe
1003                // these expression types here. Calls should have been lifted prior to expansion.
1004                Expr::Call(_)
1005                | Expr::Binary(_)
1006                | Expr::ListComprehension(_)
1007                | Expr::Let(_)
1008                | Expr::BusOperation(_)
1009                | Expr::Null(_)
1010                | Expr::Unconstrained(_) => {
1011                    unreachable!()
1012                }
1013            };
1014            bound_values.insert(binding, abstract_value);
1015        }
1016
1017        // Clone the comprehension body for this iteration, so we don't modify the original
1018        let mut body = lc.body.as_ref().clone();
1019
1020        // Rewrite all references to the iterable bindings in the comprehension body
1021        let mut visitor = RewriteIterableBindingsVisitor {
1022            values: &bound_values,
1023        };
1024        if let ControlFlow::Break(err) = visitor.visit_mut_scalar_expr(&mut body) {
1025            return Err(err);
1026        }
1027
1028        // Next, handle comprehension filters/selectors as follows:
1029        //
1030        // 1. Selectors are evaluated in the same context as the body, so we must visit iterable references in the same way.
1031        // 2. If a selector has a constant value, we can elide the selector for this iteration. Furthermore, in situations where
1032        // the selector is known false, we can elide the expansion of this iteration entirely.
1033        //
1034        // Since the selector is the last piece we need to construct the Statement corresponding to the expansion of
1035        // this iteration, we do that now before proceeding to the next step.
1036        let statement = if let Some(mut selector) = lc.selector.clone() {
1037            assert!(
1038                self.in_comprehension_constraint,
1039                "selectors are not permitted in list comprehensions"
1040            );
1041            // #1
1042            if let ControlFlow::Break(err) = visitor.visit_mut_scalar_expr(&mut selector) {
1043                return Err(err);
1044            }
1045            // #2
1046            match selector {
1047                // If the selector value is zero, or false, we can elide the expansion entirely
1048                ScalarExpr::Const(value) if value.item == 0 => return Ok(vec![]),
1049                // If the selector value is non-zero, or true, we can elide just the selector
1050                ScalarExpr::Const(_) => Statement::Enforce(body),
1051                // We have a selector that requires evaluation at runtime, we need to emit a conditional scalar constraint
1052                other => Statement::EnforceIf(body, other),
1053            }
1054        } else if self.in_comprehension_constraint {
1055            Statement::Enforce(body)
1056        } else {
1057            Statement::Expr(body.try_into().unwrap())
1058        };
1059
1060        // Next, although we've rewritten the comprehension body corresponding to this iteration, we
1061        // haven't yet performed inlining on it. We do that now, while all of the bindings are
1062        // in scope with the proper values. The result of that expansion is what we emit as the result
1063        // for this iteration.
1064        self.expand_statement(statement)
1065    }
1066
1067    /// This function handles inlining evaluator function calls.
1068    ///
1069    /// At this point, semantic analysis has verified that the call arguments are valid, in
1070    /// that the number of trace columns passed matches the number of columns expected by the
1071    /// function parameters. However, the number and type of bindings are permitted to be
1072    /// different, as long as the vectors are the same size when expanded - in effect, re-grouping
1073    /// the trace columns at the call boundary.
1074    fn expand_evaluator_callsite(
1075        &mut self,
1076        call: Call,
1077    ) -> Result<Vec<Statement>, SemanticAnalysisError> {
1078        // The callee is guaranteed to be resolved and exist at this point
1079        let callee = call
1080            .callee
1081            .resolved()
1082            .expect("callee should have been resolved by now");
1083        // We clone the evaluator here as we will be modifying the body during the
1084        // inlining process, and we must not modify the original
1085        let mut evaluator = self.evaluators.get(&callee).unwrap().clone();
1086
1087        // This will be the initial set of bindings visible within the evaluator body
1088        //
1089        // This is distinct from `self.bindings` at this point, because the evaluator doesn't
1090        // inherit the caller's scope, it has an entirely new one.
1091        let mut eval_bindings = LexicalScope::default();
1092
1093        // Add all referenced (and thus imported) items from the evaluator module
1094        //
1095        // NOTE: This will include constants, periodic columns, and other functions
1096        for (qid, binding_ty) in self.imported.iter() {
1097            if qid.module == callee.module {
1098                eval_bindings.insert(*qid.as_ref(), binding_ty.clone());
1099            }
1100        }
1101
1102        // Add trace columns, and other root declarations to the set of
1103        // bindings visible in the evaluator body, _if_ the evaluator is defined in the
1104        // root module.
1105        let is_evaluator_in_root = callee.module == self.root;
1106        if is_evaluator_in_root {
1107            for segment in self.trace.iter() {
1108                eval_bindings.insert(
1109                    segment.name,
1110                    BindingType::TraceColumn(TraceBinding {
1111                        span: segment.name.span(),
1112                        segment: segment.id,
1113                        name: Some(segment.name),
1114                        offset: 0,
1115                        size: segment.size,
1116                        ty: Type::Vector(segment.size),
1117                    }),
1118                );
1119                for binding in segment.bindings.iter().copied() {
1120                    eval_bindings.insert(
1121                        binding.name.unwrap(),
1122                        BindingType::TraceColumn(TraceBinding {
1123                            span: segment.name.span(),
1124                            segment: segment.id,
1125                            name: binding.name,
1126                            offset: binding.offset,
1127                            size: binding.size,
1128                            ty: binding.ty,
1129                        }),
1130                    );
1131                }
1132            }
1133
1134            for input in self.public_inputs.values() {
1135                eval_bindings.insert(
1136                    input.name(),
1137                    BindingType::PublicInput(Type::Vector(input.size())),
1138                );
1139            }
1140        }
1141
1142        // Match call arguments to function parameters, populating the set of rewrites
1143        // which should be performed on the inlined function body.
1144        //
1145        // NOTE: We create a new nested scope for the parameters in order to avoid conflicting
1146        // with the root declarations
1147        eval_bindings.enter();
1148        self.populate_evaluator_rewrites(
1149            &mut eval_bindings,
1150            call.args.as_slice(),
1151            evaluator.params.as_slice(),
1152        );
1153
1154        // While we're inlining the body, use the set of evaluator bindings we built above
1155        let prev_bindings = core::mem::replace(&mut self.bindings, eval_bindings);
1156
1157        // Expand the evaluator body into a block of statements
1158        self.expand_statement_block(&mut evaluator.body)?;
1159
1160        // Restore the caller's bindings before we leave
1161        self.bindings = prev_bindings;
1162
1163        Ok(evaluator.body)
1164    }
1165
1166    /// This function handles inlining pure function calls, which must produce an expression
1167    fn expand_function_callsite(&mut self, call: Call) -> Result<Expr, SemanticAnalysisError> {
1168        self.bindings.enter();
1169        // The callee is guaranteed to be resolved and exist at this point
1170        let callee = call
1171            .callee
1172            .resolved()
1173            .expect("callee should have been resolved by now");
1174
1175        if self.call_stack.contains(&callee) {
1176            let ifd = self
1177                .diagnostics
1178                .diagnostic(Severity::Error)
1179                .with_message("invalid recursive function call")
1180                .with_primary_label(call.span, "recursion occurs due to this function call");
1181            self.call_stack
1182                .iter()
1183                .rev()
1184                .fold(ifd, |ifd, caller| {
1185                    ifd.with_secondary_label(caller.span(), "which was called from")
1186                })
1187                .emit();
1188            return Err(SemanticAnalysisError::Invalid);
1189        } else {
1190            self.call_stack.push(callee);
1191        }
1192
1193        // We clone the function here as we will be modifying the body during the
1194        // inlining process, and we must not modify the original
1195        let mut function = self.functions.get(&callee).unwrap().clone();
1196
1197        // This will be the initial set of bindings visible within the function body
1198        //
1199        // This is distinct from `self.bindings` at this point, because the function doesn't
1200        // inherit the caller's scope, it has an entirely new one.
1201        let mut function_bindings = LexicalScope::default();
1202
1203        // Add all referenced (and thus imported) items from the function module
1204        //
1205        // NOTE: This will include constants, periodic columns, and other functions
1206        for (qid, binding_ty) in self.imported.iter() {
1207            if qid.module == callee.module {
1208                function_bindings.insert(*qid.as_ref(), binding_ty.clone());
1209            }
1210        }
1211
1212        // Add trace columns, and other root declarations to the set of
1213        // bindings visible in the function body, _if_ the function is defined in the
1214        // root module.
1215        let is_function_in_root = callee.module == self.root;
1216        if is_function_in_root {
1217            for segment in self.trace.iter() {
1218                function_bindings.insert(
1219                    segment.name,
1220                    BindingType::TraceColumn(TraceBinding {
1221                        span: segment.name.span(),
1222                        segment: segment.id,
1223                        name: Some(segment.name),
1224                        offset: 0,
1225                        size: segment.size,
1226                        ty: Type::Vector(segment.size),
1227                    }),
1228                );
1229                for binding in segment.bindings.iter().copied() {
1230                    function_bindings.insert(
1231                        binding.name.unwrap(),
1232                        BindingType::TraceColumn(TraceBinding {
1233                            span: segment.name.span(),
1234                            segment: segment.id,
1235                            name: binding.name,
1236                            offset: binding.offset,
1237                            size: binding.size,
1238                            ty: binding.ty,
1239                        }),
1240                    );
1241                }
1242            }
1243
1244            for input in self.public_inputs.values() {
1245                function_bindings.insert(
1246                    input.name(),
1247                    BindingType::PublicInput(Type::Vector(input.size())),
1248                );
1249            }
1250        }
1251
1252        // Match call arguments to function parameters, populating the set of rewrites
1253        // which should be performed on the inlined function body.
1254        //
1255        // NOTE: We create a new nested scope for the parameters in order to avoid conflicting
1256        // with the root declarations
1257        function_bindings.enter();
1258        self.populate_function_rewrites(
1259            &mut function_bindings,
1260            call.args.as_slice(),
1261            function.params.as_slice(),
1262        );
1263
1264        // While we're inlining the body, use the set of function bindings we built above
1265        let prev_bindings = core::mem::replace(&mut self.bindings, function_bindings);
1266
1267        // Expand the function body into a block of statements
1268        self.expand_statement_block(&mut function.body)?;
1269
1270        // Restore the caller's bindings before we leave
1271        self.bindings = prev_bindings;
1272
1273        // We're done expanding this call, so remove it from the call stack
1274        self.call_stack.pop();
1275
1276        match function.body.pop().unwrap() {
1277            Statement::Expr(expr) => Ok(expr),
1278            Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))),
1279            Statement::Enforce(_)
1280            | Statement::EnforceIf(_, _)
1281            | Statement::EnforceAll(_)
1282            | Statement::BusEnforce(_) => {
1283                panic!("unexpected constraint in function body")
1284            }
1285        }
1286    }
1287
1288    /// Populate the set of access rewrites, as well as the initial set of bindings to use when inlining an evaluator function.
1289    ///
1290    /// This is done by resolving the arguments provided by the call to the evaluator, with the parameter list of the evaluator itself.
1291    fn populate_evaluator_rewrites(
1292        &mut self,
1293        eval_bindings: &mut LexicalScope<Identifier, BindingType>,
1294        args: &[Expr],
1295        params: &[TraceSegment],
1296    ) {
1297        // Reset the rewrites set
1298        self.rewrites.clear();
1299
1300        // Each argument corresponds to a function parameter, each of which represents a single trace segment
1301        for (arg, segment) in args.iter().zip(params.iter()) {
1302            match arg {
1303                // A variable was passed as an argument for this segment
1304                //
1305                // Arguments by now must have been validated by semantic analysis, and specifically
1306                // in this case, the number of columns in the variable and the number expected by the
1307                // parameter we're binding must be the same. However, a variable may represent a single
1308                // column, a contiguous slice of columns, or a vector of such variables which may be
1309                // non-contiguous.
1310                Expr::SymbolAccess(access) => {
1311                    // We use a `BindingType` to track the state of the current input binding being processed.
1312                    //
1313                    // The initial state is given by the binding type of the access itself, but as we destructure
1314                    // the binding according to the parameter binding pattern, we may pop off columns, in which
1315                    // case the binding type here gets updated with the remaining columns
1316                    let mut binding_ty = Some(self.access_binding_type(access).unwrap());
1317                    // We visit each binding in the trace segment represented by the parameter pattern,
1318                    // consuming columns from the input argument until all bindings are matched up.
1319                    for binding in segment.bindings.iter() {
1320                        // Trace binding declarations are never anonymous, i.e. always have a name
1321                        let binding_name = binding.name.unwrap();
1322                        // We can safely assume that there is a binding type available here,
1323                        // otherwise the semantic analysis pass missed something
1324                        let bt = binding_ty.take().unwrap();
1325                        // Split out the needed columns from the input binding
1326                        //
1327                        // We can safely assume we were able to obtain all of the needed columns,
1328                        // as the semantic analyzer should have caught mismatches. Note, however,
1329                        // that these columns may have been gathered from multiple bindings in the caller
1330                        let (matched, rest) = bt.split_columns(binding.size).unwrap();
1331                        self.rewrites.insert(binding_name);
1332                        eval_bindings.insert(binding_name, matched);
1333                        // Update `binding_ty` with whatever remains of the input
1334                        binding_ty = rest;
1335                    }
1336                }
1337                // An empty vector means there are no bindings for this segment
1338                Expr::Const(Span {
1339                    item: ConstantExpr::Vector(items),
1340                    ..
1341                }) if items.is_empty() => {
1342                    continue;
1343                }
1344                // A vector of bindings was passed as an argument for this segment
1345                //
1346                // This is by far the most complicated scenario to handle when matching up arguments
1347                // to parameters, as we can get them in a variety of combinations:
1348                //
1349                // 1. An exact match in the number and size of bindings in both the input vector and the
1350                //    segment represented by the current parameter
1351                // 2. The same number of elements in the vector as bindings in the segment, but the elements
1352                //    have different sizes, implicitly regrouping columns between caller/callee
1353                // 3. More elements in the vector than bindings in the segment, typically because the function
1354                //    parameter groups together columns passed individually in the caller
1355                // 4. Fewer elements in the vector than bindings in the segment, typically because the function
1356                //    parameter destructures an input into multiple bindings
1357                Expr::Vector(inputs) => {
1358                    // The index of the input we're currently extracting columns from
1359                    let mut index = 0;
1360                    // A `BindingType` representing the current trace binding we're extracting columns from,
1361                    // can be either of TraceColumn or Vector type
1362                    let mut binding_ty = None;
1363                    // We drive the matching process by consuming input columns for each segment binding in turn
1364                    'next_binding: for binding in segment.bindings.iter() {
1365                        let binding_name = binding.name.unwrap();
1366                        let mut needed = binding.size;
1367
1368                        // When there are insufficient columns for the current parameter binding in the current
1369                        // input, we must construct a vector of trace bindings to use as the binding type of
1370                        // the current parameter binding when we have all of the needed columns. This is because
1371                        // the input columns may come from different trace bindings in the caller, so we can't
1372                        // use a single trace binding to represent them.
1373                        let mut set = vec![];
1374
1375                        // We may need to consume multiple input elements to fulfill the needed columns of
1376                        // the current parameter binding - we advance this loop whenever we have exhausted
1377                        // an input and need to move on to the next one. We may enter this loop with the
1378                        // same input index across multiple parameter bindings when the input element is
1379                        // larger than the parameter binding, in which case we have split the input and
1380                        // stored the remainder in `binding_ty`.
1381                        loop {
1382                            let input = &inputs[index];
1383                            // The input expression must have been a symbol access, as matrices of columns
1384                            // aren't a thing, and there is no other expression type which can produce trace
1385                            // bindings.
1386                            let Expr::SymbolAccess(access) = input else {
1387                                panic!("unexpected element in trace column vector: {input:#?}")
1388                            };
1389                            // Unless we have leftover input, initialize `binding_ty` with the binding type of this input
1390                            let bt = binding_ty
1391                                .take()
1392                                .unwrap_or_else(|| self.access_binding_type(access).unwrap());
1393                            match bt.split_columns(needed) {
1394                                Ok((matched, rest)) => {
1395                                    let eval_binding = match matched {
1396                                        BindingType::TraceColumn(matched) => {
1397                                            if !set.is_empty() {
1398                                                // We've obtained all the remaining columns from the current input element,
1399                                                // possibly with leftovers in the input. However, because we've started
1400                                                // constructing a vector binding, we must ensure the matched binding is
1401                                                // expanded into individual columns
1402                                                for offset in 0..matched.size {
1403                                                    set.push(BindingType::TraceColumn(
1404                                                        TraceBinding {
1405                                                            offset: matched.offset + offset,
1406                                                            size: 1,
1407                                                            ..matched
1408                                                        },
1409                                                    ));
1410                                                }
1411                                                BindingType::Vector(set)
1412                                            } else {
1413                                                // The input element perfectly matched the current binding
1414                                                BindingType::TraceColumn(matched)
1415                                            }
1416                                        }
1417                                        BindingType::Vector(mut matched) => {
1418                                            if set.is_empty() {
1419                                                // The input binding was a vector, and had the same number, or
1420                                                // more, of columns expected by the parameter binding, but may contain
1421                                                // non-contiguous bindings, so we are unable to use the symbol of
1422                                                // the access when rewriting accesses to this parameter
1423                                                BindingType::Vector(matched)
1424                                            } else {
1425                                                // Same as above, but we need to append the matched bindings to
1426                                                // the set we've already started building
1427                                                set.append(&mut matched);
1428                                                BindingType::Vector(set)
1429                                            }
1430                                        }
1431                                        _ => unreachable!(),
1432                                    };
1433                                    // This binding has been fulfilled, move to the next one
1434                                    self.rewrites.insert(binding_name);
1435                                    eval_bindings.insert(binding_name, eval_binding);
1436                                    binding_ty = rest;
1437                                    // If we have no more columns remaining in this input, advance
1438                                    // to the next input starting with the next binding
1439                                    if binding_ty.is_none() {
1440                                        index += 1;
1441                                    }
1442                                    continue 'next_binding;
1443                                }
1444                                Err(BindingType::TraceColumn(partial)) => {
1445                                    // The input binding wasn't big enough for the parameter, so we must
1446                                    // start constructing a vector of bindings since the next input is
1447                                    // unlikely to be contiguous with the current input
1448                                    for offset in 0..partial.size {
1449                                        set.push(BindingType::TraceColumn(TraceBinding {
1450                                            offset: partial.offset + offset,
1451                                            size: 1,
1452                                            ..partial
1453                                        }));
1454                                    }
1455                                    needed -= partial.size;
1456                                    index += 1;
1457                                }
1458                                Err(BindingType::Vector(mut partial)) => {
1459                                    // Same as above, but we got a vector instead
1460                                    set.append(&mut partial);
1461                                    needed -= partial.len();
1462                                    index += 1;
1463                                }
1464                                Err(_) => unreachable!(),
1465                            }
1466                        }
1467                    }
1468                }
1469                // This should not be possible at this point, but would be an invalid evaluator call,
1470                // only trace columns are permitted
1471                expr => unreachable!("{:#?}", expr),
1472            }
1473        }
1474    }
1475
1476    fn populate_function_rewrites(
1477        &mut self,
1478        function_bindings: &mut LexicalScope<Identifier, BindingType>,
1479        args: &[Expr],
1480        params: &[(Identifier, Type)],
1481    ) {
1482        // Reset the rewrites set
1483        self.rewrites.clear();
1484
1485        for (arg, (param_name, param_ty)) in args.iter().zip(params.iter()) {
1486            // We can safely assume that there is a binding type available here,
1487            // otherwise the semantic analysis pass missed something
1488            let binding_ty = self.expr_binding_type(arg).unwrap();
1489            debug_assert_eq!(binding_ty.ty(), Some(*param_ty), "unexpected type mismatch");
1490            self.rewrites.insert(*param_name);
1491            function_bindings.insert(*param_name, binding_ty);
1492        }
1493    }
1494
1495    /// Returns a new [SymbolAccess] which should be used in place of `access` in the current scope.
1496    ///
1497    /// This function should only be called on accesses which have a trace column/param [BindingType],
1498    /// but it will simply return `None` for other types, so it is safe to call on all accesses.
1499    fn get_trace_access_rewrite(&self, access: &SymbolAccess) -> Option<SymbolAccess> {
1500        if self.rewrites.contains(access.name.as_ref()) {
1501            // If we have a rewrite for this access, then the bindings map will
1502            // have an accurate trace binding for us; rewrite this access to be
1503            // relative to that trace binding
1504            match self.access_binding_type(access).unwrap() {
1505                BindingType::TraceColumn(tb) => {
1506                    let original_binding = self.trace[tb.segment]
1507                        .bindings
1508                        .iter()
1509                        .find(|b| b.name == tb.name)
1510                        .unwrap();
1511                    let (access_type, ty) = if original_binding.size == 1 {
1512                        (AccessType::Default, Type::Felt)
1513                    } else if tb.size == 1 {
1514                        (
1515                            AccessType::Index(tb.offset - original_binding.offset),
1516                            Type::Felt,
1517                        )
1518                    } else {
1519                        let start = tb.offset - original_binding.offset;
1520                        (
1521                            AccessType::Slice(RangeExpr::from(start..(start + tb.size))),
1522                            Type::Vector(tb.size),
1523                        )
1524                    };
1525                    Some(SymbolAccess {
1526                        span: access.span(),
1527                        name: ResolvableIdentifier::Local(tb.name.unwrap()),
1528                        access_type,
1529                        offset: access.offset,
1530                        ty: Some(ty),
1531                    })
1532                }
1533                // We only have a rewrite when the binding type is TraceColumn
1534                invalid => panic!(
1535                    "unexpected trace access binding type, expected column(s), got: {:#?}",
1536                    &invalid
1537                ),
1538            }
1539        } else {
1540            None
1541        }
1542    }
1543
1544    fn expr_binding_type(&self, expr: &Expr) -> Result<BindingType, InvalidAccessError> {
1545        let mut bindings = self.bindings.clone();
1546        eval_expr_binding_type(expr, &mut bindings, &self.imported)
1547    }
1548
1549    /// Returns the effective [BindingType] of the value produced by the given access
1550    fn access_binding_type(&self, expr: &SymbolAccess) -> Result<BindingType, InvalidAccessError> {
1551        eval_access_binding_type(expr, &self.bindings, &self.imported)
1552    }
1553}
1554
1555/// Returns the effective [BindingType] of the given expression
1556fn eval_expr_binding_type(
1557    expr: &Expr,
1558    bindings: &mut LexicalScope<Identifier, BindingType>,
1559    imported: &HashMap<QualifiedIdentifier, BindingType>,
1560) -> Result<BindingType, InvalidAccessError> {
1561    match expr {
1562        Expr::Const(constant) => Ok(BindingType::Local(constant.ty())),
1563        Expr::Range(range) => Ok(BindingType::Local(Type::Vector(
1564            range.to_slice_range().len(),
1565        ))),
1566        Expr::Vector(elems) => match elems[0].ty() {
1567            None | Some(Type::Felt) => {
1568                let mut binding_tys = Vec::with_capacity(elems.len());
1569                for elem in elems.iter() {
1570                    binding_tys.push(eval_expr_binding_type(elem, bindings, imported)?);
1571                }
1572                Ok(BindingType::Vector(binding_tys))
1573            }
1574            Some(Type::Vector(cols)) => {
1575                let rows = elems.len();
1576                Ok(BindingType::Local(Type::Matrix(rows, cols)))
1577            }
1578            Some(_) => unreachable!(),
1579        },
1580        Expr::Matrix(expr) => {
1581            let rows = expr.len();
1582            let columns = expr[0].len();
1583            Ok(BindingType::Local(Type::Matrix(rows, columns)))
1584        }
1585        Expr::SymbolAccess(access) => eval_access_binding_type(access, bindings, imported),
1586        Expr::Call(Call { ty: None, .. }) => Err(InvalidAccessError::InvalidBinding),
1587        Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)),
1588        Expr::Binary(_) => Ok(BindingType::Local(Type::Felt)),
1589        Expr::ListComprehension(lc) => {
1590            // The types of all iterables must be the same, so the type of
1591            // the comprehension is given by the type of the iterables. We
1592            // just pick the first iterable to tell us the type
1593            eval_expr_binding_type(&lc.iterables[0], bindings, imported)
1594        }
1595        Expr::Let(let_expr) => eval_let_binding_ty(let_expr, bindings, imported),
1596        Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => {
1597            unimplemented!("buses are not implemented for this Pipeline")
1598        }
1599    }
1600}
1601
1602/// Returns the effective [BindingType] of the value produced by the given access
1603fn eval_access_binding_type(
1604    expr: &SymbolAccess,
1605    bindings: &LexicalScope<Identifier, BindingType>,
1606    imported: &HashMap<QualifiedIdentifier, BindingType>,
1607) -> Result<BindingType, InvalidAccessError> {
1608    let binding_ty = bindings
1609        .get(expr.name.as_ref())
1610        .or_else(|| match expr.name {
1611            ResolvableIdentifier::Resolved(qid) => imported.get(&qid),
1612            _ => None,
1613        })
1614        .ok_or(InvalidAccessError::UndefinedVariable)
1615        .clone()?;
1616    binding_ty.access(expr.access_type.clone())
1617}
1618
1619fn eval_let_binding_ty(
1620    let_expr: &Let,
1621    bindings: &mut LexicalScope<Identifier, BindingType>,
1622    imported: &HashMap<QualifiedIdentifier, BindingType>,
1623) -> Result<BindingType, InvalidAccessError> {
1624    let variable_ty = eval_expr_binding_type(&let_expr.value, bindings, imported)?;
1625    bindings.enter();
1626    bindings.insert(let_expr.name, variable_ty);
1627    let binding_ty = match let_expr.body.last().unwrap() {
1628        Statement::Let(inner_let) => eval_let_binding_ty(inner_let, bindings, imported)?,
1629        Statement::Expr(expr) => eval_expr_binding_type(expr, bindings, imported)?,
1630        Statement::Enforce(_)
1631        | Statement::EnforceIf(_, _)
1632        | Statement::EnforceAll(_)
1633        | Statement::BusEnforce(_) => {
1634            unreachable!()
1635        }
1636    };
1637    bindings.exit();
1638    Ok(binding_ty)
1639}
1640
1641/// This visitor is used to rewrite uses of iterable bindings within a comprehension body,
1642/// including expansion of constant accesses.
1643struct RewriteIterableBindingsVisitor<'a> {
1644    /// This map contains the set of symbols to be rewritten, and the abstract values which
1645    /// should replace them in the comprehension body.
1646    values: &'a HashMap<Identifier, Expr>,
1647}
1648impl RewriteIterableBindingsVisitor<'_> {
1649    fn rewrite_scalar_access(
1650        &mut self,
1651        access: SymbolAccess,
1652    ) -> ControlFlow<SemanticAnalysisError, Option<ScalarExpr>> {
1653        let result = match self.values.get(access.name.as_ref()) {
1654            Some(Expr::Const(constant)) => {
1655                let span = constant.span();
1656                match constant.item {
1657                    ConstantExpr::Scalar(value) => {
1658                        assert_eq!(access.access_type, AccessType::Default);
1659                        Some(ScalarExpr::Const(Span::new(span, value)))
1660                    }
1661                    ConstantExpr::Vector(ref elems) => match access.access_type {
1662                        AccessType::Index(idx) => {
1663                            Some(ScalarExpr::Const(Span::new(span, elems[idx])))
1664                        }
1665                        invalid => panic!(
1666                            "expected vector to be reduced to scalar by access, got {invalid:#?}"
1667                        ),
1668                    },
1669                    ConstantExpr::Matrix(ref rows) => match access.access_type {
1670                        AccessType::Matrix(row, col) => {
1671                            Some(ScalarExpr::Const(Span::new(span, rows[row][col])))
1672                        }
1673                        invalid => panic!(
1674                            "expected matrix to be reduced to scalar by access, got {invalid:#?}",
1675                        ),
1676                    },
1677                }
1678            }
1679            Some(Expr::Range(range)) => {
1680                let span = range.span();
1681                let range = range.to_slice_range();
1682                match access.access_type {
1683                    AccessType::Index(idx) => Some(ScalarExpr::Const(Span::new(
1684                        span,
1685                        (range.start + idx) as u64,
1686                    ))),
1687                    invalid => {
1688                        panic!("expected range to be reduced to scalar by access, got {invalid:#?}",)
1689                    }
1690                }
1691            }
1692            Some(Expr::Vector(elems)) => {
1693                match access.access_type {
1694                    AccessType::Index(idx) => Some(elems[idx].clone().try_into().unwrap()),
1695                    // This implies that the vector contains an element which is vector-like,
1696                    // if the value at `idx` is not, this is an invalid access
1697                    AccessType::Matrix(idx, nested_idx) => match &elems[idx] {
1698                        Expr::SymbolAccess(saccess) => {
1699                            let access = saccess.access(AccessType::Index(nested_idx)).unwrap();
1700                            self.rewrite_scalar_access(access)?
1701                        }
1702                        invalid => panic!(
1703                            "expected vector-like value at {}[{idx}], got: {invalid:#?}",
1704                            access.name.as_ref(),
1705                        ),
1706                    },
1707                    invalid => panic!(
1708                        "expected vector to be reduced to scalar by access, got {invalid:#?}"
1709                    ),
1710                }
1711            }
1712            Some(Expr::Matrix(elems)) => match access.access_type {
1713                AccessType::Matrix(row, col) => Some(elems[row][col].clone()),
1714                invalid => {
1715                    panic!("expected matrix to be reduced to scalar by access, got {invalid:#?}")
1716                }
1717            },
1718            Some(Expr::SymbolAccess(symbol_access)) => {
1719                let mut new_access = symbol_access.access(access.access_type).unwrap();
1720                new_access.offset = access.offset;
1721                Some(ScalarExpr::SymbolAccess(new_access))
1722            }
1723            // These types of expressions will never be observed in this context, as they are
1724            // not valid iterable expressions (except calls, but those are lifted prior to rewrite
1725            // so that their use in this context is always a symbol access).
1726            Some(
1727                Expr::Call(_)
1728                | Expr::Binary(_)
1729                | Expr::ListComprehension(_)
1730                | Expr::Let(_)
1731                | Expr::BusOperation(_)
1732                | Expr::Null(_)
1733                | Expr::Unconstrained(_),
1734            ) => {
1735                unreachable!()
1736            }
1737            None => None,
1738        };
1739        ControlFlow::Continue(result)
1740    }
1741}
1742impl VisitMut<SemanticAnalysisError> for RewriteIterableBindingsVisitor<'_> {
1743    fn visit_mut_scalar_expr(
1744        &mut self,
1745        expr: &mut ScalarExpr,
1746    ) -> ControlFlow<SemanticAnalysisError> {
1747        match expr {
1748            // Nothing to do with constants
1749            ScalarExpr::Const(_) => ControlFlow::Continue(()),
1750            // If we observe an access, try to rewrite it as an iterable binding, if it is
1751            // not a candidate for rewrite, leave it alone.
1752            //
1753            // NOTE: We handle BoundedSymbolAccess here even though comprehension constraints are not
1754            // permitted in boundary_constraints currently. That is handled elsewhere, we just need to
1755            // make sure the symbols themselves are rewritten properly here.
1756            ScalarExpr::SymbolAccess(access)
1757            | ScalarExpr::BoundedSymbolAccess(BoundedSymbolAccess { column: access, .. }) => {
1758                if let Some(replacement) = self.rewrite_scalar_access(access.clone())? {
1759                    *expr = replacement;
1760                    return ControlFlow::Continue(());
1761                }
1762                ControlFlow::Continue(())
1763            }
1764            // We need to visit both operands of a binary expression - but while we're here,
1765            // check to see if resolving the operands reduces to a constant expression that
1766            // can be folded.
1767            ScalarExpr::Binary(binary_expr) => {
1768                self.visit_mut_binary_expr(binary_expr)?;
1769                match constant_propagation::try_fold_binary_expr(binary_expr) {
1770                    Ok(Some(folded)) => {
1771                        *expr = ScalarExpr::Const(folded);
1772                        ControlFlow::Continue(())
1773                    }
1774                    Ok(None) => ControlFlow::Continue(()),
1775                    Err(err) => ControlFlow::Break(SemanticAnalysisError::InvalidExpr(err)),
1776                }
1777            }
1778            // If we observe a call here, just rewrite the arguments, inlining happens elsewhere
1779            ScalarExpr::Call(call) => {
1780                for arg in call.args.iter_mut() {
1781                    self.visit_mut_expr(arg)?;
1782                }
1783                ControlFlow::Continue(())
1784            }
1785            // We rewrite comprehension bodies before they are expanded, so it should never be
1786            // the case that we encounter a let here, as they can only be introduced in scalar
1787            // expression position as a result of inlining/expansion
1788            ScalarExpr::Let(_) => unreachable!(),
1789            ScalarExpr::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => {
1790                ControlFlow::Break(SemanticAnalysisError::Invalid)
1791            }
1792        }
1793    }
1794}
1795
1796/// This visitor is used to apply a selector expression to all constraints in a block
1797///
1798/// For constraints which already have a selector, this rewrites those selectors to be the
1799/// logical AND of the original selector and the selector being applied.
1800struct ApplyConstraintSelector<'a> {
1801    selector: &'a ScalarExpr,
1802}
1803impl VisitMut<SemanticAnalysisError> for ApplyConstraintSelector<'_> {
1804    fn visit_mut_statement(
1805        &mut self,
1806        statement: &mut Statement,
1807    ) -> ControlFlow<SemanticAnalysisError> {
1808        match statement {
1809            Statement::Let(expr) => self.visit_mut_let(expr),
1810            Statement::Enforce(expr) => {
1811                let expr =
1812                    core::mem::replace(expr, ScalarExpr::Const(Span::new(SourceSpan::UNKNOWN, 0)));
1813                *statement = Statement::EnforceIf(expr, self.selector.clone());
1814                ControlFlow::Continue(())
1815            }
1816            Statement::EnforceIf(_, selector) => {
1817                // Combine the selectors
1818                let lhs = core::mem::replace(
1819                    selector,
1820                    ScalarExpr::Const(Span::new(SourceSpan::UNKNOWN, 0)),
1821                );
1822                let rhs = self.selector.clone();
1823                *selector = ScalarExpr::Binary(BinaryExpr::new(
1824                    self.selector.span(),
1825                    BinaryOp::Mul,
1826                    lhs,
1827                    rhs,
1828                ));
1829                ControlFlow::Continue(())
1830            }
1831            Statement::EnforceAll(_) => unreachable!(),
1832            Statement::Expr(_) => ControlFlow::Continue(()),
1833            Statement::BusEnforce(_) => ControlFlow::Break(SemanticAnalysisError::Invalid),
1834        }
1835    }
1836}
1837
1838/// This helper function is used to perform a mutation/replacement based on the expression
1839/// representing the effective value of a `let`-tree.
1840///
1841/// In particular, this function traverses the tree until it reaches the final `let` body
1842/// and the last `Expr` in that body. When it does, it invokes `callback` with a mutable
1843/// reference to that `Expr`. The callback may choose to simply mutate the `Expr`, or it
1844/// may return a new `Statement` which will be used to replace the `Statement` which
1845/// contained the `Expr` given to the callback.
1846///
1847/// This is used when expanding calls and list comprehensions, where the expanded form
1848/// of these is potentially a `let` tree, and we desire to place additional statements
1849/// in the bottom-most block, or perform some transformation on the expression which acts
1850/// as the result of the tree.
1851fn with_let_result<F>(
1852    inliner: &mut Inlining,
1853    entry: &mut Vec<Statement>,
1854    callback: F,
1855) -> Result<(), SemanticAnalysisError>
1856where
1857    F: FnOnce(&mut Inlining, &mut Expr) -> Result<Option<Statement>, SemanticAnalysisError>,
1858{
1859    // Preserve the original lexical scope to be restored on exit
1860    let prev = inliner.bindings.clone();
1861
1862    // SAFETY: We must use a raw pointer here because the Rust compiler is not able to
1863    // see that we only ever use the mutable reference once, and that the reference
1864    // is never aliased.
1865    //
1866    // Both of these guarantees are in fact upheld here however, as each iteration of the loop
1867    // is either the last iteration (when we use the mutable reference to mutate the end of the
1868    // bottom-most block), or a traversal to the last child of the current let expression.
1869    // We never alias the mutable reference, and in fact immediately convert back to a mutable
1870    // reference inside the loop to ensure that within the loop body we have some degree of
1871    // compiler-assisted checking of that invariant.
1872    let mut current_block = Some(entry as *mut Vec<Statement>);
1873    while let Some(parent_block) = current_block.take() {
1874        // SAFETY: We convert the pointer back to a mutable reference here before
1875        // we do anything else to ensure the usual aliasing rules are enforced.
1876        //
1877        // It is further guaranteed that this reference is never improperly aliased
1878        // across iterations, as each iteration is visiting a child of the previous
1879        // iteration's node, i.e. what we're doing here is equivalent to holding a
1880        // mutable reference and using it to mutate a field in a deeply nested struct.
1881        let parent_block = unsafe { &mut *parent_block };
1882        // A block is guaranteed to always have at least one statement here
1883        match parent_block.last_mut().unwrap() {
1884            // When we hit a block whose last statement is an expression, which
1885            // must also be the bottom-most block of this tree. This expression
1886            // is the effective value of the `let` tree. We will replace this
1887            // node if the callback we were given returns a new `Statement`. In
1888            // either case, we're done once we've handled the callback result.
1889            Statement::Expr(value) => match callback(inliner, value) {
1890                Ok(Some(replacement)) => {
1891                    parent_block.pop();
1892                    parent_block.push(replacement);
1893                    break;
1894                }
1895                Ok(None) => break,
1896                Err(err) => {
1897                    inliner.bindings = prev;
1898                    return Err(err);
1899                }
1900            },
1901            // We've traversed down a level in the let-tree, but there are more to go.
1902            // Set up the next iteration to visit the next block down in the tree.
1903            Statement::Let(let_expr) => {
1904                // Register this binding
1905                let binding_ty = inliner.expr_binding_type(&let_expr.value).unwrap();
1906                inliner.bindings.insert(let_expr.name, binding_ty);
1907                // Set up the next iteration
1908                current_block = Some(&mut let_expr.body as *mut Vec<Statement>);
1909                continue;
1910            }
1911            // No other statements types are possible here
1912            _ => unreachable!(),
1913        }
1914    }
1915
1916    // Restore the original lexical scope
1917    inliner.bindings = prev;
1918
1919    Ok(())
1920}