air_parser/transforms/inlining.rs
1use std::{
2 collections::{BTreeMap, HashMap, HashSet, VecDeque},
3 ops::ControlFlow,
4 vec,
5};
6
7use air_pass::Pass;
8use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned};
9
10use crate::{
11 ast::{visit::VisitMut, *},
12 sema::{BindingType, LexicalScope, SemanticAnalysisError},
13 symbols,
14};
15
16use super::constant_propagation;
17
18/// This pass performs the following transformations on a [Program]:
19///
20/// * Monomorphizing and inlining evaluators/functions at their call sites
21/// * Unrolling constraint comprehensions into a sequence of scalar constraints
22/// * Unrolling list comprehensions into a tree of `let` statements which end in
23/// a vector expression (the implicit result of the tree). Each iteration of the
24/// unrolled comprehension is reified as a value and bound to a variable so that
25/// other transformations may refer to it directly.
26/// * Rewriting aliases of top-level declarations to refer to those declarations directly
27/// * Removing let-bound variables which are unused, which is also used to clean up
28/// after the aliasing rewrite mentioned above.
29///
30/// The trickiest transformation comes with inlining the body of evaluators at their
31/// call sites, as evaluator parameter lists can arbitrarily destructure/regroup columns
32/// provided as arguments for each trace segment. This means that columns can be passed
33/// in a variety of configurations as arguments, and the patterns expressed in the evaluator
34/// parameter list can arbitrarily reconfigure them for use in the evaluator body.
35///
36/// For example, let's say you call an evaluator `foo` with three columns, passed as individual
37/// bindings, like so: `foo([a, b, c])`. Let's further assume that the evaluator signature
38/// is defined as `ev foo([x[2], y])`. While you might expect that this would be an error,
39/// and that the caller would need to provide the columns in the same configuration, that
40/// is not the case. Instead, `a` and `b` are implicitly re-bound as a vector of trace column
41/// bindings for use in the function body. There is further no requirement that `a` and `b`
42/// are consecutive bindings either, as long as they are from the same trace segment. During
43/// compilation however, accesses to individual elements of the vector will be rewritten to use
44/// the correct binding in the caller after inlining, e.g. an access like `x[1]` becomes `b`.
45///
46/// This pass accomplishes three goals:
47///
48/// * Remove all function abstractions from the program
49/// * Remove all comprehensions from the program
50/// * Inline all constraints into the integrity and boundary constraints sections
51/// * Make all references to top-level declarations concrete
52///
53/// When done, it should be impossible for there to be any invalid trace column references.
54///
55/// It is expected that the provided [Program] has already been run through semantic analysis
56/// and constant propagation, so a number of assumptions are made with regard to what syntax can
57/// be observed at this stage of compilation (e.g. no references to constant declarations, no
58/// undefined variables, expressions are well-typed, etc.).
59pub struct Inlining<'a> {
60 // This may be unused for now, but it's helpful to assume its needed in case we want it in the future
61 #[allow(unused)]
62 diagnostics: &'a DiagnosticsHandler,
63 /// The name of the root module
64 root: Identifier,
65 /// The global trace segment configuration
66 trace: Vec<TraceSegment>,
67 /// The public_inputs declaration
68 public_inputs: BTreeMap<Identifier, PublicInput>,
69 /// All local/global bindings in scope
70 bindings: LexicalScope<Identifier, BindingType>,
71 /// The values of all let-bound variables in scope
72 let_bound: LexicalScope<Identifier, Expr>,
73 /// All items which must be referenced fully-qualified, namely periodic columns at this point
74 imported: HashMap<QualifiedIdentifier, BindingType>,
75 /// All evaluator functions in the program
76 evaluators: HashMap<QualifiedIdentifier, EvaluatorFunction>,
77 /// All pure functions in the program
78 functions: HashMap<QualifiedIdentifier, Function>,
79 /// A set of identifiers for which accesses should be rewritten.
80 ///
81 /// When an identifier is in this set, it means it is a local alias for a trace column,
82 /// and should be rewritten based on the current `BindingType` associated with the alias
83 /// identifier in `bindings`.
84 rewrites: HashSet<Identifier>,
85 /// The call stack during expansion of a function call.
86 ///
87 /// Each time we begin to expand a call, we check if it is already present on the call
88 /// stack, and if so, raise a diagnostic due to infinite recursion. If not, the callee
89 /// is pushed on the stack while we expand its body. When we finish expanding the body
90 /// of the callee, we pop it off this stack, and proceed as usual.
91 call_stack: Vec<QualifiedIdentifier>,
92 in_comprehension_constraint: bool,
93 next_ident_lc: usize,
94 next_ident: usize,
95}
96impl Pass for Inlining<'_> {
97 type Input<'a> = Program;
98 type Output<'a> = Program;
99 type Error = SemanticAnalysisError;
100
101 fn run<'a>(&mut self, mut program: Self::Input<'a>) -> Result<Self::Output<'a>, Self::Error> {
102 self.root = program.name;
103 self.evaluators = program
104 .evaluators
105 .iter()
106 .map(|(k, v)| (*k, v.clone()))
107 .collect();
108
109 self.functions = program
110 .functions
111 .iter()
112 .map(|(k, v)| (*k, v.clone()))
113 .collect();
114
115 // We'll be referencing the trace configuration during inlining, so keep a copy of it
116 self.trace.clone_from(&program.trace_columns);
117 // And the public inputs
118 self.public_inputs.clone_from(&program.public_inputs);
119
120 // Add all of the local bindings visible in the root module, except for
121 // constants and periodic columns, which by this point have been rewritten
122 // to use fully-qualified names (or in the case of constants, have been
123 // eliminated entirely)
124 //
125 // Trace first..
126 for segment in program.trace_columns.iter() {
127 self.bindings.insert(
128 segment.name,
129 BindingType::TraceColumn(TraceBinding {
130 span: segment.name.span(),
131 segment: segment.id,
132 name: Some(segment.name),
133 offset: 0,
134 size: segment.size,
135 ty: Type::Vector(segment.size),
136 }),
137 );
138 for binding in segment.bindings.iter().copied() {
139 self.bindings.insert(
140 binding.name.unwrap(),
141 BindingType::TraceColumn(TraceBinding {
142 span: segment.name.span(),
143 segment: segment.id,
144 name: binding.name,
145 offset: binding.offset,
146 size: binding.size,
147 ty: binding.ty,
148 }),
149 );
150 }
151 }
152 // Public inputs..
153 for input in program.public_inputs.values() {
154 self.bindings.insert(
155 input.name(),
156 BindingType::PublicInput(Type::Vector(input.size())),
157 );
158 }
159 // For periodic columns, we register the imported item, but do not add any to the local bindings.
160 for (name, periodic) in program.periodic_columns.iter() {
161 let binding_ty = BindingType::PeriodicColumn(periodic.values.len());
162 self.imported.insert(*name, binding_ty);
163 }
164
165 // The root of the inlining process is the integrity_constraints and
166 // boundary_constraints blocks. Function calls in inlined functions are
167 // inlined at the same time as the parent
168 self.expand_boundary_constraints(&mut program.boundary_constraints)?;
169 self.expand_integrity_constraints(&mut program.integrity_constraints)?;
170
171 Ok(program)
172 }
173}
174impl<'a> Inlining<'a> {
175 pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self {
176 Self {
177 diagnostics,
178 root: Identifier::new(SourceSpan::UNKNOWN, crate::symbols::Main),
179 trace: vec![],
180 public_inputs: Default::default(),
181 bindings: Default::default(),
182 let_bound: Default::default(),
183 imported: Default::default(),
184 evaluators: Default::default(),
185 functions: Default::default(),
186 rewrites: Default::default(),
187 in_comprehension_constraint: false,
188 call_stack: vec![],
189 next_ident_lc: 0,
190 next_ident: 0,
191 }
192 }
193
194 /// Generate a new variable
195 ///
196 /// This is only used when expanding list comprehensions, so we use a special prefix for
197 /// these generated identifiers to make it clear what they were expanded from.
198 fn get_next_ident_lc(&mut self, span: SourceSpan) -> Identifier {
199 let id = self.next_ident_lc;
200 self.next_ident_lc += 1;
201 Identifier::new(span, crate::Symbol::intern(format!("%lc{id}")))
202 }
203
204 fn get_next_ident(&mut self, span: SourceSpan) -> Identifier {
205 let id = self.next_ident;
206 self.next_ident += 1;
207 Identifier::new(span, crate::Symbol::intern(format!("%{id}")))
208 }
209
210 /// Inline/expand all of the statements in the `boundary_constraints` section
211 fn expand_boundary_constraints(
212 &mut self,
213 body: &mut Vec<Statement>,
214 ) -> Result<(), SemanticAnalysisError> {
215 // Save the current bindings set, as we're entering a new lexical scope
216 self.bindings.enter();
217 // Visit all of the statements, check variable usage, and track referenced imports
218 self.expand_statement_block(body)?;
219 // Restore the original lexical scope
220 self.bindings.exit();
221
222 Ok(())
223 }
224
225 /// Inline/expand all of the statements in the `integrity_constraints` section
226 fn expand_integrity_constraints(
227 &mut self,
228 body: &mut Vec<Statement>,
229 ) -> Result<(), SemanticAnalysisError> {
230 // Save the current bindings set, as we're entering a new lexical scope
231 self.bindings.enter();
232 // Visit all of the statements, check variable usage, and track referenced imports
233 self.expand_statement_block(body)?;
234 // Restore the original lexical scope
235 self.bindings.exit();
236
237 Ok(())
238 }
239
240 /// Expand a block of statements by visiting each statement front-to-back
241 fn expand_statement_block(
242 &mut self,
243 statements: &mut Vec<Statement>,
244 ) -> Result<(), SemanticAnalysisError> {
245 // This conversion is free, and gives us a natural way to treat the block as a queue
246 let mut buffer: VecDeque<Statement> = core::mem::take(statements).into();
247 // Visit each statement, appending the resulting expansion to the original vector
248 while let Some(statement) = buffer.pop_front() {
249 let mut expanded = self.expand_statement(statement)?;
250 if expanded.is_empty() {
251 continue;
252 }
253 statements.append(&mut expanded);
254 }
255
256 Ok(())
257 }
258
259 /// Expand a single statement into one or more statements which are fully-expanded
260 fn expand_statement(
261 &mut self,
262 statement: Statement,
263 ) -> Result<Vec<Statement>, SemanticAnalysisError> {
264 match statement {
265 // Expanding a let requires special treatment, as let-bound values may be inlined as a block
266 // of statements, which requires us to rewrite the `let` into a `let` tree
267 Statement::Let(expr) => self.expand_let(expr),
268 // A call to an evaluator function is expanded by inlining the function itself at the call site
269 Statement::Enforce(ScalarExpr::Call(call)) => self.expand_evaluator_callsite(call),
270 // Constraints are inlined by expanding the constraint expression
271 Statement::Enforce(expr) => self.expand_constraint(expr),
272 // Constraint comprehensions are inlined by unrolling the comprehension into a sequence of constraints
273 Statement::EnforceAll(expr) => {
274 let in_cc = core::mem::replace(&mut self.in_comprehension_constraint, true);
275 let result = self.expand_comprehension(expr);
276 self.in_comprehension_constraint = in_cc;
277 result
278 }
279 // Conditional constraints are expanded like regular constraints, except the selector is applied
280 // to all constraints in the expansion.
281 Statement::EnforceIf(expr, mut selector) => {
282 let mut statements = match expr {
283 ScalarExpr::Call(call) => self.expand_evaluator_callsite(call)?,
284 expr => self.expand_constraint(expr)?,
285 };
286 self.rewrite_scalar_expr(&mut selector)?;
287 // We need to make sure the selector is applied to all constraints in the expansion
288 for statement in statements.iter_mut() {
289 let mut visitor = ApplyConstraintSelector {
290 selector: &selector,
291 };
292 if let ControlFlow::Break(err) = visitor.visit_mut_statement(statement) {
293 return Err(err);
294 }
295 }
296 Ok(statements)
297 }
298 // Expresssions containing function calls require expansion via inlining, otherwise
299 // all other expression types are introduced during inlining and are thus already expanded,
300 // but we must still visit them to apply rewrites.
301 Statement::Expr(expr) => match self.expand_expr(expr)? {
302 Expr::Let(let_expr) => Ok(vec![Statement::Let(*let_expr)]),
303 expr => Ok(vec![Statement::Expr(expr)]),
304 },
305 Statement::BusEnforce(_) => {
306 self.diagnostics
307 .diagnostic(Severity::Error)
308 .with_message("buses are not implemented for this Pipeline")
309 .emit();
310 Err(SemanticAnalysisError::Invalid)
311 }
312 }
313 }
314
315 fn expand_expr(&mut self, expr: Expr) -> Result<Expr, SemanticAnalysisError> {
316 match expr {
317 Expr::Vector(mut elements) => {
318 let elems = Vec::with_capacity(elements.len());
319 for elem in core::mem::replace(&mut elements.item, elems) {
320 elements.push(self.expand_expr(elem)?);
321 }
322 Ok(Expr::Vector(elements))
323 }
324 Expr::Matrix(mut rows) => {
325 for row in rows.iter_mut() {
326 let cols = Vec::with_capacity(row.len());
327 for col in core::mem::replace(row, cols) {
328 row.push(self.expand_scalar_expr(col)?);
329 }
330 }
331 Ok(Expr::Matrix(rows))
332 }
333 Expr::Binary(expr) => self.expand_binary_expr(expr),
334 Expr::Call(expr) => self.expand_call(expr),
335 Expr::ListComprehension(expr) => {
336 let mut block = self.expand_comprehension(expr)?;
337 assert_eq!(block.len(), 1);
338 Expr::try_from(block.pop().unwrap()).map_err(SemanticAnalysisError::InvalidExpr)
339 }
340 Expr::Let(expr) => {
341 let mut block = self.expand_let(*expr)?;
342 assert_eq!(block.len(), 1);
343 Expr::try_from(block.pop().unwrap()).map_err(SemanticAnalysisError::InvalidExpr)
344 }
345 expr @ (Expr::Const(_) | Expr::Range(_) | Expr::SymbolAccess(_)) => Ok(expr),
346 Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => {
347 self.diagnostics
348 .diagnostic(Severity::Error)
349 .with_message("buses are not implemented for this Pipeline")
350 .emit();
351 Err(SemanticAnalysisError::Invalid)
352 }
353 }
354 }
355
356 /// Let expressions are expanded using the following rules:
357 ///
358 /// * The let-bound expression is expanded first. If it expands to a statement block and
359 /// not an expression, the block is inlined in place of the let being expanded, and the
360 /// rest of the expansion takes place at the end of the block; replacing the last statement
361 /// in the block. If the last statement in the block was an expression, it is treated as
362 /// the let-bound value. If the last statement in the block was another `let` however, then
363 /// we recursively walk down the let tree until we reach the bottom, which must always be
364 /// an expression statement.
365 ///
366 /// * The body is expanded in-place after the previous step has been completed.
367 ///
368 /// * If a let-bound variable is an alias for a declaration, we replace all uses
369 /// of the variable with direct references to the declaration, making the let-bound
370 /// variable dead
371 ///
372 /// * If a let-bound variable is dead (i.e. has no references), then the let is elided,
373 /// by replacing it with the result of expanding its body
374 fn expand_let(&mut self, expr: Let) -> Result<Vec<Statement>, SemanticAnalysisError> {
375 let span = expr.span();
376 let name = expr.name;
377 let body = expr.body;
378
379 // Visit the let-bound expression first, since it determines how the rest of the process goes
380 let value = match expr.value {
381 // When expanding a call in this context, we're expecting a single
382 // statement of either `Expr` or `Let` type, as calls to pure functions
383 // can never contain constraints.
384 Expr::Call(call) => self.expand_call(call)?,
385 // Same as above, but for list comprehensions.
386 //
387 // The rules for expansion are the same.
388 Expr::ListComprehension(lc) => {
389 let mut expanded = self.expand_comprehension(lc)?;
390 match expanded.pop().unwrap() {
391 Statement::Let(let_expr) => Expr::Let(Box::new(let_expr)),
392 Statement::Expr(expr) => expr,
393 Statement::Enforce(_)
394 | Statement::EnforceIf(_, _)
395 | Statement::EnforceAll(_)
396 | Statement::BusEnforce(_) => unreachable!(),
397 }
398 }
399 // The operands of a binary expression can contain function calls, so we must ensure
400 // that we expand the operands as needed, and then proceed with expanding the let.
401 Expr::Binary(expr) => self.expand_binary_expr(expr)?,
402 // Other expressions we visit just to expand rewrites
403 mut expr => {
404 self.rewrite_expr(&mut expr)?;
405 expr
406 }
407 };
408
409 let expr = Let {
410 span,
411 name,
412 value,
413 body,
414 };
415
416 self.expand_let_tree(expr)
417 }
418
419 /// This is only expected to be called on a let tree which is guaranteed to only have
420 /// simple values as let-bound expressions, i.e. the `value` of the `Let` requires no
421 /// expansion or rewrites. You should use `expand_let` in general.
422 fn expand_let_tree(&mut self, mut expr: Let) -> Result<Vec<Statement>, SemanticAnalysisError> {
423 // Start new lexical scope for the body
424 self.bindings.enter();
425 self.let_bound.enter();
426 let prev_rewrites = self.rewrites.clone();
427
428 // Register the binding
429 let binding_ty = self.expr_binding_type(&expr.value).unwrap();
430
431 // If this let is a vector of trace column bindings, then we can
432 // elide the let, and rewrite all uses of the let-bound variable
433 // to the respective elements of the vector
434 let inline_body = binding_ty.is_trace_binding();
435 if inline_body {
436 self.rewrites.insert(expr.name);
437 }
438 self.bindings.insert(expr.name, binding_ty);
439 self.let_bound.insert(expr.name, expr.value.clone());
440
441 // Visit the let body
442 self.expand_statement_block(&mut expr.body)?;
443
444 // Restore the original lexical scope
445 self.bindings.exit();
446 self.let_bound.exit();
447 self.rewrites = prev_rewrites;
448
449 // If we're inlining the body, return the body block as the result;
450 // otherwise re-wrap the `let` as the sole statement in the resulting block
451 if inline_body {
452 Ok(expr.body)
453 } else {
454 Ok(vec![Statement::Let(expr)])
455 }
456 }
457
458 /// Expand a call to a pure function (including builtin list folding functions)
459 fn expand_call(&mut self, mut call: Call) -> Result<Expr, SemanticAnalysisError> {
460 if call.is_builtin() {
461 match call.callee.as_ref().name() {
462 symbols::Sum => {
463 assert_eq!(call.args.len(), 1);
464 self.expand_fold(BinaryOp::Add, call.args.pop().unwrap())
465 }
466 symbols::Prod => {
467 assert_eq!(call.args.len(), 1);
468 self.expand_fold(BinaryOp::Mul, call.args.pop().unwrap())
469 }
470 other => unimplemented!("unhandled builtin: {}", other),
471 }
472 } else {
473 self.expand_function_callsite(call)
474 }
475 }
476
477 fn expand_scalar_expr(
478 &mut self,
479 expr: ScalarExpr,
480 ) -> Result<ScalarExpr, SemanticAnalysisError> {
481 match expr {
482 ScalarExpr::Binary(expr) if expr.has_block_like_expansion() => {
483 self.expand_binary_expr(expr).and_then(|expr| {
484 ScalarExpr::try_from(expr).map_err(SemanticAnalysisError::InvalidExpr)
485 })
486 }
487 ScalarExpr::Call(lhs) => self.expand_call(lhs).and_then(|expr| {
488 ScalarExpr::try_from(expr).map_err(SemanticAnalysisError::InvalidExpr)
489 }),
490 mut expr => {
491 self.rewrite_scalar_expr(&mut expr)?;
492 Ok(expr)
493 }
494 }
495 }
496
497 fn expand_binary_expr(&mut self, expr: BinaryExpr) -> Result<Expr, SemanticAnalysisError> {
498 let span = expr.span();
499 let op = expr.op;
500 let lhs = self.expand_scalar_expr(*expr.lhs)?;
501 let rhs = self.expand_scalar_expr(*expr.rhs)?;
502
503 Ok(Expr::Binary(BinaryExpr {
504 span,
505 op,
506 lhs: Box::new(lhs),
507 rhs: Box::new(rhs),
508 }))
509 }
510
511 /// Expand a list folding operation (e.g. sum/prod) over an expression of aggregate type into an equivalent expression tree
512 fn expand_fold(&mut self, op: BinaryOp, list: Expr) -> Result<Expr, SemanticAnalysisError> {
513 let span = list.span();
514 match list {
515 Expr::Vector(mut elems) => self.expand_vector_fold(span, op, &mut elems),
516 Expr::ListComprehension(lc) => {
517 // Expand the comprehension, but ensure we don't treat it like a comprehension constraint
518 let in_cc = core::mem::replace(&mut self.in_comprehension_constraint, false);
519 let mut expanded = self.expand_comprehension(lc)?;
520 self.in_comprehension_constraint = in_cc;
521 // Apply the fold to the expanded comprehension in the bottom of the let tree
522 with_let_result(self, &mut expanded, |inliner, value| {
523 match value {
524 // The result value of expanding a comprehension _must_ be a vector
525 Expr::Vector(elems) => {
526 // We're going to replace the vector binding with the fold
527 let folded = inliner.expand_vector_fold(span, op, elems)?;
528 *value = folded;
529 Ok(None)
530 }
531 _ => unreachable!(),
532 }
533 })?;
534 match expanded.pop().unwrap() {
535 Statement::Expr(expr) => Ok(expr),
536 Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))),
537 Statement::Enforce(_)
538 | Statement::EnforceIf(_, _)
539 | Statement::EnforceAll(_)
540 | Statement::BusEnforce(_) => unreachable!(),
541 }
542 }
543 Expr::SymbolAccess(ref access) => {
544 match self.let_bound.get(access.name.as_ref()).cloned() {
545 Some(expr) => self.expand_fold(op, expr),
546 None => match self.access_binding_type(access) {
547 Ok(BindingType::TraceColumn(tb)) => {
548 let mut vector = vec![];
549 for i in 0..tb.size {
550 vector.push(Expr::SymbolAccess(
551 access.access(AccessType::Index(i)).unwrap(),
552 ));
553 }
554 self.expand_vector_fold(span, op, &mut vector)
555 }
556 Ok(_) | Err(_) => unimplemented!(),
557 },
558 }
559 }
560 // Constant propagation will have already folded calls to list-folding builtins
561 // with constant arguments, so we should panic if we ever see one here
562 Expr::Const(_) => panic!("expected constant to have been folded"),
563 // All other invalid expressions should have been caught by now
564 invalid => panic!("invalid argument to list folding builtin: {invalid:#?}"),
565 }
566 }
567
568 /// Expand a list folding operation (e.g. sum/prod) over a vector into an equivalent expression tree
569 fn expand_vector_fold(
570 &mut self,
571 span: SourceSpan,
572 op: BinaryOp,
573 vector: &mut Vec<Expr>,
574 ) -> Result<Expr, SemanticAnalysisError> {
575 // To expand this fold, we simply produce a nested sequence of BinaryExpr
576 let mut elems = vector.drain(..);
577 let mut acc = elems.next().unwrap();
578 self.rewrite_expr(&mut acc)?;
579 let mut acc: ScalarExpr = acc.try_into().map_err(SemanticAnalysisError::InvalidExpr)?;
580 for mut elem in elems {
581 self.rewrite_expr(&mut elem)?;
582 let elem: ScalarExpr = elem.try_into().expect("invalid scalar expr");
583 let new_acc = ScalarExpr::Binary(BinaryExpr::new(span, op, acc, elem));
584 acc = new_acc;
585 }
586 acc.try_into().map_err(SemanticAnalysisError::InvalidExpr)
587 }
588
589 fn expand_constraint(
590 &mut self,
591 constraint: ScalarExpr,
592 ) -> Result<Vec<Statement>, SemanticAnalysisError> {
593 // The constraint itself must be an equality at this point, as evaluator
594 // calls are handled separately in `expand_statement`
595 match constraint {
596 ScalarExpr::Binary(BinaryExpr {
597 op: BinaryOp::Eq,
598 lhs,
599 rhs,
600 span,
601 }) => {
602 let lhs = self.expand_scalar_expr(*lhs)?;
603 let rhs = self.expand_scalar_expr(*rhs)?;
604
605 Ok(vec![Statement::Enforce(ScalarExpr::Binary(BinaryExpr {
606 span,
607 op: BinaryOp::Eq,
608 lhs: Box::new(lhs),
609 rhs: Box::new(rhs),
610 }))])
611 }
612 invalid => unreachable!("unexpected constraint node: {:#?}", invalid),
613 }
614 }
615
616 /// This function rewrites expressions which contain accesses for which rewrites have been registered.
617 fn rewrite_expr(&mut self, expr: &mut Expr) -> Result<(), SemanticAnalysisError> {
618 match expr {
619 Expr::Const(_) | Expr::Range(_) => return Ok(()),
620 Expr::Vector(elems) => {
621 for elem in elems.iter_mut() {
622 self.rewrite_expr(elem)?;
623 }
624 }
625 Expr::Matrix(rows) => {
626 for row in rows.iter_mut() {
627 for col in row.iter_mut() {
628 self.rewrite_scalar_expr(col)?;
629 }
630 }
631 }
632 Expr::Binary(binary_expr) => {
633 self.rewrite_scalar_expr(binary_expr.lhs.as_mut())?;
634 self.rewrite_scalar_expr(binary_expr.rhs.as_mut())?;
635 }
636 Expr::SymbolAccess(access) => {
637 if let Some(rewrite) = self.get_trace_access_rewrite(access) {
638 *access = rewrite;
639 }
640 }
641 Expr::Call(call) => {
642 for arg in call.args.iter_mut() {
643 self.rewrite_expr(arg)?;
644 }
645 }
646 // Comprehension rewrites happen when they are expanded, but we do visit the iterables now
647 Expr::ListComprehension(lc) => {
648 for expr in lc.iterables.iter_mut() {
649 self.rewrite_expr(expr)?;
650 }
651 }
652 Expr::Let(let_expr) => {
653 let mut next = Some(let_expr.as_mut());
654 while let Some(next_let) = next.take() {
655 self.rewrite_expr(&mut next_let.value)?;
656 match next_let.body.last_mut().unwrap() {
657 Statement::Let(inner) => {
658 next = Some(inner);
659 }
660 Statement::Expr(expr) => {
661 self.rewrite_expr(expr)?;
662 }
663 Statement::Enforce(_)
664 | Statement::EnforceIf(_, _)
665 | Statement::EnforceAll(_)
666 | Statement::BusEnforce(_) => unreachable!(),
667 }
668 }
669 }
670 Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => {
671 self.diagnostics
672 .diagnostic(Severity::Error)
673 .with_message("buses are not implemented for this Pipeline")
674 .emit();
675 return Err(SemanticAnalysisError::Invalid);
676 }
677 }
678 Ok(())
679 }
680
681 /// This function rewrites scalar expressions which contain accesses for which rewrites have been registered.
682 fn rewrite_scalar_expr(&mut self, expr: &mut ScalarExpr) -> Result<(), SemanticAnalysisError> {
683 match expr {
684 ScalarExpr::Const(_) => Ok(()),
685 ScalarExpr::SymbolAccess(access)
686 | ScalarExpr::BoundedSymbolAccess(BoundedSymbolAccess { column: access, .. }) => {
687 if let Some(rewrite) = self.get_trace_access_rewrite(access) {
688 *access = rewrite;
689 }
690 Ok(())
691 }
692 ScalarExpr::Binary(BinaryExpr { op, lhs, rhs, .. }) => {
693 self.rewrite_scalar_expr(lhs.as_mut())?;
694 self.rewrite_scalar_expr(rhs.as_mut())?;
695 match op {
696 BinaryOp::Exp if !rhs.is_constant() => Err(SemanticAnalysisError::InvalidExpr(
697 InvalidExprError::NonConstantExponent(rhs.span()),
698 )),
699 _ => Ok(()),
700 }
701 }
702 ScalarExpr::Call(expr) => {
703 for arg in expr.args.iter_mut() {
704 self.rewrite_expr(arg)?;
705 }
706 Ok(())
707 }
708 ScalarExpr::Let(let_expr) => {
709 let mut next = Some(let_expr.as_mut());
710 while let Some(next_let) = next.take() {
711 self.rewrite_expr(&mut next_let.value)?;
712 match next_let.body.last_mut().unwrap() {
713 Statement::Let(inner) => {
714 next = Some(inner);
715 }
716 Statement::Expr(expr) => {
717 self.rewrite_expr(expr)?;
718 }
719 Statement::Enforce(_)
720 | Statement::EnforceIf(_, _)
721 | Statement::EnforceAll(_)
722 | Statement::BusEnforce(_) => unreachable!(),
723 }
724 }
725 Ok(())
726 }
727 ScalarExpr::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => {
728 self.diagnostics
729 .diagnostic(Severity::Error)
730 .with_message("buses are not implemented for this Pipeline")
731 .emit();
732 Err(SemanticAnalysisError::Invalid)
733 }
734 }
735 }
736
737 /// This function expands a comprehension into a sequence of statements.
738 ///
739 /// This is done using abstract interpretation. By this point in the compilation process,
740 /// all iterables should have been typed and have known static sizes. Some iterables may even
741 /// be constant, such as in the case of ranges. Because of this, we are able to "unroll" the
742 /// comprehension, evaluating the effective value of all iterable bindings at each iteration,
743 /// and rewriting the comprehension body accordingly.
744 ///
745 /// Depending on whether this is a standard list comprehension, or a constraint comprehension,
746 /// the expansion is, respectively:
747 ///
748 /// * A tree of let statements (using generated variables), where each let binds the value of a
749 /// single iteration of the comprehension. The body of the final let, and thus the effective
750 /// value of the entire tree, is a vector containing all of the bindings in the evaluation
751 /// order of the comprehension.
752 /// * A flat list of constraint statements
753 fn expand_comprehension(
754 &mut self,
755 mut expr: ListComprehension,
756 ) -> Result<Vec<Statement>, SemanticAnalysisError> {
757 // Lift any function calls in iterable position out of the comprehension,
758 // binding the result of those calls via `let`. Rewrite the iterable as
759 // a symbol access to the newly-bound variable.
760 //
761 // NOTE: The actual expansion of the lifted iterables occurs after we expand
762 // the comprehension, so that we can place the expanded comprehension in the
763 // body of the final let
764 let mut lifted_bindings = vec![];
765 let mut lifted = vec![];
766 for param in expr.iterables.iter_mut() {
767 if !matches!(param, Expr::Call(_)) {
768 continue;
769 }
770
771 let span = param.span();
772 let name = self.get_next_ident(span);
773 let ty = match param {
774 Expr::Call(Call { callee, .. }) => {
775 let callee = callee
776 .resolved()
777 .expect("callee should have been resolved by now");
778 self.functions[&callee].return_type
779 }
780 _ => unsafe { core::hint::unreachable_unchecked() },
781 };
782 let param = core::mem::replace(
783 param,
784 Expr::SymbolAccess(SymbolAccess {
785 span,
786 name: ResolvableIdentifier::Local(name),
787 access_type: AccessType::Default,
788 offset: 0,
789 ty: Some(ty),
790 }),
791 );
792 match param {
793 Expr::Call(call) => {
794 lifted_bindings.push((name, BindingType::Local(ty)));
795 lifted.push((name, call));
796 }
797 _ => unsafe { core::hint::unreachable_unchecked() },
798 }
799 }
800
801 // Get the number of iterations in this comprehension
802 let Type::Vector(num_iterations) = expr.ty.unwrap() else {
803 panic!("invalid comprehension type");
804 };
805
806 // Step the iterables for each iteration, giving each it's own lexical scope
807 let mut statement_groups = vec![];
808 for i in 0..num_iterations {
809 self.bindings.enter();
810 // Ensure any lifted iterables are in scope for the expansion of this iteration
811 for (name, binding_ty) in lifted_bindings.iter() {
812 self.bindings.insert(*name, binding_ty.clone());
813 }
814 let expansion = self.expand_comprehension_iteration(&expr, i)?;
815 // An expansion can be empty if a constraint selector with a constant selector expression
816 // evaluates to false (allowing us to elide the constraint for that iteration entirely).
817 if !expansion.is_empty() {
818 statement_groups.push(expansion);
819 }
820 self.bindings.exit();
821 }
822
823 // At this point, we have one or more statement groups, representing the expansions
824 // of each iteration of the comprehension. Additionally, we may have a set of lifted
825 // iterables which we need to bind (and expand) "around" the expansion of the comprehension
826 // itself.
827 //
828 // In short, we must take this list of statement groups, and flatten/treeify it. Once
829 // a let binding is introduced into scope, all subsequent statements must occur in the body
830 // of that let, forming a tree. Consecutive statements which introduce no new bindings do
831 // not require any nesting, resulting in the groups containing those statements being flattened.
832 //
833 // Lastly, whether this is a list or constraint comprehension determines if we will also be
834 // constructing a vector from the values produced by each iteration, and returning it as the
835 // result of the comprehension itself.
836 let span = expr.span();
837 if self.in_comprehension_constraint {
838 Ok(statement_groups.into_iter().flatten().collect())
839 } else {
840 // For list comprehensions, we must emit a let tree that binds each iteration,
841 // and ensure that the expansion of the iteration itself is properly nested so
842 // that the lexical scope of all bound variables is correct. This is more complex
843 // than the constraint comprehension case, as we must emit a single expression
844 // representing the entire expansion of the comprehension as an aggregate, whereas
845 // constraints produce no results.
846
847 // Generate a new variable name for each element in the comprehension
848 let symbols = statement_groups
849 .iter()
850 .map(|_| self.get_next_ident_lc(span))
851 .collect::<Vec<_>>();
852 // Generate the list of elements for the vector which is to be the result of the let-tree
853 let vars = statement_groups
854 .iter()
855 .zip(symbols.iter().copied())
856 .map(|(group, name)| {
857 // The type of these statements must be known by now
858 let ty = match group.last().unwrap() {
859 Statement::Expr(value) => value.ty(),
860 Statement::Let(nested) => nested.ty(),
861 stmt => unreachable!(
862 "unexpected statement type in comprehension body: {}",
863 stmt.display(0)
864 ),
865 };
866 Expr::SymbolAccess(SymbolAccess {
867 span,
868 name: ResolvableIdentifier::Local(name),
869 access_type: AccessType::Default,
870 offset: 0,
871 ty,
872 })
873 })
874 .collect();
875 // Construct the let tree by visiting the statements bottom-up
876 let acc = vec![Statement::Expr(Expr::Vector(Span::new(span, vars)))];
877 let expanded = statement_groups.into_iter().zip(symbols).try_rfold(
878 acc,
879 |acc, (mut group, name)| {
880 match group.pop().unwrap() {
881 // If the current statement is an expression, it represents the value of this
882 // iteration of the comprehension, and we must generate a let to bind it, using
883 // the accumulator expression as the body
884 Statement::Expr(expr) => {
885 group.push(Statement::Let(Let::new(span, name, expr, acc)));
886 }
887 // If the current statement is a `let`-tree, we need to generate a new `let` at
888 // the bottom of the tree, which binds the result expression as the value of the
889 // generated `let`, and uses the accumulator as the body
890 Statement::Let(mut wrapper) => {
891 with_let_result(self, &mut wrapper.body, move |_, value| {
892 let value = core::mem::replace(
893 value,
894 Expr::Const(Span::new(span, ConstantExpr::Scalar(0))),
895 );
896 Ok(Some(Statement::Let(Let::new(span, name, value, acc))))
897 })?;
898 group.push(Statement::Let(wrapper));
899 }
900 _ => unreachable!(),
901 }
902 Ok::<_, SemanticAnalysisError>(group)
903 },
904 )?;
905 // Lastly, construct the let tree for the lifted iterables, placing the expanded
906 // comprehension at the bottom of that tree.
907 lifted.into_iter().try_rfold(expanded, |acc, (name, call)| {
908 let span = call.span();
909 match self.expand_call(call)? {
910 Expr::Let(mut wrapper) => {
911 with_let_result(self, &mut wrapper.body, move |_, value| {
912 let value = core::mem::replace(
913 value,
914 Expr::Const(Span::new(span, ConstantExpr::Scalar(0))),
915 );
916 Ok(Some(Statement::Let(Let::new(span, name, value, acc))))
917 })?;
918 Ok(vec![Statement::Let(*wrapper)])
919 }
920 expr => Ok(vec![Statement::Let(Let::new(span, name, expr, acc))]),
921 }
922 })
923 }
924 }
925
926 fn expand_comprehension_iteration(
927 &mut self,
928 lc: &ListComprehension,
929 index: usize,
930 ) -> Result<Vec<Statement>, SemanticAnalysisError> {
931 // Register each iterable binding and its abstract value.
932 //
933 // The abstract value is either a constant (in which case it is concrete, not abstract), or
934 // an expression which represents accessing the iterable at the index corresponding to the
935 // current iteration.
936 let mut bound_values = HashMap::<Identifier, Expr>::default();
937 for (iterable, binding) in lc.iterables.iter().zip(lc.bindings.iter().copied()) {
938 let abstract_value = match iterable {
939 // If the iterable is constant, the value of it's corresponding binding is also constant
940 Expr::Const(constant) => {
941 let span = constant.span();
942 let value = match constant.item {
943 ConstantExpr::Vector(ref elems) => ConstantExpr::Scalar(elems[index]),
944 ConstantExpr::Matrix(ref rows) => ConstantExpr::Vector(rows[index].clone()),
945 // An iterable may never be a scalar value, this will be caught by semantic analysis
946 ConstantExpr::Scalar(_) => unreachable!(),
947 };
948 let binding_ty = BindingType::Constant(value.ty());
949 self.bindings.insert(binding, binding_ty);
950 Expr::Const(Span::new(span, value))
951 }
952 // Ranges are constant, so same rules as above apply here
953 Expr::Range(range) => {
954 let span = range.span();
955 let range = range.to_slice_range();
956 let binding_ty = BindingType::Constant(Type::Felt);
957 self.bindings.insert(binding, binding_ty);
958 Expr::Const(Span::new(
959 span,
960 ConstantExpr::Scalar((range.start + index) as u64),
961 ))
962 }
963 // If the iterable was a vector, the abstract value is whatever expression is at
964 // the corresponding index of the vector.
965 Expr::Vector(elems) => {
966 let abstract_value = elems[index].clone();
967 let binding_ty = self.expr_binding_type(&abstract_value).unwrap();
968 self.bindings.insert(binding, binding_ty);
969 abstract_value
970 }
971 // If the iterable was a matrix, the abstract value is a vector of expressions
972 // representing the current row of the matrix. We calculate the binding type of
973 // each element in that vector so that accesses into the vector are well typed.
974 Expr::Matrix(rows) => {
975 let row: Vec<Expr> = rows[index]
976 .iter()
977 .cloned()
978 .map(|se| se.try_into().unwrap())
979 .collect();
980 let mut tys = vec![];
981 for elem in row.iter() {
982 tys.push(self.expr_binding_type(elem).unwrap());
983 }
984 let binding_ty = BindingType::Vector(tys);
985 self.bindings.insert(binding, binding_ty);
986 Expr::Vector(Span::new(rows.span(), row))
987 }
988 // If the iterable was a variable/access, then we must first index into that
989 // access, and then rewrite it, if applicable.
990 Expr::SymbolAccess(access) => {
991 // The access here must be of aggregate type, so index into it for the current iteration
992 let mut current_access = access.access(AccessType::Index(index)).unwrap();
993 // Rewrite the resulting access if we have a rewrite for the underlying symbol
994 if let Some(rewrite) = self.get_trace_access_rewrite(¤t_access) {
995 current_access = rewrite;
996 }
997 let binding_ty = self.access_binding_type(¤t_access).unwrap();
998 self.bindings.insert(binding, binding_ty);
999 Expr::SymbolAccess(current_access)
1000 }
1001 // Binary expressions are scalar, so cannot be used as iterables, and we don't
1002 // (currently) support nested comprehensions, so it is never possible to observe
1003 // these expression types here. Calls should have been lifted prior to expansion.
1004 Expr::Call(_)
1005 | Expr::Binary(_)
1006 | Expr::ListComprehension(_)
1007 | Expr::Let(_)
1008 | Expr::BusOperation(_)
1009 | Expr::Null(_)
1010 | Expr::Unconstrained(_) => {
1011 unreachable!()
1012 }
1013 };
1014 bound_values.insert(binding, abstract_value);
1015 }
1016
1017 // Clone the comprehension body for this iteration, so we don't modify the original
1018 let mut body = lc.body.as_ref().clone();
1019
1020 // Rewrite all references to the iterable bindings in the comprehension body
1021 let mut visitor = RewriteIterableBindingsVisitor {
1022 values: &bound_values,
1023 };
1024 if let ControlFlow::Break(err) = visitor.visit_mut_scalar_expr(&mut body) {
1025 return Err(err);
1026 }
1027
1028 // Next, handle comprehension filters/selectors as follows:
1029 //
1030 // 1. Selectors are evaluated in the same context as the body, so we must visit iterable references in the same way.
1031 // 2. If a selector has a constant value, we can elide the selector for this iteration. Furthermore, in situations where
1032 // the selector is known false, we can elide the expansion of this iteration entirely.
1033 //
1034 // Since the selector is the last piece we need to construct the Statement corresponding to the expansion of
1035 // this iteration, we do that now before proceeding to the next step.
1036 let statement = if let Some(mut selector) = lc.selector.clone() {
1037 assert!(
1038 self.in_comprehension_constraint,
1039 "selectors are not permitted in list comprehensions"
1040 );
1041 // #1
1042 if let ControlFlow::Break(err) = visitor.visit_mut_scalar_expr(&mut selector) {
1043 return Err(err);
1044 }
1045 // #2
1046 match selector {
1047 // If the selector value is zero, or false, we can elide the expansion entirely
1048 ScalarExpr::Const(value) if value.item == 0 => return Ok(vec![]),
1049 // If the selector value is non-zero, or true, we can elide just the selector
1050 ScalarExpr::Const(_) => Statement::Enforce(body),
1051 // We have a selector that requires evaluation at runtime, we need to emit a conditional scalar constraint
1052 other => Statement::EnforceIf(body, other),
1053 }
1054 } else if self.in_comprehension_constraint {
1055 Statement::Enforce(body)
1056 } else {
1057 Statement::Expr(body.try_into().unwrap())
1058 };
1059
1060 // Next, although we've rewritten the comprehension body corresponding to this iteration, we
1061 // haven't yet performed inlining on it. We do that now, while all of the bindings are
1062 // in scope with the proper values. The result of that expansion is what we emit as the result
1063 // for this iteration.
1064 self.expand_statement(statement)
1065 }
1066
1067 /// This function handles inlining evaluator function calls.
1068 ///
1069 /// At this point, semantic analysis has verified that the call arguments are valid, in
1070 /// that the number of trace columns passed matches the number of columns expected by the
1071 /// function parameters. However, the number and type of bindings are permitted to be
1072 /// different, as long as the vectors are the same size when expanded - in effect, re-grouping
1073 /// the trace columns at the call boundary.
1074 fn expand_evaluator_callsite(
1075 &mut self,
1076 call: Call,
1077 ) -> Result<Vec<Statement>, SemanticAnalysisError> {
1078 // The callee is guaranteed to be resolved and exist at this point
1079 let callee = call
1080 .callee
1081 .resolved()
1082 .expect("callee should have been resolved by now");
1083 // We clone the evaluator here as we will be modifying the body during the
1084 // inlining process, and we must not modify the original
1085 let mut evaluator = self.evaluators.get(&callee).unwrap().clone();
1086
1087 // This will be the initial set of bindings visible within the evaluator body
1088 //
1089 // This is distinct from `self.bindings` at this point, because the evaluator doesn't
1090 // inherit the caller's scope, it has an entirely new one.
1091 let mut eval_bindings = LexicalScope::default();
1092
1093 // Add all referenced (and thus imported) items from the evaluator module
1094 //
1095 // NOTE: This will include constants, periodic columns, and other functions
1096 for (qid, binding_ty) in self.imported.iter() {
1097 if qid.module == callee.module {
1098 eval_bindings.insert(*qid.as_ref(), binding_ty.clone());
1099 }
1100 }
1101
1102 // Add trace columns, and other root declarations to the set of
1103 // bindings visible in the evaluator body, _if_ the evaluator is defined in the
1104 // root module.
1105 let is_evaluator_in_root = callee.module == self.root;
1106 if is_evaluator_in_root {
1107 for segment in self.trace.iter() {
1108 eval_bindings.insert(
1109 segment.name,
1110 BindingType::TraceColumn(TraceBinding {
1111 span: segment.name.span(),
1112 segment: segment.id,
1113 name: Some(segment.name),
1114 offset: 0,
1115 size: segment.size,
1116 ty: Type::Vector(segment.size),
1117 }),
1118 );
1119 for binding in segment.bindings.iter().copied() {
1120 eval_bindings.insert(
1121 binding.name.unwrap(),
1122 BindingType::TraceColumn(TraceBinding {
1123 span: segment.name.span(),
1124 segment: segment.id,
1125 name: binding.name,
1126 offset: binding.offset,
1127 size: binding.size,
1128 ty: binding.ty,
1129 }),
1130 );
1131 }
1132 }
1133
1134 for input in self.public_inputs.values() {
1135 eval_bindings.insert(
1136 input.name(),
1137 BindingType::PublicInput(Type::Vector(input.size())),
1138 );
1139 }
1140 }
1141
1142 // Match call arguments to function parameters, populating the set of rewrites
1143 // which should be performed on the inlined function body.
1144 //
1145 // NOTE: We create a new nested scope for the parameters in order to avoid conflicting
1146 // with the root declarations
1147 eval_bindings.enter();
1148 self.populate_evaluator_rewrites(
1149 &mut eval_bindings,
1150 call.args.as_slice(),
1151 evaluator.params.as_slice(),
1152 );
1153
1154 // While we're inlining the body, use the set of evaluator bindings we built above
1155 let prev_bindings = core::mem::replace(&mut self.bindings, eval_bindings);
1156
1157 // Expand the evaluator body into a block of statements
1158 self.expand_statement_block(&mut evaluator.body)?;
1159
1160 // Restore the caller's bindings before we leave
1161 self.bindings = prev_bindings;
1162
1163 Ok(evaluator.body)
1164 }
1165
1166 /// This function handles inlining pure function calls, which must produce an expression
1167 fn expand_function_callsite(&mut self, call: Call) -> Result<Expr, SemanticAnalysisError> {
1168 self.bindings.enter();
1169 // The callee is guaranteed to be resolved and exist at this point
1170 let callee = call
1171 .callee
1172 .resolved()
1173 .expect("callee should have been resolved by now");
1174
1175 if self.call_stack.contains(&callee) {
1176 let ifd = self
1177 .diagnostics
1178 .diagnostic(Severity::Error)
1179 .with_message("invalid recursive function call")
1180 .with_primary_label(call.span, "recursion occurs due to this function call");
1181 self.call_stack
1182 .iter()
1183 .rev()
1184 .fold(ifd, |ifd, caller| {
1185 ifd.with_secondary_label(caller.span(), "which was called from")
1186 })
1187 .emit();
1188 return Err(SemanticAnalysisError::Invalid);
1189 } else {
1190 self.call_stack.push(callee);
1191 }
1192
1193 // We clone the function here as we will be modifying the body during the
1194 // inlining process, and we must not modify the original
1195 let mut function = self.functions.get(&callee).unwrap().clone();
1196
1197 // This will be the initial set of bindings visible within the function body
1198 //
1199 // This is distinct from `self.bindings` at this point, because the function doesn't
1200 // inherit the caller's scope, it has an entirely new one.
1201 let mut function_bindings = LexicalScope::default();
1202
1203 // Add all referenced (and thus imported) items from the function module
1204 //
1205 // NOTE: This will include constants, periodic columns, and other functions
1206 for (qid, binding_ty) in self.imported.iter() {
1207 if qid.module == callee.module {
1208 function_bindings.insert(*qid.as_ref(), binding_ty.clone());
1209 }
1210 }
1211
1212 // Add trace columns, and other root declarations to the set of
1213 // bindings visible in the function body, _if_ the function is defined in the
1214 // root module.
1215 let is_function_in_root = callee.module == self.root;
1216 if is_function_in_root {
1217 for segment in self.trace.iter() {
1218 function_bindings.insert(
1219 segment.name,
1220 BindingType::TraceColumn(TraceBinding {
1221 span: segment.name.span(),
1222 segment: segment.id,
1223 name: Some(segment.name),
1224 offset: 0,
1225 size: segment.size,
1226 ty: Type::Vector(segment.size),
1227 }),
1228 );
1229 for binding in segment.bindings.iter().copied() {
1230 function_bindings.insert(
1231 binding.name.unwrap(),
1232 BindingType::TraceColumn(TraceBinding {
1233 span: segment.name.span(),
1234 segment: segment.id,
1235 name: binding.name,
1236 offset: binding.offset,
1237 size: binding.size,
1238 ty: binding.ty,
1239 }),
1240 );
1241 }
1242 }
1243
1244 for input in self.public_inputs.values() {
1245 function_bindings.insert(
1246 input.name(),
1247 BindingType::PublicInput(Type::Vector(input.size())),
1248 );
1249 }
1250 }
1251
1252 // Match call arguments to function parameters, populating the set of rewrites
1253 // which should be performed on the inlined function body.
1254 //
1255 // NOTE: We create a new nested scope for the parameters in order to avoid conflicting
1256 // with the root declarations
1257 function_bindings.enter();
1258 self.populate_function_rewrites(
1259 &mut function_bindings,
1260 call.args.as_slice(),
1261 function.params.as_slice(),
1262 );
1263
1264 // While we're inlining the body, use the set of function bindings we built above
1265 let prev_bindings = core::mem::replace(&mut self.bindings, function_bindings);
1266
1267 // Expand the function body into a block of statements
1268 self.expand_statement_block(&mut function.body)?;
1269
1270 // Restore the caller's bindings before we leave
1271 self.bindings = prev_bindings;
1272
1273 // We're done expanding this call, so remove it from the call stack
1274 self.call_stack.pop();
1275
1276 match function.body.pop().unwrap() {
1277 Statement::Expr(expr) => Ok(expr),
1278 Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))),
1279 Statement::Enforce(_)
1280 | Statement::EnforceIf(_, _)
1281 | Statement::EnforceAll(_)
1282 | Statement::BusEnforce(_) => {
1283 panic!("unexpected constraint in function body")
1284 }
1285 }
1286 }
1287
1288 /// Populate the set of access rewrites, as well as the initial set of bindings to use when inlining an evaluator function.
1289 ///
1290 /// This is done by resolving the arguments provided by the call to the evaluator, with the parameter list of the evaluator itself.
1291 fn populate_evaluator_rewrites(
1292 &mut self,
1293 eval_bindings: &mut LexicalScope<Identifier, BindingType>,
1294 args: &[Expr],
1295 params: &[TraceSegment],
1296 ) {
1297 // Reset the rewrites set
1298 self.rewrites.clear();
1299
1300 // Each argument corresponds to a function parameter, each of which represents a single trace segment
1301 for (arg, segment) in args.iter().zip(params.iter()) {
1302 match arg {
1303 // A variable was passed as an argument for this segment
1304 //
1305 // Arguments by now must have been validated by semantic analysis, and specifically
1306 // in this case, the number of columns in the variable and the number expected by the
1307 // parameter we're binding must be the same. However, a variable may represent a single
1308 // column, a contiguous slice of columns, or a vector of such variables which may be
1309 // non-contiguous.
1310 Expr::SymbolAccess(access) => {
1311 // We use a `BindingType` to track the state of the current input binding being processed.
1312 //
1313 // The initial state is given by the binding type of the access itself, but as we destructure
1314 // the binding according to the parameter binding pattern, we may pop off columns, in which
1315 // case the binding type here gets updated with the remaining columns
1316 let mut binding_ty = Some(self.access_binding_type(access).unwrap());
1317 // We visit each binding in the trace segment represented by the parameter pattern,
1318 // consuming columns from the input argument until all bindings are matched up.
1319 for binding in segment.bindings.iter() {
1320 // Trace binding declarations are never anonymous, i.e. always have a name
1321 let binding_name = binding.name.unwrap();
1322 // We can safely assume that there is a binding type available here,
1323 // otherwise the semantic analysis pass missed something
1324 let bt = binding_ty.take().unwrap();
1325 // Split out the needed columns from the input binding
1326 //
1327 // We can safely assume we were able to obtain all of the needed columns,
1328 // as the semantic analyzer should have caught mismatches. Note, however,
1329 // that these columns may have been gathered from multiple bindings in the caller
1330 let (matched, rest) = bt.split_columns(binding.size).unwrap();
1331 self.rewrites.insert(binding_name);
1332 eval_bindings.insert(binding_name, matched);
1333 // Update `binding_ty` with whatever remains of the input
1334 binding_ty = rest;
1335 }
1336 }
1337 // An empty vector means there are no bindings for this segment
1338 Expr::Const(Span {
1339 item: ConstantExpr::Vector(items),
1340 ..
1341 }) if items.is_empty() => {
1342 continue;
1343 }
1344 // A vector of bindings was passed as an argument for this segment
1345 //
1346 // This is by far the most complicated scenario to handle when matching up arguments
1347 // to parameters, as we can get them in a variety of combinations:
1348 //
1349 // 1. An exact match in the number and size of bindings in both the input vector and the
1350 // segment represented by the current parameter
1351 // 2. The same number of elements in the vector as bindings in the segment, but the elements
1352 // have different sizes, implicitly regrouping columns between caller/callee
1353 // 3. More elements in the vector than bindings in the segment, typically because the function
1354 // parameter groups together columns passed individually in the caller
1355 // 4. Fewer elements in the vector than bindings in the segment, typically because the function
1356 // parameter destructures an input into multiple bindings
1357 Expr::Vector(inputs) => {
1358 // The index of the input we're currently extracting columns from
1359 let mut index = 0;
1360 // A `BindingType` representing the current trace binding we're extracting columns from,
1361 // can be either of TraceColumn or Vector type
1362 let mut binding_ty = None;
1363 // We drive the matching process by consuming input columns for each segment binding in turn
1364 'next_binding: for binding in segment.bindings.iter() {
1365 let binding_name = binding.name.unwrap();
1366 let mut needed = binding.size;
1367
1368 // When there are insufficient columns for the current parameter binding in the current
1369 // input, we must construct a vector of trace bindings to use as the binding type of
1370 // the current parameter binding when we have all of the needed columns. This is because
1371 // the input columns may come from different trace bindings in the caller, so we can't
1372 // use a single trace binding to represent them.
1373 let mut set = vec![];
1374
1375 // We may need to consume multiple input elements to fulfill the needed columns of
1376 // the current parameter binding - we advance this loop whenever we have exhausted
1377 // an input and need to move on to the next one. We may enter this loop with the
1378 // same input index across multiple parameter bindings when the input element is
1379 // larger than the parameter binding, in which case we have split the input and
1380 // stored the remainder in `binding_ty`.
1381 loop {
1382 let input = &inputs[index];
1383 // The input expression must have been a symbol access, as matrices of columns
1384 // aren't a thing, and there is no other expression type which can produce trace
1385 // bindings.
1386 let Expr::SymbolAccess(access) = input else {
1387 panic!("unexpected element in trace column vector: {input:#?}")
1388 };
1389 // Unless we have leftover input, initialize `binding_ty` with the binding type of this input
1390 let bt = binding_ty
1391 .take()
1392 .unwrap_or_else(|| self.access_binding_type(access).unwrap());
1393 match bt.split_columns(needed) {
1394 Ok((matched, rest)) => {
1395 let eval_binding = match matched {
1396 BindingType::TraceColumn(matched) => {
1397 if !set.is_empty() {
1398 // We've obtained all the remaining columns from the current input element,
1399 // possibly with leftovers in the input. However, because we've started
1400 // constructing a vector binding, we must ensure the matched binding is
1401 // expanded into individual columns
1402 for offset in 0..matched.size {
1403 set.push(BindingType::TraceColumn(
1404 TraceBinding {
1405 offset: matched.offset + offset,
1406 size: 1,
1407 ..matched
1408 },
1409 ));
1410 }
1411 BindingType::Vector(set)
1412 } else {
1413 // The input element perfectly matched the current binding
1414 BindingType::TraceColumn(matched)
1415 }
1416 }
1417 BindingType::Vector(mut matched) => {
1418 if set.is_empty() {
1419 // The input binding was a vector, and had the same number, or
1420 // more, of columns expected by the parameter binding, but may contain
1421 // non-contiguous bindings, so we are unable to use the symbol of
1422 // the access when rewriting accesses to this parameter
1423 BindingType::Vector(matched)
1424 } else {
1425 // Same as above, but we need to append the matched bindings to
1426 // the set we've already started building
1427 set.append(&mut matched);
1428 BindingType::Vector(set)
1429 }
1430 }
1431 _ => unreachable!(),
1432 };
1433 // This binding has been fulfilled, move to the next one
1434 self.rewrites.insert(binding_name);
1435 eval_bindings.insert(binding_name, eval_binding);
1436 binding_ty = rest;
1437 // If we have no more columns remaining in this input, advance
1438 // to the next input starting with the next binding
1439 if binding_ty.is_none() {
1440 index += 1;
1441 }
1442 continue 'next_binding;
1443 }
1444 Err(BindingType::TraceColumn(partial)) => {
1445 // The input binding wasn't big enough for the parameter, so we must
1446 // start constructing a vector of bindings since the next input is
1447 // unlikely to be contiguous with the current input
1448 for offset in 0..partial.size {
1449 set.push(BindingType::TraceColumn(TraceBinding {
1450 offset: partial.offset + offset,
1451 size: 1,
1452 ..partial
1453 }));
1454 }
1455 needed -= partial.size;
1456 index += 1;
1457 }
1458 Err(BindingType::Vector(mut partial)) => {
1459 // Same as above, but we got a vector instead
1460 set.append(&mut partial);
1461 needed -= partial.len();
1462 index += 1;
1463 }
1464 Err(_) => unreachable!(),
1465 }
1466 }
1467 }
1468 }
1469 // This should not be possible at this point, but would be an invalid evaluator call,
1470 // only trace columns are permitted
1471 expr => unreachable!("{:#?}", expr),
1472 }
1473 }
1474 }
1475
1476 fn populate_function_rewrites(
1477 &mut self,
1478 function_bindings: &mut LexicalScope<Identifier, BindingType>,
1479 args: &[Expr],
1480 params: &[(Identifier, Type)],
1481 ) {
1482 // Reset the rewrites set
1483 self.rewrites.clear();
1484
1485 for (arg, (param_name, param_ty)) in args.iter().zip(params.iter()) {
1486 // We can safely assume that there is a binding type available here,
1487 // otherwise the semantic analysis pass missed something
1488 let binding_ty = self.expr_binding_type(arg).unwrap();
1489 debug_assert_eq!(binding_ty.ty(), Some(*param_ty), "unexpected type mismatch");
1490 self.rewrites.insert(*param_name);
1491 function_bindings.insert(*param_name, binding_ty);
1492 }
1493 }
1494
1495 /// Returns a new [SymbolAccess] which should be used in place of `access` in the current scope.
1496 ///
1497 /// This function should only be called on accesses which have a trace column/param [BindingType],
1498 /// but it will simply return `None` for other types, so it is safe to call on all accesses.
1499 fn get_trace_access_rewrite(&self, access: &SymbolAccess) -> Option<SymbolAccess> {
1500 if self.rewrites.contains(access.name.as_ref()) {
1501 // If we have a rewrite for this access, then the bindings map will
1502 // have an accurate trace binding for us; rewrite this access to be
1503 // relative to that trace binding
1504 match self.access_binding_type(access).unwrap() {
1505 BindingType::TraceColumn(tb) => {
1506 let original_binding = self.trace[tb.segment]
1507 .bindings
1508 .iter()
1509 .find(|b| b.name == tb.name)
1510 .unwrap();
1511 let (access_type, ty) = if original_binding.size == 1 {
1512 (AccessType::Default, Type::Felt)
1513 } else if tb.size == 1 {
1514 (
1515 AccessType::Index(tb.offset - original_binding.offset),
1516 Type::Felt,
1517 )
1518 } else {
1519 let start = tb.offset - original_binding.offset;
1520 (
1521 AccessType::Slice(RangeExpr::from(start..(start + tb.size))),
1522 Type::Vector(tb.size),
1523 )
1524 };
1525 Some(SymbolAccess {
1526 span: access.span(),
1527 name: ResolvableIdentifier::Local(tb.name.unwrap()),
1528 access_type,
1529 offset: access.offset,
1530 ty: Some(ty),
1531 })
1532 }
1533 // We only have a rewrite when the binding type is TraceColumn
1534 invalid => panic!(
1535 "unexpected trace access binding type, expected column(s), got: {:#?}",
1536 &invalid
1537 ),
1538 }
1539 } else {
1540 None
1541 }
1542 }
1543
1544 fn expr_binding_type(&self, expr: &Expr) -> Result<BindingType, InvalidAccessError> {
1545 let mut bindings = self.bindings.clone();
1546 eval_expr_binding_type(expr, &mut bindings, &self.imported)
1547 }
1548
1549 /// Returns the effective [BindingType] of the value produced by the given access
1550 fn access_binding_type(&self, expr: &SymbolAccess) -> Result<BindingType, InvalidAccessError> {
1551 eval_access_binding_type(expr, &self.bindings, &self.imported)
1552 }
1553}
1554
1555/// Returns the effective [BindingType] of the given expression
1556fn eval_expr_binding_type(
1557 expr: &Expr,
1558 bindings: &mut LexicalScope<Identifier, BindingType>,
1559 imported: &HashMap<QualifiedIdentifier, BindingType>,
1560) -> Result<BindingType, InvalidAccessError> {
1561 match expr {
1562 Expr::Const(constant) => Ok(BindingType::Local(constant.ty())),
1563 Expr::Range(range) => Ok(BindingType::Local(Type::Vector(
1564 range.to_slice_range().len(),
1565 ))),
1566 Expr::Vector(elems) => match elems[0].ty() {
1567 None | Some(Type::Felt) => {
1568 let mut binding_tys = Vec::with_capacity(elems.len());
1569 for elem in elems.iter() {
1570 binding_tys.push(eval_expr_binding_type(elem, bindings, imported)?);
1571 }
1572 Ok(BindingType::Vector(binding_tys))
1573 }
1574 Some(Type::Vector(cols)) => {
1575 let rows = elems.len();
1576 Ok(BindingType::Local(Type::Matrix(rows, cols)))
1577 }
1578 Some(_) => unreachable!(),
1579 },
1580 Expr::Matrix(expr) => {
1581 let rows = expr.len();
1582 let columns = expr[0].len();
1583 Ok(BindingType::Local(Type::Matrix(rows, columns)))
1584 }
1585 Expr::SymbolAccess(access) => eval_access_binding_type(access, bindings, imported),
1586 Expr::Call(Call { ty: None, .. }) => Err(InvalidAccessError::InvalidBinding),
1587 Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)),
1588 Expr::Binary(_) => Ok(BindingType::Local(Type::Felt)),
1589 Expr::ListComprehension(lc) => {
1590 // The types of all iterables must be the same, so the type of
1591 // the comprehension is given by the type of the iterables. We
1592 // just pick the first iterable to tell us the type
1593 eval_expr_binding_type(&lc.iterables[0], bindings, imported)
1594 }
1595 Expr::Let(let_expr) => eval_let_binding_ty(let_expr, bindings, imported),
1596 Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => {
1597 unimplemented!("buses are not implemented for this Pipeline")
1598 }
1599 }
1600}
1601
1602/// Returns the effective [BindingType] of the value produced by the given access
1603fn eval_access_binding_type(
1604 expr: &SymbolAccess,
1605 bindings: &LexicalScope<Identifier, BindingType>,
1606 imported: &HashMap<QualifiedIdentifier, BindingType>,
1607) -> Result<BindingType, InvalidAccessError> {
1608 let binding_ty = bindings
1609 .get(expr.name.as_ref())
1610 .or_else(|| match expr.name {
1611 ResolvableIdentifier::Resolved(qid) => imported.get(&qid),
1612 _ => None,
1613 })
1614 .ok_or(InvalidAccessError::UndefinedVariable)
1615 .clone()?;
1616 binding_ty.access(expr.access_type.clone())
1617}
1618
1619fn eval_let_binding_ty(
1620 let_expr: &Let,
1621 bindings: &mut LexicalScope<Identifier, BindingType>,
1622 imported: &HashMap<QualifiedIdentifier, BindingType>,
1623) -> Result<BindingType, InvalidAccessError> {
1624 let variable_ty = eval_expr_binding_type(&let_expr.value, bindings, imported)?;
1625 bindings.enter();
1626 bindings.insert(let_expr.name, variable_ty);
1627 let binding_ty = match let_expr.body.last().unwrap() {
1628 Statement::Let(inner_let) => eval_let_binding_ty(inner_let, bindings, imported)?,
1629 Statement::Expr(expr) => eval_expr_binding_type(expr, bindings, imported)?,
1630 Statement::Enforce(_)
1631 | Statement::EnforceIf(_, _)
1632 | Statement::EnforceAll(_)
1633 | Statement::BusEnforce(_) => {
1634 unreachable!()
1635 }
1636 };
1637 bindings.exit();
1638 Ok(binding_ty)
1639}
1640
1641/// This visitor is used to rewrite uses of iterable bindings within a comprehension body,
1642/// including expansion of constant accesses.
1643struct RewriteIterableBindingsVisitor<'a> {
1644 /// This map contains the set of symbols to be rewritten, and the abstract values which
1645 /// should replace them in the comprehension body.
1646 values: &'a HashMap<Identifier, Expr>,
1647}
1648impl RewriteIterableBindingsVisitor<'_> {
1649 fn rewrite_scalar_access(
1650 &mut self,
1651 access: SymbolAccess,
1652 ) -> ControlFlow<SemanticAnalysisError, Option<ScalarExpr>> {
1653 let result = match self.values.get(access.name.as_ref()) {
1654 Some(Expr::Const(constant)) => {
1655 let span = constant.span();
1656 match constant.item {
1657 ConstantExpr::Scalar(value) => {
1658 assert_eq!(access.access_type, AccessType::Default);
1659 Some(ScalarExpr::Const(Span::new(span, value)))
1660 }
1661 ConstantExpr::Vector(ref elems) => match access.access_type {
1662 AccessType::Index(idx) => {
1663 Some(ScalarExpr::Const(Span::new(span, elems[idx])))
1664 }
1665 invalid => panic!(
1666 "expected vector to be reduced to scalar by access, got {invalid:#?}"
1667 ),
1668 },
1669 ConstantExpr::Matrix(ref rows) => match access.access_type {
1670 AccessType::Matrix(row, col) => {
1671 Some(ScalarExpr::Const(Span::new(span, rows[row][col])))
1672 }
1673 invalid => panic!(
1674 "expected matrix to be reduced to scalar by access, got {invalid:#?}",
1675 ),
1676 },
1677 }
1678 }
1679 Some(Expr::Range(range)) => {
1680 let span = range.span();
1681 let range = range.to_slice_range();
1682 match access.access_type {
1683 AccessType::Index(idx) => Some(ScalarExpr::Const(Span::new(
1684 span,
1685 (range.start + idx) as u64,
1686 ))),
1687 invalid => {
1688 panic!("expected range to be reduced to scalar by access, got {invalid:#?}",)
1689 }
1690 }
1691 }
1692 Some(Expr::Vector(elems)) => {
1693 match access.access_type {
1694 AccessType::Index(idx) => Some(elems[idx].clone().try_into().unwrap()),
1695 // This implies that the vector contains an element which is vector-like,
1696 // if the value at `idx` is not, this is an invalid access
1697 AccessType::Matrix(idx, nested_idx) => match &elems[idx] {
1698 Expr::SymbolAccess(saccess) => {
1699 let access = saccess.access(AccessType::Index(nested_idx)).unwrap();
1700 self.rewrite_scalar_access(access)?
1701 }
1702 invalid => panic!(
1703 "expected vector-like value at {}[{idx}], got: {invalid:#?}",
1704 access.name.as_ref(),
1705 ),
1706 },
1707 invalid => panic!(
1708 "expected vector to be reduced to scalar by access, got {invalid:#?}"
1709 ),
1710 }
1711 }
1712 Some(Expr::Matrix(elems)) => match access.access_type {
1713 AccessType::Matrix(row, col) => Some(elems[row][col].clone()),
1714 invalid => {
1715 panic!("expected matrix to be reduced to scalar by access, got {invalid:#?}")
1716 }
1717 },
1718 Some(Expr::SymbolAccess(symbol_access)) => {
1719 let mut new_access = symbol_access.access(access.access_type).unwrap();
1720 new_access.offset = access.offset;
1721 Some(ScalarExpr::SymbolAccess(new_access))
1722 }
1723 // These types of expressions will never be observed in this context, as they are
1724 // not valid iterable expressions (except calls, but those are lifted prior to rewrite
1725 // so that their use in this context is always a symbol access).
1726 Some(
1727 Expr::Call(_)
1728 | Expr::Binary(_)
1729 | Expr::ListComprehension(_)
1730 | Expr::Let(_)
1731 | Expr::BusOperation(_)
1732 | Expr::Null(_)
1733 | Expr::Unconstrained(_),
1734 ) => {
1735 unreachable!()
1736 }
1737 None => None,
1738 };
1739 ControlFlow::Continue(result)
1740 }
1741}
1742impl VisitMut<SemanticAnalysisError> for RewriteIterableBindingsVisitor<'_> {
1743 fn visit_mut_scalar_expr(
1744 &mut self,
1745 expr: &mut ScalarExpr,
1746 ) -> ControlFlow<SemanticAnalysisError> {
1747 match expr {
1748 // Nothing to do with constants
1749 ScalarExpr::Const(_) => ControlFlow::Continue(()),
1750 // If we observe an access, try to rewrite it as an iterable binding, if it is
1751 // not a candidate for rewrite, leave it alone.
1752 //
1753 // NOTE: We handle BoundedSymbolAccess here even though comprehension constraints are not
1754 // permitted in boundary_constraints currently. That is handled elsewhere, we just need to
1755 // make sure the symbols themselves are rewritten properly here.
1756 ScalarExpr::SymbolAccess(access)
1757 | ScalarExpr::BoundedSymbolAccess(BoundedSymbolAccess { column: access, .. }) => {
1758 if let Some(replacement) = self.rewrite_scalar_access(access.clone())? {
1759 *expr = replacement;
1760 return ControlFlow::Continue(());
1761 }
1762 ControlFlow::Continue(())
1763 }
1764 // We need to visit both operands of a binary expression - but while we're here,
1765 // check to see if resolving the operands reduces to a constant expression that
1766 // can be folded.
1767 ScalarExpr::Binary(binary_expr) => {
1768 self.visit_mut_binary_expr(binary_expr)?;
1769 match constant_propagation::try_fold_binary_expr(binary_expr) {
1770 Ok(Some(folded)) => {
1771 *expr = ScalarExpr::Const(folded);
1772 ControlFlow::Continue(())
1773 }
1774 Ok(None) => ControlFlow::Continue(()),
1775 Err(err) => ControlFlow::Break(SemanticAnalysisError::InvalidExpr(err)),
1776 }
1777 }
1778 // If we observe a call here, just rewrite the arguments, inlining happens elsewhere
1779 ScalarExpr::Call(call) => {
1780 for arg in call.args.iter_mut() {
1781 self.visit_mut_expr(arg)?;
1782 }
1783 ControlFlow::Continue(())
1784 }
1785 // We rewrite comprehension bodies before they are expanded, so it should never be
1786 // the case that we encounter a let here, as they can only be introduced in scalar
1787 // expression position as a result of inlining/expansion
1788 ScalarExpr::Let(_) => unreachable!(),
1789 ScalarExpr::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => {
1790 ControlFlow::Break(SemanticAnalysisError::Invalid)
1791 }
1792 }
1793 }
1794}
1795
1796/// This visitor is used to apply a selector expression to all constraints in a block
1797///
1798/// For constraints which already have a selector, this rewrites those selectors to be the
1799/// logical AND of the original selector and the selector being applied.
1800struct ApplyConstraintSelector<'a> {
1801 selector: &'a ScalarExpr,
1802}
1803impl VisitMut<SemanticAnalysisError> for ApplyConstraintSelector<'_> {
1804 fn visit_mut_statement(
1805 &mut self,
1806 statement: &mut Statement,
1807 ) -> ControlFlow<SemanticAnalysisError> {
1808 match statement {
1809 Statement::Let(expr) => self.visit_mut_let(expr),
1810 Statement::Enforce(expr) => {
1811 let expr =
1812 core::mem::replace(expr, ScalarExpr::Const(Span::new(SourceSpan::UNKNOWN, 0)));
1813 *statement = Statement::EnforceIf(expr, self.selector.clone());
1814 ControlFlow::Continue(())
1815 }
1816 Statement::EnforceIf(_, selector) => {
1817 // Combine the selectors
1818 let lhs = core::mem::replace(
1819 selector,
1820 ScalarExpr::Const(Span::new(SourceSpan::UNKNOWN, 0)),
1821 );
1822 let rhs = self.selector.clone();
1823 *selector = ScalarExpr::Binary(BinaryExpr::new(
1824 self.selector.span(),
1825 BinaryOp::Mul,
1826 lhs,
1827 rhs,
1828 ));
1829 ControlFlow::Continue(())
1830 }
1831 Statement::EnforceAll(_) => unreachable!(),
1832 Statement::Expr(_) => ControlFlow::Continue(()),
1833 Statement::BusEnforce(_) => ControlFlow::Break(SemanticAnalysisError::Invalid),
1834 }
1835 }
1836}
1837
1838/// This helper function is used to perform a mutation/replacement based on the expression
1839/// representing the effective value of a `let`-tree.
1840///
1841/// In particular, this function traverses the tree until it reaches the final `let` body
1842/// and the last `Expr` in that body. When it does, it invokes `callback` with a mutable
1843/// reference to that `Expr`. The callback may choose to simply mutate the `Expr`, or it
1844/// may return a new `Statement` which will be used to replace the `Statement` which
1845/// contained the `Expr` given to the callback.
1846///
1847/// This is used when expanding calls and list comprehensions, where the expanded form
1848/// of these is potentially a `let` tree, and we desire to place additional statements
1849/// in the bottom-most block, or perform some transformation on the expression which acts
1850/// as the result of the tree.
1851fn with_let_result<F>(
1852 inliner: &mut Inlining,
1853 entry: &mut Vec<Statement>,
1854 callback: F,
1855) -> Result<(), SemanticAnalysisError>
1856where
1857 F: FnOnce(&mut Inlining, &mut Expr) -> Result<Option<Statement>, SemanticAnalysisError>,
1858{
1859 // Preserve the original lexical scope to be restored on exit
1860 let prev = inliner.bindings.clone();
1861
1862 // SAFETY: We must use a raw pointer here because the Rust compiler is not able to
1863 // see that we only ever use the mutable reference once, and that the reference
1864 // is never aliased.
1865 //
1866 // Both of these guarantees are in fact upheld here however, as each iteration of the loop
1867 // is either the last iteration (when we use the mutable reference to mutate the end of the
1868 // bottom-most block), or a traversal to the last child of the current let expression.
1869 // We never alias the mutable reference, and in fact immediately convert back to a mutable
1870 // reference inside the loop to ensure that within the loop body we have some degree of
1871 // compiler-assisted checking of that invariant.
1872 let mut current_block = Some(entry as *mut Vec<Statement>);
1873 while let Some(parent_block) = current_block.take() {
1874 // SAFETY: We convert the pointer back to a mutable reference here before
1875 // we do anything else to ensure the usual aliasing rules are enforced.
1876 //
1877 // It is further guaranteed that this reference is never improperly aliased
1878 // across iterations, as each iteration is visiting a child of the previous
1879 // iteration's node, i.e. what we're doing here is equivalent to holding a
1880 // mutable reference and using it to mutate a field in a deeply nested struct.
1881 let parent_block = unsafe { &mut *parent_block };
1882 // A block is guaranteed to always have at least one statement here
1883 match parent_block.last_mut().unwrap() {
1884 // When we hit a block whose last statement is an expression, which
1885 // must also be the bottom-most block of this tree. This expression
1886 // is the effective value of the `let` tree. We will replace this
1887 // node if the callback we were given returns a new `Statement`. In
1888 // either case, we're done once we've handled the callback result.
1889 Statement::Expr(value) => match callback(inliner, value) {
1890 Ok(Some(replacement)) => {
1891 parent_block.pop();
1892 parent_block.push(replacement);
1893 break;
1894 }
1895 Ok(None) => break,
1896 Err(err) => {
1897 inliner.bindings = prev;
1898 return Err(err);
1899 }
1900 },
1901 // We've traversed down a level in the let-tree, but there are more to go.
1902 // Set up the next iteration to visit the next block down in the tree.
1903 Statement::Let(let_expr) => {
1904 // Register this binding
1905 let binding_ty = inliner.expr_binding_type(&let_expr.value).unwrap();
1906 inliner.bindings.insert(let_expr.name, binding_ty);
1907 // Set up the next iteration
1908 current_block = Some(&mut let_expr.body as *mut Vec<Statement>);
1909 continue;
1910 }
1911 // No other statements types are possible here
1912 _ => unreachable!(),
1913 }
1914 }
1915
1916 // Restore the original lexical scope
1917 inliner.bindings = prev;
1918
1919 Ok(())
1920}