air_codegen_masm/
codegen.rs

1use crate::config::CodegenConfig;
2use crate::constants::{AUX_TRACE, MAIN_TRACE};
3use crate::error::CodegenError;
4use crate::utils::{
5    boundary_group_to_procedure_name, load_quadratic_element, periodic_group_to_memory_offset,
6    quadratic_element_square,
7};
8use crate::visitor::{
9    walk_boundary_constraints, walk_integrity_constraints, walk_periodic_columns, AirVisitor,
10};
11use crate::writer::Writer;
12use air_ir::{
13    Air, ConstraintDomain, ConstraintRoot, Identifier, NodeIndex, Operation, PeriodicColumn,
14    TraceSegmentId, Value,
15};
16use miden_core::{Felt, StarkField};
17use std::collections::btree_map::BTreeMap;
18use std::mem::{replace, take};
19use winter_math::fft;
20
21#[derive(Default)]
22pub struct CodeGenerator {
23    config: CodegenConfig,
24}
25impl CodeGenerator {
26    pub fn new(config: CodegenConfig) -> Self {
27        Self { config }
28    }
29}
30impl air_ir::CodeGenerator for CodeGenerator {
31    type Output = String;
32
33    fn generate(&self, ir: &Air) -> anyhow::Result<Self::Output> {
34        let generator = Backend::new(ir, self.config);
35        generator.generate()
36    }
37}
38
39struct Backend<'ast> {
40    /// Miden Assembly writer.
41    ///
42    /// Track indentation level, and performs basic validations for generated instructions and
43    /// closing of blocks.
44    writer: Writer,
45
46    /// Counts how many periodic columns have been visited so far.
47    ///
48    /// Periodic columns are visited in order, and the counter is the same as the columns ID.
49    periodic_column: u32,
50
51    /// A list of the periodic lengths in decreasing order.
52    ///
53    /// The index in this vector corresponds to the offset of the pre-computed z value.
54    periods: Vec<usize>,
55
56    /// Counts how many composition coefficients have been used so far, used to compute the correct
57    /// offset in memory. This counter is shared among integrity and boundary constraints for both
58    /// the main and auxiliary traces.
59    composition_coefficient_count: u32,
60
61    /// Counts how many integrity constraint roots have been visited so far, used for
62    /// emitting documentation.
63    integrity_contraints: usize,
64
65    /// Counts how many boundary constraint roots have been visited so far, used for
66    /// emitting documentation.
67    boundary_contraints: usize,
68
69    /// Counts the size of a given boundary constraint category. The counter is used to emit the
70    /// correct number of multiplications for a given divisor.
71    boundary_constraint_count: BTreeMap<(TraceSegmentId, ConstraintDomain), usize>,
72
73    /// Maps the public input to their start offset.
74    public_input_to_offset: BTreeMap<Identifier, usize>,
75
76    /// The [Air] to visit.
77    ir: &'ast Air,
78
79    /// Configuration for the codegen.
80    config: CodegenConfig,
81}
82
83impl<'ast> Backend<'ast> {
84    fn new(ir: &'ast Air, config: CodegenConfig) -> Self {
85        // remove duplicates and sort period lengths in descending order, since larger periods will
86        // have smaller number of cycles (which means a smaller number of exponentiations)
87        let mut periods: Vec<usize> = ir.periodic_columns().map(|e| e.period()).collect();
88        periods.sort();
89        periods.dedup();
90        periods.reverse();
91
92        // Maps the public input name to its memory offset, were the memory offset is the
93        // accumulated number of inputs laid out in memory prior to our target. For example:
94        //
95        //  Input "a" starts at offset 0
96        // |      Input "b" starts at offset 4, after the 4 values of "a"
97        // v      v                   Input "c" starts at offset 20, after the values of "a" and "b"
98        // [ .... | ................ | ....]
99        //
100        // The offset is used by the codegen to load public input values.
101        let public_input_to_offset = ir
102            .public_inputs()
103            .scan(0, |public_input_count, input| {
104                let start_offset = *public_input_count;
105                *public_input_count += input.size;
106                Some((input.name, start_offset))
107            })
108            .collect();
109
110        // count the boundary constraints
111        let mut boundary_constraint_count = BTreeMap::new();
112        for segment in [MAIN_TRACE, AUX_TRACE] {
113            for boundary in ir.boundary_constraints(segment) {
114                boundary_constraint_count
115                    .entry((segment, boundary.domain()))
116                    .and_modify(|c| *c += 1)
117                    .or_insert(1);
118            }
119        }
120
121        Self {
122            writer: Writer::new(),
123            periodic_column: 0,
124            periods,
125            composition_coefficient_count: 0,
126            integrity_contraints: 0,
127            boundary_contraints: 0,
128            boundary_constraint_count,
129            public_input_to_offset,
130            ir,
131            config,
132        }
133    }
134
135    /// Emits the Miden Assembly code  after visiting the [AirIR].
136    fn generate(mut self) -> anyhow::Result<String> {
137        self.visit_air()?;
138        Ok(self.writer.into_code())
139    }
140
141    /// Emits code for the procedure `cache_z_exp`.
142    ///
143    /// The procedure computes and caches the necessary exponentiation of `z`. These values are
144    /// later on used to evaluate each periodic column polynomial and the constraint divisor.
145    ///
146    /// This procedure exists because the VM doesn't have native instructions for exponentiation of
147    /// quadratic extension elements, and this is an expensive operation.
148    ///
149    /// The generated code is optimized to perform the fewest number of exponentiations, this is
150    /// achieved by observing that periodic columns and trace length are both powers-of-two, since
151    /// the exponent is defined as `exp = trace_len / periodic_column_len`, all exponents are
152    /// themselves powers-of-two. This allows the results to be computed from smallest to largest,
153    /// re-using the intermediary values.
154    fn gen_cache_z_exp(&mut self) -> Result<(), CodegenError> {
155        // NOTE:
156        // - the trace length is a power-of-two.
157        //   Ref: https://github.com/0xPolygonMiden/miden-vm/blob/next/stdlib/asm/crypto/stark/random_coin.masm#L82-L87
158        // - the periodic columns are powers-of-two.
159        //   Ref: https://github.com/0xPolygonMiden/air-script/blob/next/ir/src/symbol_table/mod.rs#L305-L309
160        // - the trace length is always greater-than-or-equal the periodic column length.
161        //   Ref: https://github.com/facebook/winterfell/blob/main/air/src/air/mod.rs#L322-L326
162
163        self.writer
164            .header("Procedure to efficiently compute the required exponentiations of the out-of-domain point `z` and cache them for later use.");
165        self.writer.header("");
166        self.writer.header("This computes the power of `z` needed to evaluate the periodic polynomials and the constraint divisors");
167        self.writer.header("");
168        self.writer.header("Input: [...]");
169        self.writer.header("Output: [...]");
170
171        self.writer.proc("cache_z_exp");
172
173        self.load_z();
174        self.writer.header("=> [z_1, z_0, ...]");
175
176        // The loop below needs to mutably borrow the codegen, so take the field for the iteration
177        // (must reset value after the loop).
178        let periods = take(&mut self.periods);
179
180        // Emit code to precompute the exponentiation of z for the periodic columns.
181        let mut previous_period_size: Option<u64> = None;
182        for (idx, period) in periods.iter().enumerate() {
183            assert!(
184                period.is_power_of_two(),
185                "The length of a periodic column must be a power-of-two"
186            );
187
188            match previous_period_size {
189                None => {
190                    self.writer.header(format!(
191                        "Find number exponentiations required to get for a period of length {}",
192                        period
193                    ));
194
195                    // This procedure caches the result of `z.exp(trace_len / period_size)`. Note
196                    // that `trace_len = 2^x` and `period_len = 2^y`, so the result of the division
197                    // is the same as `2^(x - y)`, the code below computes `x-y` because both
198                    // values are in log2 form.
199                    //
200                    // The result is the number of times that `z` needs to be squared. The
201                    // instructions below result in a negative value, as `add.1` is optimized in
202                    // the VM (IOW, counting up is faster than counting down).
203                    self.load_log2_trace_len();
204                    self.writer.neg();
205                    self.writer.add(period.ilog2().into());
206                    self.writer.header(format!(
207                        "=> [count, z_1, z_0, ...] where count = -log2(trace_len) + {}",
208                        period.ilog2()
209                    ));
210                }
211                Some(prev) => {
212                    self.writer.header(format!(
213                        "Find number of exponentiations to bring from length {} to {}",
214                        prev, *period,
215                    ));
216
217                    // The previous iteration computed `log2(trace_len) - log2(prev_period_size)`,
218                    // this iteration will compute `log2(trace_len) - log2(period_size)`. The goal
219                    // is to reuse the previous value as a cache, so only compute the difference of
220                    // the two values which is just `log2(prev_period_size) - log2(period_size)`.
221                    let prev = Felt::new(prev.ilog2().into());
222                    let new = Felt::new(period.ilog2().into());
223                    let diff = new - prev; // this is a negative value
224                    self.writer.push(diff.as_int());
225                    self.writer.header(format!(
226                        "=> [count, (z_1, z_0)^{}, ...] where count = {} - {}",
227                        prev.as_int(),
228                        new.as_int(),
229                        prev.as_int(),
230                    ));
231                }
232            }
233
234            self.writer.header("Exponentiate z");
235            self.writer.ext2_exponentiate();
236
237            let idx: u32 = idx.try_into().expect("periodic column length is too large");
238            let addr = self.config.z_exp_address + idx;
239            self.writer.push(0);
240            self.writer.mem_storew(addr);
241            self.writer.comment(format!("z^{}", *period));
242
243            self.writer.header(format!(
244                "=> [0, 0, (z_1, z_0)^n, ...] where n = trace_len-{}",
245                *period
246            ));
247            self.writer.drop();
248            self.writer.drop();
249
250            previous_period_size = Some((*period).try_into().expect("diff must fit in a u64"));
251        }
252
253        // Re-set the periods now that the loop is over
254        let _ = replace(&mut self.periods, periods);
255
256        // Emit code to precompute the exponentiation of z for the divisor.
257        match previous_period_size {
258            None => {
259                self.writer.header("Exponentiate z trace_len times");
260                self.load_log2_trace_len();
261                self.writer.neg();
262                self.writer
263                    .header("=> [count, z_1, z_0, ...] where count = -log2(trace_len)");
264            }
265            Some(prev) => {
266                self.writer
267                    .header(format!("Exponentiate z {} times, until trace_len", prev));
268                let prev = Felt::new(prev.ilog2().into());
269                let neg_prev = -prev;
270                self.writer.push(neg_prev.as_int());
271                self.writer.header(format!(
272                    "=> [count, (z_1, z_0)^n, ...] where count=-{} , n=trace_len-{}",
273                    prev.as_int(),
274                    prev.as_int(),
275                ));
276            }
277        }
278
279        self.writer.ext2_exponentiate();
280
281        let idx: u32 = self
282            .periods
283            .len()
284            .try_into()
285            .expect("periodic column length is too large");
286        let addr = self.config.z_exp_address + idx;
287        self.writer.push(0);
288        self.writer.mem_storew(addr);
289        self.writer.comment("z^trace_len");
290
291        self.writer.header("=> [0, 0, (z_1, z_0)^trace_len, ...]");
292        self.writer.dropw();
293        self.writer.comment("Clean stack");
294
295        self.writer.end();
296
297        Ok(())
298    }
299
300    /// Emits code for the procedure `cache_periodic_polys`.
301    ///
302    /// This procedure first computes the `z**exp` for each periodic column, and then evaluates
303    /// each periodic polynomial using Horner's method. The results are cached to memory.
304    fn gen_evaluate_periodic_polys(&mut self) -> Result<(), CodegenError> {
305        self.writer
306            .header("Procedure to evaluate the periodic polynomials.");
307        self.writer.header("");
308        self.writer
309            .header("Procedure `cache_z_exp` must have been called prior to this.");
310        self.writer.header("");
311        self.writer.header("Input: [...]");
312        self.writer.header("Output: [...]");
313
314        self.writer.proc("cache_periodic_polys");
315        walk_periodic_columns(self, self.ir)?;
316        self.writer.end();
317
318        Ok(())
319    }
320
321    fn gen_compute_integrity_constraint_divisor(&mut self) -> Result<(), CodegenError> {
322        self.writer
323            .header("Procedure to compute the integrity constraint divisor.");
324        self.writer.header("");
325        self.writer.header(
326            "The divisor is defined as `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))`",
327        );
328        self.writer
329            .header("Procedure `cache_z_exp` must have been called prior to this.");
330        self.writer.header("");
331        self.writer.header("Input: [...]");
332        self.writer.header("Output: [divisor_1, divisor_0, ...]");
333
334        self.writer.proc("compute_integrity_constraint_divisor");
335
336        // `z^trace_len` is saved after all the period column points
337        let group: u32 = self.periods.len().try_into().expect("periods are u32");
338        load_quadratic_element(
339            &mut self.writer,
340            self.config.z_exp_address,
341            periodic_group_to_memory_offset(group),
342        )?;
343        self.writer.comment("load z^trace_len");
344
345        self.writer.header("Comments below use zt = `z^trace_len`");
346        self.writer.header("=> [zt_1, zt_0, ...]");
347
348        // Compute the numerator `z^trace_len - 1`
349        self.writer.push(1);
350        self.writer.push(0);
351        self.writer.ext2sub();
352        self.writer.header("=> [zt_1-1, zt_0-1, ...]");
353
354        // Compute the denominator of the divisor
355        self.load_z();
356        self.writer.header("=> [z_1, z_0, zt_1-1, zt_0-1, ...]");
357
358        self.writer.exec("get_exemptions_points");
359        self.writer
360            .header("=> [g^{trace_len-2}, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...]");
361
362        self.writer.dup(0);
363        self.writer.mem_store(self.config.exemption_two_address);
364        self.writer
365            .comment("Save a copy of `g^{trace_len-2} to be used by the boundary divisor");
366
367        // Compute `z - g^{trace_len-2}`
368        self.writer.dup(3);
369        self.writer.dup(3);
370        self.writer.movup(3);
371        self.writer.push(0);
372        self.writer.ext2sub();
373        self.writer
374            .header("=> [e_1, e_0, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...]");
375
376        // Compute `z - g^{trace_len-1}`
377        self.writer.movup(4);
378        self.writer.movup(4);
379        self.writer.movup(4);
380        self.writer.push(0);
381        self.writer.ext2sub();
382        self.writer
383            .header("=> [e_3, e_2, e_1, e_0, zt_1-1, zt_0-1, ...]");
384
385        // Compute the denominator `(z - g^{trace_len-2}) * (z - g^{trace_len-1})`
386        self.writer.ext2mul();
387        self.writer
388            .header("=> [denominator_1, denominator_0, zt_1-1, zt_0-1, ...]");
389
390        // Compute the divisor `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))`
391        self.writer.ext2div();
392        self.writer.header("=> [divisor_1, divisor_0, ...]");
393        self.writer.end();
394
395        Ok(())
396    }
397
398    /// Emits code for the procedure `compute_integrity_constraints`.
399    ///
400    /// This procedure evaluates each top-level integrity constraint and leaves the result on the
401    /// stack. This is useful for testing the evaluation. Later on the value is aggregated.
402    fn gen_compute_integrity_constraints(&mut self) -> Result<(), CodegenError> {
403        let main_trace_count = self.ir.integrity_constraints(MAIN_TRACE).len();
404        let aux_trace_count = self.ir.integrity_constraints(AUX_TRACE).len();
405
406        self.writer
407            .header("Procedure to evaluate numerators of all integrity constraints.");
408        self.writer.header("");
409        self.writer.header(format!(
410            "All the {} main and {} auxiliary constraints are evaluated.",
411            main_trace_count, aux_trace_count
412        ));
413        self.writer.header(
414            "The result of each evaluation is kept on the stack, with the top of the stack",
415        );
416        self.writer.header(
417            "containing the evaluations for the auxiliary trace (if any) followed by the main trace.",
418        );
419        self.writer.header("");
420        self.writer.header("Input: [...]");
421        self.writer.header("Output: [(r_1, r_0)*, ...]");
422        self.writer.header(
423            "where: (r_1, r_0) is the quadratic extension element resulting from the integrity constraint evaluation.",
424        );
425        self.writer.header(format!(
426            "       This procedure pushes {} quadratic extension field elements to the stack",
427            main_trace_count + aux_trace_count
428        ));
429
430        self.writer.proc("compute_integrity_constraints");
431        walk_integrity_constraints(self, self.ir, MAIN_TRACE)?;
432        self.integrity_contraints = 0; // reset counter for the aux trace
433        walk_integrity_constraints(self, self.ir, AUX_TRACE)?;
434        self.writer.end();
435
436        Ok(())
437    }
438
439    /// Emits procedure to compute boundary constraints values.
440    ///
441    /// This will emit four procedures:
442    ///
443    /// - compute_boundary_constraints_main_first
444    /// - compute_boundary_constraints_main_last
445    /// - compute_boundary_constraints_aux_first
446    /// - compute_boundary_constraints_aux_last
447    ///
448    /// Each procedure corresponds to a specific boundary constraint group. They are emitted
449    /// separetely because each value is divided by a different divisor, and it is best to
450    /// manipulate each point separetely.
451    fn gen_compute_boundary_constraints(&mut self) -> Result<(), CodegenError> {
452        // The boundary constraints have a natural order defined as (trace, domain, column_pos).
453        // The code below iterates using that order
454
455        if self
456            .boundary_constraint_count
457            .contains_key(&(MAIN_TRACE, ConstraintDomain::FirstRow))
458        {
459            let name = boundary_group_to_procedure_name(MAIN_TRACE, ConstraintDomain::FirstRow);
460            self.writer.header(
461                "Procedure to evaluate the boundary constraint numerator for the first row of the main trace",
462            );
463            self.writer.header("");
464            self.writer.header("Input: [...]");
465            self.writer.header("Output: [(r_1, r_0)*, ...]");
466            self.writer.header(
467                "Where: (r_1, r_0) is one quadratic extension field element for each constraint",
468            );
469            self.writer.proc(name);
470            walk_boundary_constraints(self, self.ir, MAIN_TRACE, ConstraintDomain::FirstRow)?;
471            self.writer.end();
472        }
473
474        if self
475            .boundary_constraint_count
476            .contains_key(&(MAIN_TRACE, ConstraintDomain::LastRow))
477        {
478            let name = boundary_group_to_procedure_name(MAIN_TRACE, ConstraintDomain::LastRow);
479            self.writer.header(
480                "Procedure to evaluate the boundary constraint numerator for the last row of the main trace",
481            );
482            self.writer.header("");
483            self.writer.header("Input: [...]");
484            self.writer.header("Output: [(r_1, r_0)*, ...]");
485            self.writer.header(
486                "Where: (r_1, r_0) is one quadratic extension field element for each constraint",
487            );
488            self.writer.proc(name);
489            walk_boundary_constraints(self, self.ir, MAIN_TRACE, ConstraintDomain::LastRow)?;
490            self.writer.end();
491        }
492
493        if self
494            .boundary_constraint_count
495            .contains_key(&(AUX_TRACE, ConstraintDomain::FirstRow))
496        {
497            let name = boundary_group_to_procedure_name(AUX_TRACE, ConstraintDomain::FirstRow);
498            self.writer.header(
499            "Procedure to evaluate the boundary constraint numerator for the first row of the auxiliary trace",
500        );
501            self.writer.header("");
502            self.writer.header("Input: [...]");
503            self.writer.header("Output: [(r_1, r_0)*, ...]");
504            self.writer.header(
505                "Where: (r_1, r_0) is one quadratic extension field element for each constraint",
506            );
507            self.writer.proc(name);
508            walk_boundary_constraints(self, self.ir, AUX_TRACE, ConstraintDomain::FirstRow)?;
509            self.writer.end();
510        }
511
512        if self
513            .boundary_constraint_count
514            .contains_key(&(AUX_TRACE, ConstraintDomain::LastRow))
515        {
516            let name = boundary_group_to_procedure_name(AUX_TRACE, ConstraintDomain::LastRow);
517            self.writer.header(
518            "Procedure to evaluate the boundary constraint numerator for the last row of the auxiliary trace",
519        );
520            self.writer.header("");
521            self.writer.header("Input: [...]");
522            self.writer.header("Output: [(r_1, r_0)*, ...]");
523            self.writer.header(
524                "Where: (r_1, r_0) is one quadratic extension field element for each constraint",
525            );
526            self.writer.proc(name);
527            walk_boundary_constraints(self, self.ir, AUX_TRACE, ConstraintDomain::LastRow)?;
528            self.writer.end();
529        }
530
531        Ok(())
532    }
533
534    /// Emits code for the procedure `get_exemptions_points`.
535    ///
536    /// Generate code to push the exemption points to the top of the stack.
537    /// Stack: [g^{trace_len-2}, g^{trace_len-1}, ...]
538    fn gen_get_exemptions_points(&mut self) -> Result<(), CodegenError> {
539        self.writer
540            .header("Procedure to compute the exemption points.");
541        self.writer.header("");
542        self.writer.header("Input: [...]");
543        self.writer.header("Output: [g^{-2}, g^{-1}, ...]");
544
545        self.writer.proc("get_exemptions_points");
546        self.load_trace_domain_generator();
547        self.writer.header("=> [g, ...]");
548
549        self.writer.push(1);
550        self.writer.swap();
551        self.writer.div();
552        self.writer.header("=> [g^{-1}, ...]");
553
554        self.writer.dup(0);
555        self.writer.dup(0);
556        self.writer.mul();
557        self.writer.header("=> [g^{-2}, g^{-1}, ...]");
558
559        self.writer.end(); // end proc
560
561        Ok(())
562    }
563
564    /// Emits code for the procedure `evaluate_integrity_constraints`.
565    ///
566    /// Evaluates the integrity constraints for both the main and auxiliary traces.
567    fn gen_evaluate_integrity_constraints(&mut self) -> Result<(), CodegenError> {
568        self.writer
569            .header("Procedure to evaluate all integrity constraints.");
570        self.writer.header("");
571        self.writer.header("Input: [...]");
572        self.writer.header("Output: [(r_1, r_0), ...]");
573        self.writer
574            .header("Where: (r_1, r_0) is the final result with the divisor applied");
575
576        self.writer.proc("evaluate_integrity_constraints");
577
578        if !self.ir.periodic_columns.is_empty() {
579            self.writer.exec("cache_periodic_polys");
580        }
581
582        self.writer.exec("compute_integrity_constraints");
583
584        self.writer
585            .header("Numerator of the transition constraint polynomial");
586
587        let total_len = self.ir.integrity_constraints(MAIN_TRACE).len()
588            + self.ir.integrity_constraints(AUX_TRACE).len();
589
590        for _ in 0..total_len {
591            self.writer.ext2add();
592        }
593
594        self.writer
595            .header("Divisor of the transition constraint polynomial");
596
597        self.writer.exec("compute_integrity_constraint_divisor");
598
599        self.writer.ext2div();
600        self.writer.comment("divide the numerator by the divisor");
601
602        self.writer.end();
603
604        Ok(())
605    }
606
607    /// Emits code for the procedure `evaluate_boundary_constraints`.
608    ///
609    /// Evaluates the boundary constraints for both the main and auxiliary traces.
610    fn gen_evaluate_boundary_constraints(&mut self) -> Result<(), CodegenError> {
611        self.writer
612            .header("Procedure to evaluate all boundary constraints.");
613        self.writer.header("");
614        self.writer.header("Input: [...]");
615        self.writer.header("Output: [(r_1, r_0), ...]");
616        self.writer
617            .header("Where: (r_1, r_0) is the final result with the divisor applied");
618
619        self.writer.proc("evaluate_boundary_constraints");
620
621        let last = self.boundary_constraint_group(ConstraintDomain::LastRow);
622        let first = self.boundary_constraint_group(ConstraintDomain::FirstRow);
623
624        if last != 0 && first != 0 {
625            self.writer.header("Add first and last row groups");
626            self.writer.ext2add();
627        }
628
629        self.writer.end();
630
631        Ok(())
632    }
633
634    /// Emits code to evaluate the boundary constraint for a given group determined by the domain.
635    fn boundary_constraint_group(&mut self, domain: ConstraintDomain) -> usize {
636        let aux_count = self
637            .boundary_constraint_count
638            .get(&(AUX_TRACE, domain))
639            .cloned();
640
641        let name = match domain {
642            ConstraintDomain::LastRow => "last",
643            ConstraintDomain::FirstRow => "first",
644            _ => panic!("unexpected domain"),
645        };
646
647        if let Some(count) = aux_count {
648            self.boundary_constraint_numerator(count, AUX_TRACE, domain);
649            self.writer
650                .header(format!("=> [(aux_{name}1, aux_{name}0), ...]"));
651        }
652
653        let main_count = self
654            .boundary_constraint_count
655            .get(&(MAIN_TRACE, domain))
656            .cloned();
657
658        if let Some(count) = main_count {
659            self.boundary_constraint_numerator(count, MAIN_TRACE, domain);
660
661            if aux_count.is_some() {
662                self.writer.header(format!(
663                    "=> [(main_{name}1, main_{name}0), (aux_{name}1, aux_{name}0), ...]"
664                ));
665                self.writer.ext2add();
666            }
667
668            self.writer.header(format!("=> [({name}1, {name}0), ...]"));
669        }
670
671        if aux_count.is_some() || main_count.is_some() {
672            self.writer
673                .header(format!("Compute the denominator for domain {:?}", domain));
674
675            match domain {
676                ConstraintDomain::FirstRow => {
677                    self.load_z();
678                    self.writer.push(1);
679                    self.writer.push(0);
680                    self.writer.ext2sub();
681                }
682                ConstraintDomain::LastRow => {
683                    self.load_z();
684                    self.writer.mem_load(self.config.exemption_two_address);
685                    self.writer.push(0);
686                    self.writer.ext2sub();
687                }
688                _ => panic!("unexpected constraint domain"),
689            };
690
691            self.writer
692                .header(format!("Compute numerator/denominator for {name} row"));
693            self.writer.ext2div();
694
695            aux_count.unwrap_or(0) + main_count.unwrap_or(0)
696        } else {
697            0
698        }
699    }
700
701    /// Emits code to evaluate the numerator portion of a boundary constraint point determined by
702    /// `segment` and `domain`.
703    fn boundary_constraint_numerator(
704        &mut self,
705        count: usize,
706        segment: TraceSegmentId,
707        domain: ConstraintDomain,
708    ) {
709        let name = boundary_group_to_procedure_name(segment, domain);
710        self.writer.exec(name);
711
712        if count > 1 {
713            self.writer.header(format!(
714                "Accumulate the numerator for segment {} {:?}",
715                segment, domain
716            ));
717            for _ in 0..count {
718                self.writer.ext2add();
719            }
720        }
721    }
722
723    /// Emits code for the procedure `evaluate_constraints`.
724    ///
725    /// This will compute and cache values, the transition and boundary constraints for both the main and auxiliary traces.
726    fn gen_evaluate_constraints(&mut self) {
727        self.writer
728            .header("Procedure to evaluate the integrity and boundary constraints.");
729        self.writer.header("");
730        self.writer.header("Input: [...]");
731        self.writer.header("Output: [(r_1, r_0), ...]");
732
733        self.writer.export("evaluate_constraints");
734
735        // The order of execution below is important. These are the dependencies:
736        // - `z^trace_len` is computed and cached to be used by integrity contraints
737        // - `g^{trace_len-2}` is computed and cached to be used by boundary constraints
738        self.writer.exec("cache_z_exp");
739        self.writer.exec("evaluate_integrity_constraints");
740        self.writer.exec("evaluate_boundary_constraints");
741        self.writer.ext2add();
742
743        self.writer.end();
744    }
745
746    /// Emits code to load the `log_2(trace_len)` onto the top of the stack.
747    fn load_log2_trace_len(&mut self) {
748        self.writer.mem_load(self.config.log2_trace_len_address);
749    }
750
751    /// Emits code to load `z` onto the top of the stack.
752    fn load_z(&mut self) {
753        self.writer.padw();
754        self.writer.mem_loadw(self.config.z_address);
755        self.writer.drop();
756        self.writer.drop();
757        self.writer.comment("load z");
758    }
759
760    /// Emits code to load `g`, the trace domain generator.
761    fn load_trace_domain_generator(&mut self) {
762        self.writer
763            .mem_load(self.config.trace_domain_generator_address);
764    }
765}
766
767impl<'ast> AirVisitor<'ast> for Backend<'ast> {
768    type Value = ();
769    type Error = CodegenError;
770
771    fn visit_integrity_constraint(
772        &mut self,
773        constraint: &'ast ConstraintRoot,
774        trace_segment: TraceSegmentId,
775    ) -> Result<Self::Value, Self::Error> {
776        if !constraint.domain().is_integrity() {
777            return Err(CodegenError::InvalidIntegrityConstraint);
778        }
779
780        let segment = if trace_segment == MAIN_TRACE {
781            "main"
782        } else {
783            "aux"
784        };
785
786        self.writer.header(format!(
787            "integrity constraint {} for {}",
788            self.integrity_contraints, segment
789        ));
790
791        self.visit_node_index(constraint.node_index())?;
792
793        self.writer
794            .header("Multiply by the composition coefficient");
795
796        load_quadratic_element(
797            &mut self.writer,
798            self.config.composition_coef_address,
799            self.composition_coefficient_count,
800        )?;
801        self.writer.ext2mul();
802        self.composition_coefficient_count += 1;
803
804        self.integrity_contraints += 1;
805        Ok(())
806    }
807
808    fn visit_boundary_constraint(
809        &mut self,
810        constraint: &'ast ConstraintRoot,
811        trace_segment: TraceSegmentId,
812    ) -> Result<Self::Value, Self::Error> {
813        if !constraint.domain().is_boundary() {
814            return Err(CodegenError::InvalidBoundaryConstraint);
815        }
816
817        let segment = if trace_segment == MAIN_TRACE {
818            "main"
819        } else {
820            "aux"
821        };
822
823        self.writer.header(format!(
824            "boundary constraint {} for {}",
825            self.boundary_contraints, segment
826        ));
827
828        // Note: AirScript's boundary constraints are only defined for the first or last row.
829        // Meaning they are implemented as an assertion for a single element. Visiting the
830        // [NodeIndex] will emit code to compute the difference of the expected value and the
831        // evaluation frame value.
832        self.visit_node_index(constraint.node_index())?;
833
834        self.writer
835            .header("Multiply by the composition coefficient");
836
837        // Note: The correctness of the load below relies on the integrity constraint being
838        // iterated first _and_ the boundary constraints being iterated in natural order.
839        load_quadratic_element(
840            &mut self.writer,
841            self.config.composition_coef_address,
842            self.composition_coefficient_count,
843        )?;
844        self.writer.ext2mul();
845        self.composition_coefficient_count += 1;
846
847        self.boundary_contraints += 1;
848        Ok(())
849    }
850
851    fn visit_air(&mut self) -> Result<Self::Value, Self::Error> {
852        self.gen_cache_z_exp()?;
853        self.gen_get_exemptions_points()?;
854
855        if !self.ir.periodic_columns.is_empty() {
856            self.gen_evaluate_periodic_polys()?;
857        }
858
859        self.gen_compute_integrity_constraint_divisor()?;
860
861        self.gen_compute_integrity_constraints()?;
862        self.gen_compute_boundary_constraints()?;
863
864        // NOTE: Order of the following two methods is important! The iteration order is used to
865        // determine the composition coefficient index. The correct order is:
866        // 1. Integrity constraints for the MAIN trace.
867        // 2. Integrity constraints for the AUX trace.
868        // 3. Boundary constraints for the MAIN trace.
869        // 4. Boundary constraints for the AUX trace.
870        self.gen_evaluate_integrity_constraints()?;
871        self.gen_evaluate_boundary_constraints()?;
872
873        self.gen_evaluate_constraints();
874
875        Ok(())
876    }
877
878    fn visit_node_index(
879        &mut self,
880        node_index: &'ast NodeIndex,
881    ) -> Result<Self::Value, Self::Error> {
882        let op = self.ir.constraint_graph().node(node_index).op();
883        self.visit_operation(op)
884    }
885
886    fn visit_operation(&mut self, op: &'ast Operation) -> Result<Self::Value, Self::Error> {
887        match op {
888            Operation::Value(value) => {
889                self.visit_value(value)?;
890            }
891            Operation::Add(left, right) => {
892                self.visit_node_index(left)?;
893                self.visit_node_index(right)?;
894                self.writer.ext2add();
895            }
896            Operation::Sub(left, right) => {
897                self.visit_node_index(left)?;
898                self.visit_node_index(right)?;
899                self.writer.ext2sub();
900            }
901            Operation::Mul(left, right) => {
902                self.visit_node_index(left)?;
903                self.visit_node_index(right)?;
904                self.writer.ext2mul();
905            }
906            Operation::Exp(left, exp) => {
907                // NOTE: The VM doesn't support exponentiation of extension elements.
908                //
909                // Ref: https://github.com/facebook/winterfell/blob/0acb2a148e2e8445d5f6a3511fa9d852e54818dd/math/src/field/traits.rs#L124-L150
910
911                self.visit_node_index(left)?;
912
913                self.writer.header("push the accumulator to the stack");
914                self.writer.push(1);
915                self.writer.movdn(2);
916                self.writer.push(0);
917                self.writer.movdn(2);
918                self.writer.header("=> [b1, b0, r1, r0, ...]");
919
920                // emitted code computes exponentiation via square-and-multiply
921                let mut e: usize = *exp;
922                while e != 0 {
923                    self.writer
924                        .header(format!("square {} times", e.trailing_zeros()));
925                    quadratic_element_square(&mut self.writer, e.trailing_zeros());
926
927                    // account for the exponentiations done above
928                    e = e >> e.trailing_zeros();
929
930                    self.writer.header("multiply");
931                    self.writer.dup(1);
932                    self.writer.dup(1);
933                    self.writer.movdn(5);
934                    self.writer.movdn(5);
935                    self.writer
936                        .header("=> [b1, b0, r1, r0, b1, b0, ...] (4 cycles)");
937
938                    self.writer.ext2mul();
939                    self.writer.movdn(3);
940                    self.writer.movdn(3);
941                    self.writer.header("=> [b1, b0, r1', r0', ...] (5 cycles)");
942
943                    // account for the multiply done above
944                    assert!(
945                        e & 1 == 1,
946                        "this loop is only executed if the number is non-zero"
947                    );
948                    e ^= 1;
949                }
950
951                self.writer.header("clean stack");
952                self.writer.drop();
953                self.writer.drop();
954                self.writer.header("=> [r1, r0, ...] (2 cycles)");
955            }
956        };
957
958        Ok(())
959    }
960
961    fn visit_periodic_column(
962        &mut self,
963        column: &'ast PeriodicColumn,
964    ) -> Result<Self::Value, Self::Error> {
965        // convert the periodic column to a polynomial
966        let inv_twiddles = fft::get_inv_twiddles::<Felt>(column.period());
967        let mut poly: Vec<Felt> = column.values.iter().map(|e| Felt::new(*e)).collect();
968        fft::interpolate_poly(&mut poly, &inv_twiddles);
969
970        self.writer
971            .comment(format!("periodic column {}", self.periodic_column));
972
973        // LOAD OOD ELEMENT
974        // ---------------------------------------------------------------------------------------
975
976        // assumes that cache_z_exp has been called before, which precomputes the value of z**exp
977        let group: u32 = self
978            .periods
979            .iter()
980            .position(|&p| p == column.period())
981            .expect("All periods are added in the constructor")
982            .try_into()
983            .expect("periods are u32");
984        load_quadratic_element(
985            &mut self.writer,
986            self.config.z_exp_address,
987            periodic_group_to_memory_offset(group),
988        )?;
989        self.writer.header("=> [z_exp_1, z_exp_0, ...]");
990
991        // EVALUATE PERIODIC POLY
992        // ---------------------------------------------------------------------------------------
993
994        // convert coefficients from Montgomery form (Masm uses plain integers).
995        let coef: Vec<u64> = poly.iter().map(|e| e.as_int()).collect();
996
997        // periodic columns have at least 2 values, push the first as the accumulator
998        self.writer.push(coef[0]);
999        self.writer.push(0);
1000        self.writer.header("=> [a_1, a_0, z_exp_1, z_exp_0, ...]");
1001
1002        // Evaluate the periodic polynomial at point z**exp using Horner's algorithm
1003        for c in coef.iter().skip(1) {
1004            self.writer.header("duplicate z_exp");
1005            self.writer.dup(3);
1006            self.writer.dup(3);
1007            self.writer
1008                .header("=> [z_exp_1, z_exp_0, a_1, a_0, z_exp_1, z_exp_0, ...]");
1009
1010            self.writer.ext2mul();
1011            self.writer.push(*c);
1012            self.writer.push(0);
1013            self.writer.ext2add();
1014            self.writer.header("=> [a_1, a_0, z_exp_1, z_exp_0, ...]");
1015        }
1016
1017        self.writer.header("Clean z_exp from the stack");
1018        self.writer.movup(3);
1019        self.writer.movup(3);
1020        self.writer.drop();
1021        self.writer.drop();
1022        self.writer.header("=> [a_1, a_0, ...]");
1023
1024        self.writer.header(
1025            "Save the evaluation of the periodic polynomial at point z**exp, and clean stack",
1026        );
1027        let addr = self.config.periodic_values_address + self.periodic_column;
1028        self.writer.push(0);
1029        self.writer.push(0);
1030        self.writer.mem_storew(addr);
1031        self.writer.dropw();
1032
1033        self.periodic_column += 1;
1034        Ok(())
1035    }
1036
1037    fn visit_value(&mut self, value: &'ast Value) -> Result<Self::Value, Self::Error> {
1038        match value {
1039            Value::Constant(value) => {
1040                self.writer.push(*value);
1041                self.writer.push(0);
1042            }
1043            Value::TraceAccess(access) => {
1044                // eventually larger offsets will be supported
1045                if access.row_offset > 1 {
1046                    return Err(CodegenError::InvalidRowOffset);
1047                }
1048
1049                // Compute the target address for this variable. Each memory address contains the
1050                // curr and next values of a single variable.
1051                //
1052                // Layout defined at: https://github.com/0xPolygonMiden/miden-vm/issues/875
1053                let target_word: u32 = access
1054                    .column
1055                    .try_into()
1056                    .map_err(|_| CodegenError::InvalidIndex)?;
1057                let el_pos: u32 = access
1058                    .row_offset
1059                    .try_into()
1060                    .or(Err(CodegenError::InvalidIndex))?;
1061                let target_element = target_word * 2 + el_pos;
1062
1063                let base_address = if access.segment == MAIN_TRACE {
1064                    self.config.ood_frame_address
1065                } else {
1066                    self.config.ood_aux_frame_address
1067                };
1068
1069                load_quadratic_element(&mut self.writer, base_address, target_element)?;
1070            }
1071            Value::PeriodicColumn(access) => {
1072                let group: u32 = self
1073                    .periods
1074                    .iter()
1075                    .position(|&p| p == access.cycle)
1076                    .expect("All periods are added in the constructor")
1077                    .try_into()
1078                    .expect("periods are u32");
1079                load_quadratic_element(
1080                    &mut self.writer,
1081                    self.config.periodic_values_address,
1082                    periodic_group_to_memory_offset(group),
1083                )?;
1084            }
1085            Value::PublicInput(access) => {
1086                let start_offset = self
1087                    .public_input_to_offset
1088                    .get(&access.name)
1089                    .unwrap_or_else(|| panic!("public input {} unknown", access.name));
1090
1091                self.writer.header(format!(
1092                    "Load public input {} pos {} with final offset {}",
1093                    access.name, access.index, start_offset,
1094                ));
1095                let index: u32 = (start_offset + access.index)
1096                    .try_into()
1097                    .or(Err(CodegenError::InvalidIndex))?;
1098                load_quadratic_element(&mut self.writer, self.config.public_inputs_address, index)?;
1099            }
1100            Value::RandomValue(element) => {
1101                // Compute the target address for the random value. Each memory address contains
1102                // two values.
1103                //
1104                // Layout defined at: https://github.com/0xPolygonMiden/miden-vm/blob/next/stdlib/asm/crypto/stark/random_coin.masm#L169-L172
1105                load_quadratic_element(
1106                    &mut self.writer,
1107                    self.config.aux_rand_address,
1108                    (*element).try_into().or(Err(CodegenError::InvalidIndex))?,
1109                )?;
1110            }
1111        };
1112
1113        Ok(())
1114    }
1115}