miden_processor/trace/chiplets/ace/
trace.rs1use alloc::vec::Vec;
2use core::borrow::BorrowMut;
3
4use miden_air::{
5 AceCols, QuadFeltExpr,
6 trace::{RowIndex, chiplets::ace::ACE_CHIPLET_NUM_COLS},
7};
8use miden_core::{
9 Felt, Word,
10 field::{BasedVectorSpace, QuadFelt},
11};
12
13use super::{
14 MAX_NUM_ACE_WIRES,
15 instruction::{Op, decode_instruction},
16};
17use crate::{ContextId, errors::AceError};
18
19#[derive(Debug, Clone, Copy)]
22struct ReadNode {
23 ptr: Felt,
24 id_0: Felt,
25 v_0: QuadFelt,
26 id_1: Felt,
27 v_1: QuadFelt,
28}
29
30#[derive(Debug, Clone, Copy)]
34struct EvalNode {
35 ptr: Felt,
36 eval_op: Felt,
37 id_0: Felt,
38 v_0: QuadFelt,
39 id_1: Felt,
40 v_1: QuadFelt,
41 id_2: Felt,
42 v_2: QuadFelt,
43}
44
45#[derive(Debug, Clone)]
50pub struct CircuitEvaluation {
51 ctx: ContextId,
52 clk: RowIndex,
53 wire_bus: WireBus,
54 read_nodes: Vec<ReadNode>,
55 eval_nodes: Vec<EvalNode>,
56}
57
58impl CircuitEvaluation {
59 pub fn new(ctx: ContextId, clk: RowIndex, num_read_rows: u32, num_eval_rows: u32) -> Self {
66 let num_wires = 2 * (num_read_rows as u64) + (num_eval_rows as u64);
67 assert!(num_wires <= MAX_NUM_ACE_WIRES as u64, "too many wires");
68
69 Self {
70 ctx,
71 clk,
72 wire_bus: WireBus::new(num_wires as u32),
73 read_nodes: Vec::with_capacity(num_read_rows as usize),
74 eval_nodes: Vec::with_capacity(num_eval_rows as usize),
75 }
76 }
77
78 pub fn num_rows(&self) -> usize {
79 self.read_nodes.len() + self.eval_nodes.len()
80 }
81
82 pub fn clk(&self) -> u32 {
83 self.clk.into()
84 }
85
86 pub fn ctx(&self) -> u32 {
87 self.ctx.into()
88 }
89
90 pub fn num_read_rows(&self) -> u32 {
91 self.read_nodes.len() as u32
92 }
93
94 pub fn num_eval_rows(&self) -> u32 {
95 self.eval_nodes.len() as u32
96 }
97
98 pub fn do_read(&mut self, ptr: Felt, word: Word) {
101 let v_0 = QuadFelt::from_basis_coefficients_fn(|i: usize| [word[0], word[1]][i]);
102 let id_0 = self.wire_bus.insert(v_0);
103
104 let v_1 = QuadFelt::from_basis_coefficients_fn(|i: usize| [word[2], word[3]][i]);
105 let id_1 = self.wire_bus.insert(v_1);
106
107 self.read_nodes.push(ReadNode { ptr, id_0, v_0, id_1, v_1 });
108 }
109
110 pub fn do_eval(&mut self, ptr: Felt, instruction: Felt) -> Result<(), AceError> {
113 let (id_l, id_r, op) = decode_instruction(instruction)
114 .ok_or(AceError("failed to decode instruction".into()))?;
115
116 let v_l = self
117 .wire_bus
118 .read_value(id_l)
119 .ok_or(AceError("failed to read from the wiring bus".into()))?;
120 let id_1 = Felt::from_u32(id_l);
121
122 let v_r = self
123 .wire_bus
124 .read_value(id_r)
125 .ok_or(AceError("failed to read from the wiring bus".into()))?;
126 let id_2 = Felt::from_u32(id_r);
127
128 let v_0 = match op {
129 Op::Sub => v_l - v_r,
130 Op::Mul => v_l * v_r,
131 Op::Add => v_l + v_r,
132 };
133 let id_0 = self.wire_bus.insert(v_0);
134
135 let eval_op = match op {
136 Op::Sub => -Felt::ONE,
137 Op::Mul => Felt::ZERO,
138 Op::Add => Felt::ONE,
139 };
140
141 self.eval_nodes.push(EvalNode {
142 ptr,
143 eval_op,
144 id_0,
145 v_0,
146 id_1,
147 v_1: v_l,
148 id_2,
149 v_2: v_r,
150 });
151 Ok(())
152 }
153
154 pub fn fill(&self, offset: usize, out: &mut [Felt]) {
158 const W: usize = ACE_CHIPLET_NUM_COLS;
159 let (out_rows, _) = out.as_chunks_mut::<W>();
160 let num_read_rows = self.read_nodes.len();
161 let num_eval_rows = self.eval_nodes.len();
162
163 let ctx_felt: Felt = self.ctx.into();
164 let clk_felt: Felt = self.clk.into();
165 let eval_section_first_idx = Felt::from_u32(num_eval_rows as u32 - 1);
166 let mut multiplicities_iter = self.wire_bus.wires.iter().map(|(_v, m)| Felt::from_u32(*m));
167
168 for (i, node) in self.read_nodes.iter().enumerate() {
170 let cols: &mut AceCols<Felt> = out_rows[offset + i].as_mut_slice().borrow_mut();
171 cols.s_start = if i == 0 { Felt::ONE } else { Felt::ZERO };
172 cols.s_block = Felt::ZERO;
173 cols.ctx = ctx_felt;
174 cols.clk = clk_felt;
175 cols.ptr = node.ptr;
176 cols.id_0 = node.id_0;
177 cols.v_0 = quad_to_expr(node.v_0);
178 cols.id_1 = node.id_1;
179 cols.v_1 = quad_to_expr(node.v_1);
180
181 let m_0 = multiplicities_iter
182 .next()
183 .expect("the m0 multiplicities were not constructed properly");
184 let m_1 = multiplicities_iter
185 .next()
186 .expect("the m1 multiplicities were not constructed properly");
187
188 let read = cols.read_mut();
189 read.num_eval = eval_section_first_idx;
190 read.m_0 = m_0;
191 read.m_1 = m_1;
192 }
193
194 for (i, node) in self.eval_nodes.iter().enumerate() {
196 let cols: &mut AceCols<Felt> =
197 out_rows[offset + num_read_rows + i].as_mut_slice().borrow_mut();
198 cols.s_start = Felt::ZERO;
199 cols.s_block = Felt::ONE;
200 cols.ctx = ctx_felt;
201 cols.clk = clk_felt;
202 cols.ptr = node.ptr;
203 cols.eval_op = node.eval_op;
204 cols.id_0 = node.id_0;
205 cols.v_0 = quad_to_expr(node.v_0);
206 cols.id_1 = node.id_1;
207 cols.v_1 = quad_to_expr(node.v_1);
208
209 let m_0 = multiplicities_iter
210 .next()
211 .expect("the m0 multiplicities were not constructed properly");
212
213 let eval = cols.eval_mut();
214 eval.id_2 = node.id_2;
215 eval.v_2 = quad_to_expr(node.v_2);
216 eval.m_0 = m_0;
217 }
218
219 debug_assert!(multiplicities_iter.next().is_none());
220 }
221
222 pub fn output_value(&self) -> Option<QuadFelt> {
224 if !self.wire_bus.is_finalized() {
225 return None;
226 }
227 self.wire_bus.wires.last().map(|(v, _m)| *v)
228 }
229}
230
231fn quad_to_expr(v: QuadFelt) -> QuadFeltExpr<Felt> {
234 let c = v.as_basis_coefficients_slice();
235 QuadFeltExpr(c[0], c[1])
236}
237
238#[derive(Debug, Clone)]
247struct WireBus {
248 id_next: Felt,
250 wires: Vec<(QuadFelt, u32)>,
253 num_wires: u32,
255}
256
257impl WireBus {
258 fn new(num_wires: u32) -> Self {
259 Self {
260 wires: Vec::with_capacity(num_wires as usize),
261 num_wires,
262 id_next: Felt::from_u32(num_wires - 1),
263 }
264 }
265
266 fn insert(&mut self, value: QuadFelt) -> Felt {
268 debug_assert!(!self.is_finalized());
269 self.wires.push((value, 0));
270 let id = self.id_next;
271 self.id_next -= Felt::ONE;
272 id
273 }
274
275 fn read_value(&mut self, id: u32) -> Option<QuadFelt> {
278 let (v, m) = self
280 .num_wires
281 .checked_sub(id + 1)
282 .and_then(|id| self.wires.get_mut(id as usize))?;
283 *m += 1;
284 Some(*v)
285 }
286
287 fn is_finalized(&self) -> bool {
289 self.wires.len() == self.num_wires as usize
290 }
291}