1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2
3use itertools::Itertools;
4use miden_air::{
5 Felt,
6 trace::{
7 CLK_COL_IDX, CTX_COL_IDX, DECODER_TRACE_OFFSET, DECODER_TRACE_WIDTH, FN_HASH_RANGE,
8 MIN_TRACE_LEN, MainTrace, PADDED_TRACE_WIDTH, RowIndex, STACK_TRACE_OFFSET,
9 STACK_TRACE_WIDTH, SYS_TRACE_WIDTH, TRACE_WIDTH,
10 decoder::{
11 ADDR_COL_IDX, GROUP_COUNT_COL_IDX, HASHER_STATE_OFFSET, IN_SPAN_COL_IDX,
12 NUM_HASHER_COLUMNS, NUM_OP_BATCH_FLAGS, NUM_OP_BITS, OP_BATCH_FLAGS_OFFSET,
13 OP_BITS_EXTRA_COLS_OFFSET, OP_BITS_OFFSET, OP_INDEX_COL_IDX,
14 },
15 stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX, STACK_TOP_OFFSET},
16 },
17};
18use miden_core::{
19 ONE, Word, ZERO,
20 field::{PrimeCharacteristicRing, batch_inversion_allow_zeros},
21 mast::MastForest,
22 operations::Operation,
23 program::{Kernel, MIN_STACK_DEPTH, ProgramInfo},
24 utils::{ColMatrix, uninit_vector},
25};
26use rayon::prelude::*;
27use tracing::instrument;
28
29use crate::{
30 ContextId,
31 continuation_stack::ContinuationStack,
32 fast::ExecutionOutput,
33 trace::{
34 AuxTraceBuilders, ChipletsLengths, ExecutionTrace, TraceLenSummary,
35 parallel::{processor::ReplayProcessor, tracer::CoreTraceGenerationTracer},
36 range::RangeChecker,
37 },
38};
39
40pub const CORE_TRACE_WIDTH: usize = SYS_TRACE_WIDTH + DECODER_TRACE_WIDTH + STACK_TRACE_WIDTH;
41
42pub(crate) mod core_trace_fragment;
43use core_trace_fragment::CoreTraceFragment;
44
45mod processor;
46mod tracer;
47
48use super::{
49 chiplets::Chiplets,
50 decoder::AuxTraceBuilder as DecoderAuxTraceBuilder,
51 execution_tracer::TraceGenerationContext,
52 stack::AuxTraceBuilder as StackAuxTraceBuilder,
53 trace_state::{
54 AceReplay, BitwiseOp, BitwiseReplay, CoreTraceFragmentContext, CoreTraceState,
55 ExecutionReplay, HasherOp, HasherRequestReplay, KernelReplay, MemoryWritesReplay,
56 RangeCheckerReplay,
57 },
58};
59
60#[cfg(test)]
61mod tests;
62
63#[instrument(name = "build_trace", skip_all)]
68pub fn build_trace(
69 execution_output: ExecutionOutput,
70 trace_generation_context: TraceGenerationContext,
71 program_info: ProgramInfo,
72) -> ExecutionTrace {
73 let TraceGenerationContext {
74 core_trace_contexts,
75 range_checker_replay,
76 memory_writes,
77 bitwise_replay: bitwise,
78 kernel_replay,
79 hasher_for_chiplet,
80 ace_replay,
81 final_pc_transcript,
82 fragment_size,
83 } = trace_generation_context;
84
85 let chiplets = initialize_chiplets(
86 program_info.kernel().clone(),
87 &core_trace_contexts,
88 memory_writes,
89 bitwise,
90 kernel_replay,
91 hasher_for_chiplet,
92 ace_replay,
93 );
94
95 let range_checker = initialize_range_checker(range_checker_replay, &chiplets);
96
97 let mut core_trace_columns = generate_core_trace_columns(
98 core_trace_contexts,
99 program_info.kernel().clone(),
100 fragment_size,
101 );
102
103 let core_trace_len = core_trace_columns[0].len();
105
106 let range_table_len = range_checker.get_number_range_checker_rows();
108
109 let trace_len_summary =
110 TraceLenSummary::new(core_trace_len, range_table_len, ChipletsLengths::new(&chiplets));
111
112 let main_trace_len =
114 compute_main_trace_length(core_trace_len, range_table_len, chiplets.trace_len());
115
116 let ((), (range_checker_trace, chiplets_trace)) = rayon::join(
117 || pad_trace_columns(&mut core_trace_columns, main_trace_len),
118 || {
119 rayon::join(
120 || range_checker.into_trace_with_table(range_table_len, main_trace_len),
121 || chiplets.into_trace(main_trace_len, final_pc_transcript.state()),
122 )
123 },
124 );
125
126 let padding_columns = vec![vec![ZERO; main_trace_len]; PADDED_TRACE_WIDTH - TRACE_WIDTH];
128
129 let trace_columns: Vec<Vec<Felt>> = core_trace_columns
131 .into_iter()
132 .chain(range_checker_trace.trace)
133 .chain(chiplets_trace.trace)
134 .chain(padding_columns)
135 .collect();
136
137 let main_trace = {
139 let last_program_row = RowIndex::from((core_trace_len as u32).saturating_sub(1));
140 let col_matrix = ColMatrix::new(trace_columns);
141 MainTrace::new(col_matrix, last_program_row)
142 };
143
144 let aux_trace_builders = AuxTraceBuilders {
146 decoder: DecoderAuxTraceBuilder::default(),
147 range: range_checker_trace.aux_builder,
148 chiplets: chiplets_trace.aux_builder,
149 stack: StackAuxTraceBuilder,
150 };
151
152 ExecutionTrace::new_from_parts(
153 program_info,
154 execution_output,
155 main_trace,
156 aux_trace_builders,
157 trace_len_summary,
158 )
159}
160
161fn compute_main_trace_length(
165 core_trace_len: usize,
166 range_table_len: usize,
167 chiplets_trace_len: usize,
168) -> usize {
169 let max_len = range_table_len.max(core_trace_len).max(chiplets_trace_len);
171
172 let trace_len = max_len.next_power_of_two();
174 core::cmp::max(trace_len, MIN_TRACE_LEN)
175}
176
177fn generate_core_trace_columns(
179 core_trace_contexts: Vec<CoreTraceFragmentContext>,
180 kernel: Kernel,
181 fragment_size: usize,
182) -> Vec<Vec<Felt>> {
183 let mut core_trace_columns: Vec<Vec<Felt>> =
184 unsafe { vec![uninit_vector(core_trace_contexts.len() * fragment_size); CORE_TRACE_WIDTH] };
185
186 let first_stack_top = if let Some(first_context) = core_trace_contexts.first() {
188 first_context.state.stack.stack_top.to_vec()
189 } else {
190 vec![ZERO; MIN_STACK_DEPTH]
191 };
192
193 let mut fragments = create_fragments_from_trace_columns(&mut core_trace_columns, fragment_size);
194
195 let fragment_results: Vec<([Felt; STACK_TRACE_WIDTH], [Felt; SYS_TRACE_WIDTH], usize)> =
197 core_trace_contexts
198 .into_par_iter()
199 .zip(fragments.par_iter_mut())
200 .map(|(trace_state, fragment)| {
201 let (mut processor, mut tracer, mut continuation_stack, mut current_forest) =
202 split_trace_fragment_context(trace_state, fragment, fragment_size);
203
204 processor
205 .execute(&mut continuation_stack, &mut current_forest, &kernel, &mut tracer)
206 .expect("fragment execution failed");
207
208 tracer.into_parts()
209 })
210 .collect();
211
212 let mut stack_rows = Vec::new();
214 let mut system_rows = Vec::new();
215 let mut total_core_trace_rows = 0;
216
217 for (stack_row, system_row, num_rows_written) in fragment_results {
218 stack_rows.push(stack_row);
219 system_rows.push(system_row);
220 total_core_trace_rows += num_rows_written;
221 }
222
223 fixup_stack_and_system_rows(
225 &mut core_trace_columns,
226 fragment_size,
227 &stack_rows,
228 &system_rows,
229 &first_stack_top,
230 );
231
232 {
236 let h0_column = &mut core_trace_columns[STACK_TRACE_OFFSET + H0_COL_IDX];
237 h0_column.par_chunks_mut(fragment_size).for_each(batch_inversion_allow_zeros);
238 }
239
240 for col in core_trace_columns.iter_mut() {
242 col.truncate(total_core_trace_rows);
243 }
244
245 push_halt_opcode_row(
246 &mut core_trace_columns,
247 system_rows.last().expect(
248 "system_rows should not be empty, which indicates that there are no trace fragments",
249 ),
250 stack_rows.last().expect(
251 "stack_rows should not be empty, which indicates that there are no trace fragments",
252 ),
253 );
254
255 core_trace_columns
256}
257
258fn create_fragments_from_trace_columns(
261 core_trace_columns: &mut [Vec<Felt>],
262 fragment_size: usize,
263) -> Vec<CoreTraceFragment<'_>> {
264 let mut column_chunks: Vec<_> = core_trace_columns
265 .iter_mut()
266 .map(|col| col.chunks_exact_mut(fragment_size))
267 .collect();
268 let mut core_trace_fragments = Vec::new();
269
270 loop {
271 let fragment_cols: Vec<&mut [Felt]> =
272 column_chunks.iter_mut().filter_map(|col_chunk| col_chunk.next()).collect();
273 assert!(
274 fragment_cols.is_empty() || fragment_cols.len() == CORE_TRACE_WIDTH,
275 "column chunks don't all have the same size"
276 );
277
278 if fragment_cols.is_empty() {
279 return core_trace_fragments;
280 } else {
281 core_trace_fragments.push(CoreTraceFragment {
282 columns: fragment_cols.try_into().expect("fragment has CORE_TRACE_WIDTH columns"),
283 });
284 }
285 }
286}
287
288fn fixup_stack_and_system_rows(
295 core_trace_columns: &mut [Vec<Felt>],
296 fragment_size: usize,
297 stack_rows: &[[Felt; STACK_TRACE_WIDTH]],
298 system_rows: &[[Felt; SYS_TRACE_WIDTH]],
299 first_stack_top: &[Felt],
300) {
301 const MIN_STACK_DEPTH_FELT: Felt = Felt::new(MIN_STACK_DEPTH as u64);
302
303 let system_state_first_row = [
304 ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ];
311
312 {
314 for (col_idx, &value) in system_state_first_row.iter().enumerate() {
316 core_trace_columns[col_idx][0] = value;
317 }
318
319 for (stack_col_idx, &value) in first_stack_top.iter().rev().enumerate() {
323 core_trace_columns[STACK_TRACE_OFFSET + STACK_TOP_OFFSET + stack_col_idx][0] = value;
324 }
325
326 core_trace_columns[STACK_TRACE_OFFSET + B0_COL_IDX][0] = MIN_STACK_DEPTH_FELT;
328 core_trace_columns[STACK_TRACE_OFFSET + B1_COL_IDX][0] = ZERO;
329 core_trace_columns[STACK_TRACE_OFFSET + H0_COL_IDX][0] = ZERO;
330 }
331
332 let fragment_start_row_indices = {
335 let num_fragments = core_trace_columns[0].len() / fragment_size;
336
337 (0..).step_by(fragment_size).take(num_fragments).skip(1)
338 };
339
340 for (row_idx, (system_row, stack_row)) in
342 fragment_start_row_indices.zip(system_rows.iter().zip(stack_rows.iter()))
343 {
344 for (col_idx, &value) in system_row.iter().enumerate() {
346 core_trace_columns[col_idx][row_idx] = value;
347 }
348
349 for (col_idx, &value) in stack_row.iter().enumerate() {
351 core_trace_columns[STACK_TRACE_OFFSET + col_idx][row_idx] = value;
352 }
353 }
354}
355
356fn push_halt_opcode_row(
361 core_trace_columns: &mut [Vec<Felt>],
362 last_system_state: &[Felt; SYS_TRACE_WIDTH],
363 last_stack_state: &[Felt; STACK_TRACE_WIDTH],
364) {
365 for (col_idx, &value) in last_system_state.iter().enumerate() {
368 core_trace_columns[col_idx].push(value);
369 }
370
371 for (col_idx, &value) in last_stack_state.iter().enumerate() {
374 core_trace_columns[STACK_TRACE_OFFSET + col_idx].push(value);
375 }
376
377 core_trace_columns[DECODER_TRACE_OFFSET + ADDR_COL_IDX].push(ZERO);
381
382 let halt_opcode = Operation::Halt.op_code();
384 for bit_idx in 0..NUM_OP_BITS {
385 let bit_value = Felt::from_u8((halt_opcode >> bit_idx) & 1);
386 core_trace_columns[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + bit_idx].push(bit_value);
387 }
388
389 for hasher_col_idx in 0..NUM_HASHER_COLUMNS {
393 let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + hasher_col_idx;
394 if hasher_col_idx < 4 {
395 let last_row_idx = core_trace_columns[col_idx].len() - 1;
397 let last_hasher_value = core_trace_columns[col_idx][last_row_idx];
398 core_trace_columns[col_idx].push(last_hasher_value);
399 } else {
400 core_trace_columns[col_idx].push(ZERO);
402 }
403 }
404
405 core_trace_columns[DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX].push(ZERO);
407
408 core_trace_columns[DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX].push(ZERO);
410
411 core_trace_columns[DECODER_TRACE_OFFSET + OP_INDEX_COL_IDX].push(ZERO);
413
414 for batch_flag_idx in 0..NUM_OP_BATCH_FLAGS {
416 let col_idx = DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET + batch_flag_idx;
417 core_trace_columns[col_idx].push(ZERO);
418 }
419
420 core_trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET].push(ZERO);
424 core_trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1].push(ONE);
425}
426
427fn initialize_range_checker(
428 range_checker_replay: RangeCheckerReplay,
429 chiplets: &Chiplets,
430) -> RangeChecker {
431 let mut range_checker = RangeChecker::new();
432
433 for (clk, values) in range_checker_replay.into_iter() {
435 range_checker.add_range_checks(clk, &values);
436 }
437
438 chiplets.append_range_checks(&mut range_checker);
440
441 range_checker
442}
443
444fn initialize_chiplets(
447 kernel: Kernel,
448 core_trace_contexts: &[CoreTraceFragmentContext],
449 memory_writes: MemoryWritesReplay,
450 bitwise: BitwiseReplay,
451 kernel_replay: KernelReplay,
452 hasher_for_chiplet: HasherRequestReplay,
453 ace_replay: AceReplay,
454) -> Chiplets {
455 let mut chiplets = Chiplets::new(kernel);
456
457 for hasher_op in hasher_for_chiplet.into_iter() {
459 match hasher_op {
460 HasherOp::Permute(input_state) => {
461 let _ = chiplets.hasher.permute(input_state);
462 },
463 HasherOp::HashControlBlock((h1, h2, domain, expected_hash)) => {
464 let _ = chiplets.hasher.hash_control_block(h1, h2, domain, expected_hash);
465 },
466 HasherOp::HashBasicBlock((forest, node_id, expected_hash)) => {
467 let op_batches = forest[node_id].unwrap_basic_block().op_batches();
468 let _ = chiplets.hasher.hash_basic_block(op_batches, expected_hash);
469 },
470 HasherOp::BuildMerkleRoot((value, path, index)) => {
471 let _ = chiplets.hasher.build_merkle_root(value, &path, index);
472 },
473 HasherOp::UpdateMerkleRoot((old_value, new_value, path, index)) => {
474 chiplets.hasher.update_merkle_root(old_value, new_value, &path, index);
475 },
476 }
477 }
478
479 for (bitwise_op, a, b) in bitwise {
481 match bitwise_op {
482 BitwiseOp::U32And => {
483 let _ = chiplets
484 .bitwise
485 .u32and(a, b)
486 .expect("bitwise AND operation failed when populating chiplet");
487 },
488 BitwiseOp::U32Xor => {
489 let _ = chiplets
490 .bitwise
491 .u32xor(a, b)
492 .expect("bitwise XOR operation failed when populating chiplet");
493 },
494 }
495 }
496
497 {
503 let elements_written: Box<dyn Iterator<Item = MemoryAccess>> =
504 Box::new(memory_writes.iter_elements_written().map(|(element, addr, ctx, clk)| {
505 MemoryAccess::WriteElement(*addr, *element, *ctx, *clk)
506 }));
507 let words_written: Box<dyn Iterator<Item = MemoryAccess>> = Box::new(
508 memory_writes
509 .iter_words_written()
510 .map(|(word, addr, ctx, clk)| MemoryAccess::WriteWord(*addr, *word, *ctx, *clk)),
511 );
512 let elements_read: Box<dyn Iterator<Item = MemoryAccess>> =
513 Box::new(core_trace_contexts.iter().flat_map(|ctx| {
514 ctx.replay
515 .memory_reads
516 .iter_read_elements()
517 .map(|(_, addr, ctx, clk)| MemoryAccess::ReadElement(addr, ctx, clk))
518 }));
519 let words_read: Box<dyn Iterator<Item = MemoryAccess>> =
520 Box::new(core_trace_contexts.iter().flat_map(|ctx| {
521 ctx.replay
522 .memory_reads
523 .iter_read_words()
524 .map(|(_, addr, ctx, clk)| MemoryAccess::ReadWord(addr, ctx, clk))
525 }));
526
527 [elements_written, words_written, elements_read, words_read]
528 .into_iter()
529 .kmerge_by(|a, b| a.clk() < b.clk())
530 .for_each(|mem_access| match mem_access {
531 MemoryAccess::ReadElement(addr, ctx, clk) => {
532 let _ = chiplets
533 .memory
534 .read(ctx, addr, clk)
535 .expect("memory read element failed when populating chiplet");
536 },
537 MemoryAccess::WriteElement(addr, element, ctx, clk) => {
538 chiplets
539 .memory
540 .write(ctx, addr, clk, element)
541 .expect("memory write element failed when populating chiplet");
542 },
543 MemoryAccess::ReadWord(addr, ctx, clk) => {
544 chiplets
545 .memory
546 .read_word(ctx, addr, clk)
547 .expect("memory read word failed when populating chiplet");
548 },
549 MemoryAccess::WriteWord(addr, word, ctx, clk) => {
550 chiplets
551 .memory
552 .write_word(ctx, addr, clk, word)
553 .expect("memory write word failed when populating chiplet");
554 },
555 });
556
557 enum MemoryAccess {
558 ReadElement(Felt, ContextId, RowIndex),
559 WriteElement(Felt, Felt, ContextId, RowIndex),
560 ReadWord(Felt, ContextId, RowIndex),
561 WriteWord(Felt, Word, ContextId, RowIndex),
562 }
563
564 impl MemoryAccess {
565 fn clk(&self) -> RowIndex {
566 match self {
567 MemoryAccess::ReadElement(_, _, clk) => *clk,
568 MemoryAccess::WriteElement(_, _, _, clk) => *clk,
569 MemoryAccess::ReadWord(_, _, clk) => *clk,
570 MemoryAccess::WriteWord(_, _, _, clk) => *clk,
571 }
572 }
573 }
574 }
575
576 for (clk, circuit_eval) in ace_replay.into_iter() {
578 chiplets.ace.add_circuit_evaluation(clk, circuit_eval);
579 }
580
581 for proc_hash in kernel_replay.into_iter() {
583 chiplets
584 .kernel_rom
585 .access_proc(proc_hash)
586 .expect("kernel proc access failed when populating chiplet");
587 }
588
589 chiplets
590}
591
592fn pad_trace_columns(trace_columns: &mut [Vec<Felt>], main_trace_len: usize) {
593 let total_program_rows = trace_columns[0].len();
594 assert!(total_program_rows <= main_trace_len);
595
596 let num_padding_rows = main_trace_len - total_program_rows;
597
598 for padding_row_idx in 0..num_padding_rows {
603 trace_columns[CLK_COL_IDX]
604 .push(Felt::from_u32((total_program_rows + padding_row_idx) as u32));
605 }
606
607 trace_columns[CTX_COL_IDX].resize(main_trace_len, ZERO);
609
610 for fn_hash_col_idx in FN_HASH_RANGE {
613 trace_columns[fn_hash_col_idx].resize(main_trace_len, ZERO);
614 }
615
616 trace_columns[DECODER_TRACE_OFFSET + ADDR_COL_IDX].resize(main_trace_len, ZERO);
621
622 let halt_opcode = Operation::Halt.op_code();
624 for i in 0..NUM_OP_BITS {
625 let bit_value = Felt::from_u8((halt_opcode >> i) & 1);
626 trace_columns[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + i].resize(main_trace_len, bit_value);
627 }
628
629 for i in 0..NUM_HASHER_COLUMNS {
633 let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + i;
634 if i < 4 {
635 let last_hasher_value = trace_columns[col_idx][total_program_rows - 1];
639 trace_columns[col_idx].resize(main_trace_len, last_hasher_value);
640 } else {
641 trace_columns[col_idx].resize(main_trace_len, ZERO);
643 }
644 }
645
646 trace_columns[DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX].resize(main_trace_len, ZERO);
648
649 trace_columns[DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX].resize(main_trace_len, ZERO);
651
652 trace_columns[DECODER_TRACE_OFFSET + OP_INDEX_COL_IDX].resize(main_trace_len, ZERO);
654
655 for i in 0..NUM_OP_BATCH_FLAGS {
657 trace_columns[DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET + i]
658 .resize(main_trace_len, ZERO);
659 }
660
661 trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET].resize(main_trace_len, ZERO);
665 trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1].resize(main_trace_len, ONE);
666
667 for i in 0..STACK_TRACE_WIDTH {
672 let col_idx = STACK_TRACE_OFFSET + i;
673 let last_stack_value = trace_columns[col_idx][total_program_rows - 1];
676 trace_columns[col_idx].resize(main_trace_len, last_stack_value);
677 }
678}
679
680fn split_trace_fragment_context<'a>(
683 fragment_context: CoreTraceFragmentContext,
684 fragment: &'a mut CoreTraceFragment<'a>,
685 fragment_size: usize,
686) -> (
687 ReplayProcessor,
688 CoreTraceGenerationTracer<'a>,
689 ContinuationStack,
690 Arc<MastForest>,
691) {
692 let CoreTraceFragmentContext {
693 state: CoreTraceState { system, decoder, stack },
694 replay:
695 ExecutionReplay {
696 block_stack: block_stack_replay,
697 execution_context: execution_context_replay,
698 stack_overflow: stack_overflow_replay,
699 memory_reads: memory_reads_replay,
700 advice: advice_replay,
701 hasher: hasher_response_replay,
702 block_address: block_address_replay,
703 mast_forest_resolution: mast_forest_resolution_replay,
704 },
705 continuation,
706 initial_mast_forest,
707 } = fragment_context;
708
709 let processor = ReplayProcessor::new(
710 system,
711 stack,
712 stack_overflow_replay,
713 execution_context_replay,
714 advice_replay,
715 memory_reads_replay,
716 hasher_response_replay,
717 mast_forest_resolution_replay,
718 fragment_size.into(),
719 );
720 let tracer =
721 CoreTraceGenerationTracer::new(fragment, decoder, block_address_replay, block_stack_replay);
722
723 (processor, tracer, continuation, initial_mast_forest)
724}