1use alloc::{boxed::Box, vec::Vec};
2
3use itertools::Itertools;
4use miden_air::{
5 Felt, RowIndex,
6 trace::{
7 CLK_COL_IDX, CTX_COL_IDX, DECODER_TRACE_OFFSET, DECODER_TRACE_WIDTH, FN_HASH_RANGE,
8 MIN_TRACE_LEN, PADDED_TRACE_WIDTH, STACK_TRACE_OFFSET, STACK_TRACE_WIDTH, SYS_TRACE_WIDTH,
9 TRACE_WIDTH,
10 decoder::{
11 ADDR_COL_IDX, GROUP_COUNT_COL_IDX, HASHER_STATE_OFFSET, IN_SPAN_COL_IDX,
12 NUM_HASHER_COLUMNS, NUM_OP_BATCH_FLAGS, NUM_OP_BITS, OP_BATCH_FLAGS_OFFSET,
13 OP_BITS_EXTRA_COLS_OFFSET, OP_BITS_OFFSET, OP_INDEX_COL_IDX,
14 },
15 main_trace::MainTrace,
16 stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX, STACK_TOP_OFFSET},
17 },
18};
19use miden_core::{
20 Kernel, ONE, Operation, Word, ZERO, stack::MIN_STACK_DEPTH, utils::uninit_vector,
21};
22use rayon::prelude::*;
23use winter_prover::{crypto::RandomCoin, math::batch_inversion};
24
25use crate::{
26 ChipletsLengths, ColMatrix, ContextId, ExecutionTrace, TraceLenSummary,
27 chiplets::Chiplets,
28 crypto::RpoRandomCoin,
29 decoder::AuxTraceBuilder as DecoderAuxTraceBuilder,
30 fast::{
31 ExecutionOutput,
32 execution_tracer::TraceGenerationContext,
33 trace_state::{
34 AceReplay, BitwiseOp, BitwiseReplay, CoreTraceFragmentContext, HasherOp,
35 HasherRequestReplay, KernelReplay, MemoryWritesReplay,
36 },
37 },
38 parallel::core_trace_fragment::{CoreTraceFragment, CoreTraceFragmentFiller},
39 range::RangeChecker,
40 stack::AuxTraceBuilder as StackAuxTraceBuilder,
41 trace::{AuxTraceBuilders, NUM_RAND_ROWS},
42};
43
44pub const CORE_TRACE_WIDTH: usize = SYS_TRACE_WIDTH + DECODER_TRACE_WIDTH + STACK_TRACE_WIDTH;
45
46mod core_trace_fragment;
47
48#[cfg(test)]
49mod tests;
50
51pub fn build_trace(
56 execution_output: ExecutionOutput,
57 trace_generation_context: TraceGenerationContext,
58 program_hash: Word,
59 kernel: Kernel,
60) -> ExecutionTrace {
61 let TraceGenerationContext {
62 core_trace_contexts,
63 range_checker_replay,
64 memory_writes,
65 bitwise_replay: bitwise,
66 kernel_replay,
67 hasher_for_chiplet,
68 ace_replay,
69 final_pc_transcript,
70 fragment_size,
71 } = trace_generation_context;
72
73 let chiplets = initialize_chiplets(
74 kernel.clone(),
75 &core_trace_contexts,
76 memory_writes,
77 bitwise,
78 kernel_replay,
79 hasher_for_chiplet,
80 ace_replay,
81 );
82
83 let range_checker = initialize_range_checker(range_checker_replay, &chiplets);
84
85 let mut core_trace_columns = generate_core_trace_columns(core_trace_contexts, fragment_size);
86
87 let core_trace_len = {
89 let core_trace_len: usize = core_trace_columns[0].len();
90
91 core_trace_len - 1
97 };
98
99 let range_table_len = range_checker.get_number_range_checker_rows();
101
102 let trace_len_summary =
103 TraceLenSummary::new(core_trace_len, range_table_len, ChipletsLengths::new(&chiplets));
104
105 let main_trace_len =
107 compute_main_trace_length(core_trace_len, range_table_len, chiplets.trace_len());
108
109 let ((), (range_checker_trace, chiplets_trace)) = rayon::join(
110 || pad_trace_columns(&mut core_trace_columns, main_trace_len),
111 || {
112 rayon::join(
113 || {
114 range_checker.into_trace_with_table(
115 range_table_len,
116 main_trace_len,
117 NUM_RAND_ROWS,
118 )
119 },
120 || chiplets.into_trace(main_trace_len, NUM_RAND_ROWS, final_pc_transcript.state()),
121 )
122 },
123 );
124
125 let padding_columns = vec![vec![ZERO; main_trace_len]; PADDED_TRACE_WIDTH - TRACE_WIDTH];
127
128 let mut trace_columns: Vec<Vec<Felt>> = core_trace_columns
130 .into_iter()
131 .chain(range_checker_trace.trace)
132 .chain(chiplets_trace.trace)
133 .chain(padding_columns)
134 .collect();
135
136 let mut rng = RpoRandomCoin::new(program_hash);
138
139 for i in main_trace_len - NUM_RAND_ROWS..main_trace_len {
141 for column in trace_columns.iter_mut() {
142 column[i] = rng.draw().expect("failed to draw a random value");
143 }
144 }
145
146 let main_trace = {
148 let last_program_row = RowIndex::from((core_trace_len as u32).saturating_sub(1));
149 let col_matrix = ColMatrix::new(trace_columns);
150 MainTrace::new(col_matrix, last_program_row)
151 };
152
153 let aux_trace_builders = AuxTraceBuilders {
155 decoder: DecoderAuxTraceBuilder::default(),
156 range: range_checker_trace.aux_builder,
157 chiplets: chiplets_trace.aux_builder,
158 stack: StackAuxTraceBuilder,
159 };
160
161 ExecutionTrace::new_from_parts(
162 program_hash,
163 kernel,
164 execution_output,
165 main_trace,
166 aux_trace_builders,
167 trace_len_summary,
168 )
169}
170
171fn compute_main_trace_length(
175 core_trace_len: usize,
176 range_table_len: usize,
177 chiplets_trace_len: usize,
178) -> usize {
179 let max_len = range_table_len.max(core_trace_len).max(chiplets_trace_len);
181
182 let trace_len = (max_len + NUM_RAND_ROWS).next_power_of_two();
185 core::cmp::max(trace_len, MIN_TRACE_LEN)
186}
187
188fn generate_core_trace_columns(
190 core_trace_contexts: Vec<CoreTraceFragmentContext>,
191 fragment_size: usize,
192) -> Vec<Vec<Felt>> {
193 let mut core_trace_columns: Vec<Vec<Felt>> =
194 unsafe { vec![uninit_vector(core_trace_contexts.len() * fragment_size); CORE_TRACE_WIDTH] };
195
196 let first_stack_top = if let Some(first_context) = core_trace_contexts.first() {
198 first_context.state.stack.stack_top.to_vec()
199 } else {
200 vec![ZERO; MIN_STACK_DEPTH]
201 };
202
203 let mut fragments = create_fragments_from_trace_columns(&mut core_trace_columns, fragment_size);
204
205 let fragment_results: Vec<([Felt; STACK_TRACE_WIDTH], [Felt; SYS_TRACE_WIDTH], usize)> =
207 core_trace_contexts
208 .into_par_iter()
209 .zip(fragments.par_iter_mut())
210 .map(|(trace_state, fragment)| {
211 let core_trace_fragment_filler =
212 CoreTraceFragmentFiller::new(trace_state, fragment);
213 core_trace_fragment_filler.fill_fragment()
214 })
215 .collect();
216
217 let mut stack_rows = Vec::new();
219 let mut system_rows = Vec::new();
220 let mut total_core_trace_rows = 0;
221
222 for (stack_row, system_row, num_rows_written) in fragment_results {
223 stack_rows.push(stack_row);
224 system_rows.push(system_row);
225 total_core_trace_rows += num_rows_written;
226 }
227
228 fixup_stack_and_system_rows(
230 &mut core_trace_columns,
231 fragment_size,
232 &stack_rows,
233 &system_rows,
234 &first_stack_top,
235 );
236
237 for col in core_trace_columns.iter_mut() {
239 col.truncate(total_core_trace_rows);
240 }
241
242 push_halt_opcode_row(
243 &mut core_trace_columns,
244 system_rows.last().expect(
245 "system_rows should not be empty, which indicates that there are no trace fragments",
246 ),
247 stack_rows.last().expect(
248 "stack_rows should not be empty, which indicates that there are no trace fragments",
249 ),
250 );
251
252 core_trace_columns[STACK_TRACE_OFFSET + H0_COL_IDX] =
254 batch_inversion(&core_trace_columns[STACK_TRACE_OFFSET + H0_COL_IDX]);
255
256 core_trace_columns
257}
258
259fn create_fragments_from_trace_columns(
262 core_trace_columns: &mut [Vec<Felt>],
263 fragment_size: usize,
264) -> Vec<CoreTraceFragment<'_>> {
265 let mut column_chunks: Vec<_> = core_trace_columns
266 .iter_mut()
267 .map(|col| col.chunks_exact_mut(fragment_size))
268 .collect();
269 let mut core_trace_fragments = Vec::new();
270
271 loop {
272 let fragment_cols: Vec<&mut [Felt]> =
273 column_chunks.iter_mut().filter_map(|col_chunk| col_chunk.next()).collect();
274 assert!(
275 fragment_cols.is_empty() || fragment_cols.len() == CORE_TRACE_WIDTH,
276 "column chunks don't all have the same size"
277 );
278
279 if fragment_cols.is_empty() {
280 return core_trace_fragments;
281 } else {
282 core_trace_fragments.push(CoreTraceFragment {
283 columns: fragment_cols.try_into().expect("fragment has CORE_TRACE_WIDTH columns"),
284 });
285 }
286 }
287}
288
289fn fixup_stack_and_system_rows(
296 core_trace_columns: &mut [Vec<Felt>],
297 fragment_size: usize,
298 stack_rows: &[[Felt; STACK_TRACE_WIDTH]],
299 system_rows: &[[Felt; SYS_TRACE_WIDTH]],
300 first_stack_top: &[Felt],
301) {
302 const MIN_STACK_DEPTH_FELT: Felt = Felt::new(MIN_STACK_DEPTH as u64);
303
304 let system_state_first_row = [
305 ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ];
312
313 {
315 for (col_idx, &value) in system_state_first_row.iter().enumerate() {
317 core_trace_columns[col_idx][0] = value;
318 }
319
320 for (stack_col_idx, &value) in first_stack_top.iter().rev().enumerate() {
324 core_trace_columns[STACK_TRACE_OFFSET + STACK_TOP_OFFSET + stack_col_idx][0] = value;
325 }
326
327 core_trace_columns[STACK_TRACE_OFFSET + B0_COL_IDX][0] = MIN_STACK_DEPTH_FELT;
329 core_trace_columns[STACK_TRACE_OFFSET + B1_COL_IDX][0] = ZERO;
330 core_trace_columns[STACK_TRACE_OFFSET + H0_COL_IDX][0] = ZERO;
331 }
332
333 let fragment_start_row_indices = {
336 let num_fragments = core_trace_columns[0].len() / fragment_size;
337
338 (0..).step_by(fragment_size).take(num_fragments).skip(1)
339 };
340
341 for (row_idx, (system_row, stack_row)) in
343 fragment_start_row_indices.zip(system_rows.iter().zip(stack_rows.iter()))
344 {
345 for (col_idx, &value) in system_row.iter().enumerate() {
347 core_trace_columns[col_idx][row_idx] = value;
348 }
349
350 for (col_idx, &value) in stack_row.iter().enumerate() {
352 core_trace_columns[STACK_TRACE_OFFSET + col_idx][row_idx] = value;
353 }
354 }
355}
356
357fn push_halt_opcode_row(
362 core_trace_columns: &mut [Vec<Felt>],
363 last_system_state: &[Felt; SYS_TRACE_WIDTH],
364 last_stack_state: &[Felt; STACK_TRACE_WIDTH],
365) {
366 for (col_idx, &value) in last_system_state.iter().enumerate() {
369 core_trace_columns[col_idx].push(value);
370 }
371
372 for (col_idx, &value) in last_stack_state.iter().enumerate() {
375 core_trace_columns[STACK_TRACE_OFFSET + col_idx].push(value);
376 }
377
378 core_trace_columns[DECODER_TRACE_OFFSET + ADDR_COL_IDX].push(ZERO);
382
383 let halt_opcode = Operation::Halt.op_code();
385 for bit_idx in 0..NUM_OP_BITS {
386 let bit_value = Felt::from((halt_opcode >> bit_idx) & 1);
387 core_trace_columns[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + bit_idx].push(bit_value);
388 }
389
390 for hasher_col_idx in 0..NUM_HASHER_COLUMNS {
394 let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + hasher_col_idx;
395 if hasher_col_idx < 4 {
396 let last_row_idx = core_trace_columns[col_idx].len() - 1;
398 let last_hasher_value = core_trace_columns[col_idx][last_row_idx];
399 core_trace_columns[col_idx].push(last_hasher_value);
400 } else {
401 core_trace_columns[col_idx].push(ZERO);
403 }
404 }
405
406 core_trace_columns[DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX].push(ZERO);
408
409 core_trace_columns[DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX].push(ZERO);
411
412 core_trace_columns[DECODER_TRACE_OFFSET + OP_INDEX_COL_IDX].push(ZERO);
414
415 for batch_flag_idx in 0..NUM_OP_BATCH_FLAGS {
417 let col_idx = DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET + batch_flag_idx;
418 core_trace_columns[col_idx].push(ZERO);
419 }
420
421 core_trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET].push(ZERO);
425 core_trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1].push(ONE);
426}
427
428fn initialize_range_checker(
429 range_checker_replay: crate::fast::trace_state::RangeCheckerReplay,
430 chiplets: &Chiplets,
431) -> RangeChecker {
432 let mut range_checker = RangeChecker::new();
433
434 for (clk, values) in range_checker_replay.into_iter() {
436 range_checker.add_range_checks(clk, &values);
437 }
438
439 chiplets.append_range_checks(&mut range_checker);
441
442 range_checker
443}
444
445fn initialize_chiplets(
446 kernel: Kernel,
447 core_trace_contexts: &[CoreTraceFragmentContext],
448 memory_writes: MemoryWritesReplay,
449 bitwise: BitwiseReplay,
450 kernel_replay: KernelReplay,
451 hasher_for_chiplet: HasherRequestReplay,
452 ace_replay: AceReplay,
453) -> Chiplets {
454 let mut chiplets = Chiplets::new(kernel);
455
456 for hasher_op in hasher_for_chiplet.into_iter() {
458 match hasher_op {
459 HasherOp::Permute(input_state) => {
460 chiplets.hasher.permute(input_state);
461 },
462 HasherOp::HashControlBlock((h1, h2, domain, expected_hash)) => {
463 chiplets.hasher.hash_control_block(h1, h2, domain, expected_hash);
464 },
465 HasherOp::HashBasicBlock((op_batches, expected_hash)) => {
466 chiplets.hasher.hash_basic_block(&op_batches, expected_hash);
467 },
468 HasherOp::BuildMerkleRoot((value, path, index)) => {
469 chiplets.hasher.build_merkle_root(value, &path, index);
470 },
471 HasherOp::UpdateMerkleRoot((old_value, new_value, path, index)) => {
472 chiplets.hasher.update_merkle_root(old_value, new_value, &path, index);
473 },
474 }
475 }
476
477 for (bitwise_op, a, b) in bitwise {
479 match bitwise_op {
480 BitwiseOp::U32And => {
481 chiplets
482 .bitwise
483 .u32and(a, b, &())
484 .expect("bitwise AND operation failed when populating chiplet");
485 },
486 BitwiseOp::U32Xor => {
487 chiplets
488 .bitwise
489 .u32xor(a, b, &())
490 .expect("bitwise XOR operation failed when populating chiplet");
491 },
492 }
493 }
494
495 {
501 let elements_written: Box<dyn Iterator<Item = MemoryAccess>> =
502 Box::new(memory_writes.iter_elements_written().map(|(element, addr, ctx, clk)| {
503 MemoryAccess::WriteElement(*addr, *element, *ctx, *clk)
504 }));
505 let words_written: Box<dyn Iterator<Item = MemoryAccess>> = Box::new(
506 memory_writes
507 .iter_words_written()
508 .map(|(word, addr, ctx, clk)| MemoryAccess::WriteWord(*addr, *word, *ctx, *clk)),
509 );
510 let elements_read: Box<dyn Iterator<Item = MemoryAccess>> =
511 Box::new(core_trace_contexts.iter().flat_map(|ctx| {
512 ctx.replay
513 .memory_reads
514 .iter_read_elements()
515 .map(|(_, addr, ctx, clk)| MemoryAccess::ReadElement(addr, ctx, clk))
516 }));
517 let words_read: Box<dyn Iterator<Item = MemoryAccess>> =
518 Box::new(core_trace_contexts.iter().flat_map(|ctx| {
519 ctx.replay
520 .memory_reads
521 .iter_read_words()
522 .map(|(_, addr, ctx, clk)| MemoryAccess::ReadWord(addr, ctx, clk))
523 }));
524
525 [elements_written, words_written, elements_read, words_read]
526 .into_iter()
527 .kmerge_by(|a, b| a.clk() < b.clk())
528 .for_each(|mem_access| match mem_access {
529 MemoryAccess::ReadElement(addr, ctx, clk) => {
530 chiplets
531 .memory
532 .read(ctx, addr, clk, &())
533 .expect("memory read element failed when populating chiplet");
534 },
535 MemoryAccess::WriteElement(addr, element, ctx, clk) => {
536 chiplets
537 .memory
538 .write(ctx, addr, clk, element, &())
539 .expect("memory write element failed when populating chiplet");
540 },
541 MemoryAccess::ReadWord(addr, ctx, clk) => {
542 chiplets
543 .memory
544 .read_word(ctx, addr, clk, &())
545 .expect("memory read word failed when populating chiplet");
546 },
547 MemoryAccess::WriteWord(addr, word, ctx, clk) => {
548 chiplets
549 .memory
550 .write_word(ctx, addr, clk, word, &())
551 .expect("memory write word failed when populating chiplet");
552 },
553 });
554
555 enum MemoryAccess {
556 ReadElement(Felt, ContextId, RowIndex),
557 WriteElement(Felt, Felt, ContextId, RowIndex),
558 ReadWord(Felt, ContextId, RowIndex),
559 WriteWord(Felt, Word, ContextId, RowIndex),
560 }
561
562 impl MemoryAccess {
563 fn clk(&self) -> RowIndex {
564 match self {
565 MemoryAccess::ReadElement(_, _, clk) => *clk,
566 MemoryAccess::WriteElement(_, _, _, clk) => *clk,
567 MemoryAccess::ReadWord(_, _, clk) => *clk,
568 MemoryAccess::WriteWord(_, _, _, clk) => *clk,
569 }
570 }
571 }
572 }
573
574 for (clk, circuit_eval) in ace_replay.into_iter() {
576 chiplets.ace.add_circuit_evaluation(clk, circuit_eval);
577 }
578
579 for proc_hash in kernel_replay.into_iter() {
581 chiplets
582 .kernel_rom
583 .access_proc(proc_hash, &())
584 .expect("kernel proc access failed when populating chiplet");
585 }
586
587 chiplets
588}
589
590fn pad_trace_columns(trace_columns: &mut [Vec<Felt>], main_trace_len: usize) {
591 let total_program_rows = trace_columns[0].len();
592 assert!(total_program_rows + NUM_RAND_ROWS - 1 <= main_trace_len);
593
594 let num_padding_rows = main_trace_len - total_program_rows;
595
596 for padding_row_idx in 0..num_padding_rows {
601 trace_columns[CLK_COL_IDX].push(Felt::from((total_program_rows + padding_row_idx) as u32));
602 }
603
604 trace_columns[CTX_COL_IDX].resize(main_trace_len, ZERO);
606
607 for fn_hash_col_idx in FN_HASH_RANGE {
610 trace_columns[fn_hash_col_idx].resize(main_trace_len, ZERO);
611 }
612
613 trace_columns[DECODER_TRACE_OFFSET + ADDR_COL_IDX].resize(main_trace_len, ZERO);
618
619 let halt_opcode = Operation::Halt.op_code();
621 for i in 0..NUM_OP_BITS {
622 let bit_value = Felt::from((halt_opcode >> i) & 1);
623 trace_columns[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + i].resize(main_trace_len, bit_value);
624 }
625
626 for i in 0..NUM_HASHER_COLUMNS {
630 let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + i;
631 if i < 4 {
632 let last_hasher_value = trace_columns[col_idx][total_program_rows - 1];
636 trace_columns[col_idx].resize(main_trace_len, last_hasher_value);
637 } else {
638 trace_columns[col_idx].resize(main_trace_len, ZERO);
640 }
641 }
642
643 trace_columns[DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX].resize(main_trace_len, ZERO);
645
646 trace_columns[DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX].resize(main_trace_len, ZERO);
648
649 trace_columns[DECODER_TRACE_OFFSET + OP_INDEX_COL_IDX].resize(main_trace_len, ZERO);
651
652 for i in 0..NUM_OP_BATCH_FLAGS {
654 trace_columns[DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET + i]
655 .resize(main_trace_len, ZERO);
656 }
657
658 trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET].resize(main_trace_len, ZERO);
662 trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1].resize(main_trace_len, ONE);
663
664 for i in 0..STACK_TRACE_WIDTH {
669 let col_idx = STACK_TRACE_OFFSET + i;
670 let last_stack_value = trace_columns[col_idx][total_program_rows - 1];
673 trace_columns[col_idx].resize(main_trace_len, last_stack_value);
674 }
675}