Skip to main content

miden_processor/trace/chiplets/ace/
mod.rs

1use alloc::{collections::BTreeMap, vec::Vec};
2
3use miden_air::trace::{MainTrace, RowIndex, chiplets::ace::ACE_CHIPLET_NUM_COLS};
4use miden_core::{
5    Felt, ZERO,
6    field::{ExtensionField, PrimeCharacteristicRing},
7};
8
9use crate::trace::TraceFragment;
10
11mod trace;
12pub use trace::{CircuitEvaluation, NUM_ACE_LOGUP_FRACTIONS_EVAL, NUM_ACE_LOGUP_FRACTIONS_READ};
13
14mod instruction;
15#[cfg(test)]
16mod tests;
17
18pub const PTR_OFFSET_ELEM: Felt = Felt::ONE;
19pub const PTR_OFFSET_WORD: Felt = Felt::new(4);
20pub const MAX_NUM_ACE_WIRES: u32 = instruction::MAX_ID;
21
22/// Arithmetic circuit evaluation (ACE) chiplet.
23///
24/// This is a VM chiplet used to evaluate arithmetic circuits given some input, which is equivalent
25/// to evaluating some multi-variate polynomial at a tuple representing the input.
26///
27/// During the course of the VM execution, we keep track of all calls to the ACE chiplet in an
28/// [`CircuitEvaluation`] per call. This is then used to generate the full trace of the ACE chiplet.
29#[derive(Debug, Default)]
30pub struct Ace {
31    circuit_evaluations: BTreeMap<RowIndex, CircuitEvaluation>,
32}
33
34impl Ace {
35    /// Gets the total trace length of the ACE chiplet.
36    pub(crate) fn trace_len(&self) -> usize {
37        self.circuit_evaluations.values().map(|eval_ctx| eval_ctx.num_rows()).sum()
38    }
39
40    /// Fills the portion of the main trace allocated to the ACE chiplet.
41    ///
42    /// This also returns helper data needed for generating the part of the auxiliary trace
43    /// associated with the ACE chiplet.
44    pub(crate) fn fill_trace(self, trace: &mut TraceFragment) -> Vec<EvaluatedCircuitsMetadata> {
45        // make sure fragment dimensions are consistent with the dimensions of this trace
46        debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
47        debug_assert_eq!(ACE_CHIPLET_NUM_COLS, trace.width(), "inconsistent trace widths");
48
49        let mut gen_trace: [Vec<Felt>; ACE_CHIPLET_NUM_COLS] = (0..ACE_CHIPLET_NUM_COLS)
50            .map(|_| vec![ZERO; self.trace_len()])
51            .collect::<Vec<_>>()
52            .try_into()
53            .expect("failed to convert vector to array");
54
55        let mut sections_info = Vec::with_capacity(self.circuit_evaluations.keys().count());
56
57        let mut offset = 0;
58        for eval_ctx in self.circuit_evaluations.into_values() {
59            eval_ctx.fill(offset, &mut gen_trace);
60            offset += eval_ctx.num_rows();
61            let section = EvaluatedCircuitsMetadata::from_evaluation_context(&eval_ctx);
62            sections_info.push(section);
63        }
64
65        for (out_column, column) in trace.columns().zip(gen_trace) {
66            out_column.copy_from_slice(&column);
67        }
68
69        sections_info
70    }
71
72    /// Adds an entry resulting from a call to the ACE chiplet.
73    pub(crate) fn add_circuit_evaluation(
74        &mut self,
75        clk: RowIndex,
76        circuit_eval: CircuitEvaluation,
77    ) {
78        self.circuit_evaluations.insert(clk, circuit_eval);
79    }
80}
81
82/// Stores metadata associated to an evaluated circuit needed for building the portion of the
83/// auxiliary trace segment relevant for the ACE chiplet.
84#[derive(Debug, Default, Clone)]
85pub struct EvaluatedCircuitsMetadata {
86    ctx: u32,
87    clk: u32,
88    num_vars: u32,
89    num_evals: u32,
90}
91
92impl EvaluatedCircuitsMetadata {
93    pub fn clk(&self) -> u32 {
94        self.clk
95    }
96
97    pub fn ctx(&self) -> u32 {
98        self.ctx
99    }
100
101    pub fn num_vars(&self) -> u32 {
102        self.num_vars
103    }
104
105    pub fn num_evals(&self) -> u32 {
106        self.num_evals
107    }
108
109    fn from_evaluation_context(eval_ctx: &CircuitEvaluation) -> EvaluatedCircuitsMetadata {
110        EvaluatedCircuitsMetadata {
111            ctx: eval_ctx.ctx(),
112            clk: eval_ctx.clk(),
113            num_vars: eval_ctx.num_read_rows(),
114            num_evals: eval_ctx.num_eval_rows(),
115        }
116    }
117}
118
119/// Stores metadata for the ACE chiplet useful when building the portion of the auxiliary
120/// trace segment relevant for the ACE chiplet.
121///
122/// This data is already present in the main trace but collecting it here allows us to simplify
123/// the logic for building the auxiliary segment portion for the ACE chiplet.
124/// For example, we know that `clk` and `ctx` are constant throughout each circuit evaluation
125/// and we also know the exact number of ACE chiplet rows per circuit evaluation and the exact
126/// number of rows per `READ` and `EVAL` portions, which allows us to avoid the need to compute
127/// selectors as part of the logic of auxiliary trace generation.
128#[derive(Clone, Debug, Default)]
129pub struct AceHints {
130    offset_chiplet_trace: usize,
131    pub sections: Vec<EvaluatedCircuitsMetadata>,
132}
133
134impl AceHints {
135    pub fn new(offset_chiplet_trace: usize, sections: Vec<EvaluatedCircuitsMetadata>) -> Self {
136        Self { offset_chiplet_trace, sections }
137    }
138
139    pub(crate) fn offset(&self) -> usize {
140        self.offset_chiplet_trace
141    }
142
143    pub(crate) fn build_divisors<E: ExtensionField<Felt>>(
144        &self,
145        main_trace: &MainTrace,
146        alphas: &[E],
147    ) -> Vec<E> {
148        let num_fractions = self.num_fractions();
149        let mut total_values = vec![E::ZERO; num_fractions];
150        let mut total_inv_values = vec![E::ZERO; num_fractions];
151
152        let mut chiplet_offset = self.offset_chiplet_trace;
153        let mut values_offset = 0;
154        let mut acc = E::ONE;
155        for section in self.sections.iter() {
156            let clk = section.clk();
157            let ctx = section.ctx();
158
159            let values = &mut total_values[values_offset
160                ..values_offset + NUM_ACE_LOGUP_FRACTIONS_READ * section.num_vars() as usize];
161            let inv_values = &mut total_inv_values[values_offset
162                ..values_offset + NUM_ACE_LOGUP_FRACTIONS_READ * section.num_vars() as usize];
163
164            // read section
165            for (i, (value, inv_value)) in values
166                .chunks_mut(NUM_ACE_LOGUP_FRACTIONS_READ)
167                .zip(inv_values.chunks_mut(NUM_ACE_LOGUP_FRACTIONS_READ))
168                .enumerate()
169            {
170                let trace_row = i + chiplet_offset;
171
172                let wire_0 = main_trace.chiplet_ace_wire_0(trace_row.into());
173                let wire_1 = main_trace.chiplet_ace_wire_1(trace_row.into());
174
175                let value_0 = alphas[0]
176                    + alphas[1] * Felt::from_u32(clk)
177                    + alphas[2] * Felt::from_u32(ctx)
178                    + alphas[3] * wire_0[0]
179                    + alphas[4] * wire_0[1]
180                    + alphas[5] * wire_0[2];
181                let value_1 = alphas[0]
182                    + alphas[1] * Felt::from_u32(clk)
183                    + alphas[2] * Felt::from_u32(ctx)
184                    + alphas[3] * wire_1[0]
185                    + alphas[4] * wire_1[1]
186                    + alphas[5] * wire_1[2];
187
188                value[0] = value_0;
189                value[1] = value_1;
190                inv_value[0] = acc;
191                acc *= value_0;
192                inv_value[1] = acc;
193                acc *= value_1;
194            }
195
196            chiplet_offset += section.num_vars() as usize;
197            values_offset += NUM_ACE_LOGUP_FRACTIONS_READ * section.num_vars() as usize;
198
199            // eval section
200            let values = &mut total_values[values_offset
201                ..values_offset + NUM_ACE_LOGUP_FRACTIONS_EVAL * section.num_evals() as usize];
202            let inv_values = &mut total_inv_values[values_offset
203                ..values_offset + NUM_ACE_LOGUP_FRACTIONS_EVAL * section.num_evals() as usize];
204            for (i, (value, inv_value)) in values
205                .chunks_mut(NUM_ACE_LOGUP_FRACTIONS_EVAL)
206                .zip(inv_values.chunks_mut(NUM_ACE_LOGUP_FRACTIONS_EVAL))
207                .enumerate()
208            {
209                let trace_row = i + chiplet_offset;
210
211                let wire_0 = main_trace.chiplet_ace_wire_0(trace_row.into());
212                let wire_1 = main_trace.chiplet_ace_wire_1(trace_row.into());
213                let wire_2 = main_trace.chiplet_ace_wire_2(trace_row.into());
214
215                let value_0 = alphas[0]
216                    + alphas[1] * Felt::from_u32(clk)
217                    + alphas[2] * Felt::from_u32(ctx)
218                    + alphas[3] * wire_0[0]
219                    + alphas[4] * wire_0[1]
220                    + alphas[5] * wire_0[2];
221
222                let value_1 = alphas[0]
223                    + alphas[1] * Felt::from_u32(clk)
224                    + alphas[2] * Felt::from_u32(ctx)
225                    + alphas[3] * wire_1[0]
226                    + alphas[4] * wire_1[1]
227                    + alphas[5] * wire_1[2];
228
229                let value_2 = alphas[0]
230                    + alphas[1] * Felt::from_u32(clk)
231                    + alphas[2] * Felt::from_u32(ctx)
232                    + alphas[3] * wire_2[0]
233                    + alphas[4] * wire_2[1]
234                    + alphas[5] * wire_2[2];
235
236                value[0] = value_0;
237                value[1] = value_1;
238                value[2] = value_2;
239                inv_value[0] = acc;
240                acc *= value_0;
241                inv_value[1] = acc;
242                acc *= value_1;
243                inv_value[2] = acc;
244                acc *= value_2;
245            }
246
247            chiplet_offset += section.num_evals() as usize;
248            values_offset += NUM_ACE_LOGUP_FRACTIONS_EVAL * section.num_evals() as usize;
249        }
250
251        // invert the accumulated product
252        acc = acc.inverse();
253
254        for i in (0..total_values.len()).rev() {
255            total_inv_values[i] *= acc;
256            acc *= total_values[i];
257        }
258
259        total_inv_values
260    }
261
262    fn num_fractions(&self) -> usize {
263        self.sections
264            .iter()
265            .map(|section| {
266                NUM_ACE_LOGUP_FRACTIONS_READ * (section.num_vars as usize)
267                    + NUM_ACE_LOGUP_FRACTIONS_EVAL * (section.num_evals as usize)
268            })
269            .sum()
270    }
271}