1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2
3use itertools::Itertools;
4use miden_air::{
5 Felt,
6 trace::{
7 CLK_COL_IDX, DECODER_TRACE_OFFSET, DECODER_TRACE_WIDTH, MIN_TRACE_LEN, MainTrace, RowIndex,
8 STACK_TRACE_OFFSET, STACK_TRACE_WIDTH, SYS_TRACE_WIDTH,
9 decoder::{
10 HASHER_STATE_OFFSET, NUM_HASHER_COLUMNS, NUM_OP_BITS, OP_BITS_EXTRA_COLS_OFFSET,
11 OP_BITS_OFFSET,
12 },
13 stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX, STACK_TOP_OFFSET},
14 },
15};
16use miden_core::{
17 ONE, Word, ZERO,
18 field::batch_inversion_allow_zeros,
19 mast::{MastForest, MastNode},
20 operations::opcodes,
21 program::{Kernel, MIN_STACK_DEPTH},
22};
23use rayon::prelude::*;
24use tracing::instrument;
25
26use crate::{
27 ContextId, ExecutionError,
28 continuation_stack::ContinuationStack,
29 errors::MapExecErrNoCtx,
30 trace::{
31 ChipletsLengths, ExecutionTrace, TraceBuildInputs, TraceLenSummary,
32 parallel::{processor::ReplayProcessor, tracer::CoreTraceGenerationTracer},
33 range::RangeChecker,
34 utils::RowMajorTraceWriter,
35 },
36};
37
38pub const CORE_TRACE_WIDTH: usize = SYS_TRACE_WIDTH + DECODER_TRACE_WIDTH + STACK_TRACE_WIDTH;
39
40const MAX_TRACE_LEN: usize = 1 << 29;
45
46pub(crate) mod core_trace_fragment;
47
48mod processor;
49mod tracer;
50
51use super::{
52 chiplets::Chiplets,
53 execution_tracer::TraceGenerationContext,
54 trace_state::{
55 AceReplay, BitwiseOp, BitwiseReplay, CoreTraceFragmentContext, CoreTraceState,
56 ExecutionReplay, HasherOp, HasherRequestReplay, KernelReplay, MemoryWritesReplay,
57 RangeCheckerReplay,
58 },
59};
60
61#[cfg(test)]
62mod tests;
63
64#[instrument(name = "build_trace", skip_all)]
85pub fn build_trace(inputs: TraceBuildInputs) -> Result<ExecutionTrace, ExecutionError> {
86 build_trace_with_max_len(inputs, MAX_TRACE_LEN)
87}
88
89pub fn build_trace_with_max_len(
94 inputs: TraceBuildInputs,
95 max_trace_len: usize,
96) -> Result<ExecutionTrace, ExecutionError> {
97 let TraceBuildInputs {
98 trace_output,
99 trace_generation_context,
100 program_info,
101 } = inputs;
102
103 if !trace_output.has_matching_precompile_requests_digest() {
104 return Err(ExecutionError::Internal(
105 "trace inputs do not match deferred precompile requests",
106 ));
107 }
108
109 let TraceGenerationContext {
110 core_trace_contexts,
111 range_checker_replay,
112 memory_writes,
113 bitwise_replay: bitwise,
114 kernel_replay,
115 hasher_for_chiplet,
116 ace_replay,
117 fragment_size,
118 max_stack_depth,
119 } = trace_generation_context;
120
121 let total_core_trace_rows = core_trace_contexts
128 .len()
129 .checked_mul(fragment_size)
130 .and_then(|n| n.checked_add(1))
131 .ok_or(ExecutionError::TraceLenExceeded(max_trace_len))?;
132 if total_core_trace_rows > max_trace_len {
133 return Err(ExecutionError::TraceLenExceeded(max_trace_len));
134 }
135
136 if core_trace_contexts.is_empty() {
137 return Err(ExecutionError::Internal(
138 "no trace fragments provided in the trace generation context",
139 ));
140 }
141
142 let chiplets = initialize_chiplets(
143 program_info.kernel().clone(),
144 &core_trace_contexts,
145 memory_writes,
146 bitwise,
147 kernel_replay,
148 hasher_for_chiplet,
149 ace_replay,
150 max_trace_len,
151 )?;
152
153 let range_checker = initialize_range_checker(range_checker_replay, &chiplets);
154
155 let mut core_trace_data = generate_core_trace_row_major(
156 core_trace_contexts,
157 program_info.kernel().clone(),
158 fragment_size,
159 max_stack_depth,
160 )?;
161
162 let core_trace_len = core_trace_data.len() / CORE_TRACE_WIDTH;
163
164 let range_table_len = range_checker.get_number_range_checker_rows();
166
167 let trace_len_summary =
168 TraceLenSummary::new(core_trace_len, range_table_len, ChipletsLengths::new(&chiplets));
169
170 let main_trace_len =
172 compute_main_trace_length(core_trace_len, range_table_len, chiplets.trace_len());
173
174 let ((range_checker_trace, chiplets_trace), ()) = rayon::join(
175 || {
176 rayon::join(
177 || range_checker.into_trace_with_table(range_table_len, main_trace_len),
178 || chiplets.into_trace(main_trace_len),
179 )
180 },
181 || pad_core_row_major(&mut core_trace_data, main_trace_len),
182 );
183
184 let main_trace = {
186 let last_program_row = RowIndex::from((core_trace_len as u32).saturating_sub(1));
187 MainTrace::from_parts(
188 core_trace_data,
189 chiplets_trace.trace,
190 range_checker_trace.trace,
191 main_trace_len,
192 last_program_row,
193 )
194 };
195
196 Ok(ExecutionTrace::new_from_parts(
197 program_info,
198 trace_output,
199 main_trace,
200 trace_len_summary,
201 ))
202}
203
204fn compute_main_trace_length(
208 core_trace_len: usize,
209 range_table_len: usize,
210 chiplets_trace_len: usize,
211) -> usize {
212 let max_len = range_table_len.max(core_trace_len).max(chiplets_trace_len);
214
215 let trace_len = max_len.next_power_of_two();
217 core::cmp::max(trace_len, MIN_TRACE_LEN)
218}
219
220fn generate_core_trace_row_major(
222 core_trace_contexts: Vec<CoreTraceFragmentContext>,
223 kernel: Kernel,
224 fragment_size: usize,
225 max_stack_depth: usize,
226) -> Result<Vec<Felt>, ExecutionError> {
227 let num_fragments = core_trace_contexts.len();
228 let total_allocated_rows = num_fragments * fragment_size;
229
230 let mut core_trace_data: Vec<Felt> = vec![ZERO; total_allocated_rows * CORE_TRACE_WIDTH];
231
232 let first_stack_top = if let Some(first_context) = core_trace_contexts.first() {
234 first_context.state.stack.stack_top.to_vec()
235 } else {
236 vec![ZERO; MIN_STACK_DEPTH]
237 };
238
239 let writers: Vec<RowMajorTraceWriter<'_, Felt>> = core_trace_data
240 .chunks_exact_mut(fragment_size * CORE_TRACE_WIDTH)
241 .map(|chunk| RowMajorTraceWriter::new(chunk, CORE_TRACE_WIDTH))
242 .collect();
243
244 let fragment_results: Result<Vec<_>, ExecutionError> = core_trace_contexts
246 .into_par_iter()
247 .zip(writers.into_par_iter())
248 .map(|(trace_state, writer)| {
249 let (mut processor, mut tracer, mut continuation_stack, mut current_forest) =
250 split_trace_fragment_context(trace_state, writer, fragment_size, max_stack_depth);
251
252 processor.execute(
253 &mut continuation_stack,
254 &mut current_forest,
255 &kernel,
256 &mut tracer,
257 )?;
258
259 tracer.into_final_state()
260 })
261 .collect();
262 let fragment_results = fragment_results?;
263
264 let mut stack_rows = Vec::new();
265 let mut system_rows = Vec::new();
266 let mut total_core_trace_rows = 0;
267
268 for final_state in fragment_results {
269 stack_rows.push(final_state.last_stack_cols);
270 system_rows.push(final_state.last_system_cols);
271 total_core_trace_rows += final_state.num_rows_written;
272 }
273
274 fixup_stack_and_system_rows(
276 &mut core_trace_data,
277 fragment_size,
278 &stack_rows,
279 &system_rows,
280 &first_stack_top,
281 );
282
283 {
287 let h0_col_offset = STACK_TRACE_OFFSET + H0_COL_IDX;
288 let w = CORE_TRACE_WIDTH;
289 core_trace_data[..total_core_trace_rows * w]
290 .par_chunks_mut(fragment_size * w)
291 .for_each(|fragment_chunk| {
292 let num_rows = fragment_chunk.len() / w;
293 let mut h0_vals: Vec<Felt> =
294 (0..num_rows).map(|r| fragment_chunk[r * w + h0_col_offset]).collect();
295 batch_inversion_allow_zeros(&mut h0_vals);
296 for (r, &val) in h0_vals.iter().enumerate() {
297 fragment_chunk[r * w + h0_col_offset] = val;
298 }
299 });
300 }
301
302 core_trace_data.truncate(total_core_trace_rows * CORE_TRACE_WIDTH);
304
305 push_halt_opcode_row(
306 &mut core_trace_data,
307 total_core_trace_rows,
308 system_rows.last().ok_or(ExecutionError::Internal(
309 "no trace fragments provided in the trace generation context",
310 ))?,
311 stack_rows.last().ok_or(ExecutionError::Internal(
312 "no trace fragments provided in the trace generation context",
313 ))?,
314 );
315
316 Ok(core_trace_data)
317}
318
319fn fixup_stack_and_system_rows(
326 core_trace_data: &mut [Felt],
327 fragment_size: usize,
328 stack_rows: &[[Felt; STACK_TRACE_WIDTH]],
329 system_rows: &[[Felt; SYS_TRACE_WIDTH]],
330 first_stack_top: &[Felt],
331) {
332 const MIN_STACK_DEPTH_FELT: Felt = Felt::new_unchecked(MIN_STACK_DEPTH as u64);
333 let w = CORE_TRACE_WIDTH;
334
335 {
336 let row = &mut core_trace_data[..w];
337
338 for (stack_col_idx, &value) in first_stack_top.iter().rev().enumerate() {
340 row[STACK_TRACE_OFFSET + STACK_TOP_OFFSET + stack_col_idx] = value;
341 }
342
343 row[STACK_TRACE_OFFSET + B0_COL_IDX] = MIN_STACK_DEPTH_FELT;
344 row[STACK_TRACE_OFFSET + B1_COL_IDX] = ZERO;
345 row[STACK_TRACE_OFFSET + H0_COL_IDX] = ZERO;
346 }
347
348 let total_rows = core_trace_data.len() / w;
349 let num_fragments = total_rows / fragment_size;
350
351 for frag_idx in 1..num_fragments {
352 let row_idx = frag_idx * fragment_size;
353 let row_start = row_idx * w;
354 let system_row = &system_rows[frag_idx - 1];
355 let stack_row = &stack_rows[frag_idx - 1];
356
357 core_trace_data[row_start..row_start + SYS_TRACE_WIDTH].copy_from_slice(system_row);
358
359 let stack_start = row_start + STACK_TRACE_OFFSET;
360 core_trace_data[stack_start..stack_start + STACK_TRACE_WIDTH].copy_from_slice(stack_row);
361 }
362}
363
364fn push_halt_opcode_row(
369 core_trace_data: &mut Vec<Felt>,
370 num_rows_before: usize,
371 last_system_state: &[Felt; SYS_TRACE_WIDTH],
372 last_stack_state: &[Felt; STACK_TRACE_WIDTH],
373) {
374 let w = CORE_TRACE_WIDTH;
375 let mut row = [ZERO; CORE_TRACE_WIDTH];
376
377 row[..SYS_TRACE_WIDTH].copy_from_slice(last_system_state);
380
381 row[STACK_TRACE_OFFSET..STACK_TRACE_OFFSET + STACK_TRACE_WIDTH]
384 .copy_from_slice(last_stack_state);
385
386 let halt_opcode = opcodes::HALT;
388 for bit_idx in 0..NUM_OP_BITS {
389 row[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + bit_idx] =
390 Felt::from_u8((halt_opcode >> bit_idx) & 1);
391 }
392
393 if num_rows_before > 0 {
397 let last_row_start = (num_rows_before - 1) * w;
398 for hasher_col_idx in 0..4 {
400 let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + hasher_col_idx;
401 row[col_idx] = core_trace_data[last_row_start + col_idx];
402 }
403 }
404
405 row[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1] = ONE;
409
410 core_trace_data.extend_from_slice(&row);
411}
412
413fn initialize_range_checker(
419 range_checker_replay: RangeCheckerReplay,
420 chiplets: &Chiplets,
421) -> RangeChecker {
422 let mut range_checker = RangeChecker::new();
423
424 for (clk, values) in range_checker_replay.into_iter() {
426 range_checker.add_range_checks(clk, &values);
427 }
428
429 chiplets.append_range_checks(&mut range_checker);
431
432 range_checker
433}
434
435fn initialize_chiplets(
438 kernel: Kernel,
439 core_trace_contexts: &[CoreTraceFragmentContext],
440 memory_writes: MemoryWritesReplay,
441 bitwise: BitwiseReplay,
442 kernel_replay: KernelReplay,
443 hasher_for_chiplet: HasherRequestReplay,
444 ace_replay: AceReplay,
445 max_trace_len: usize,
446) -> Result<Chiplets, ExecutionError> {
447 let check_chiplets_trace_len = |chiplets: &Chiplets| -> Result<(), ExecutionError> {
448 if chiplets.trace_len() > max_trace_len {
449 return Err(ExecutionError::TraceLenExceeded(max_trace_len));
450 }
451 Ok(())
452 };
453
454 let mut chiplets = Chiplets::new(kernel);
455
456 for hasher_op in hasher_for_chiplet.into_iter() {
458 match hasher_op {
459 HasherOp::Permute(input_state) => {
460 let _ = chiplets.hasher.permute(input_state);
461 check_chiplets_trace_len(&chiplets)?;
462 },
463 HasherOp::HashControlBlock((h1, h2, domain, expected_hash)) => {
464 let _ = chiplets.hasher.hash_control_block(h1, h2, domain, expected_hash);
465 check_chiplets_trace_len(&chiplets)?;
466 },
467 HasherOp::HashBasicBlock((forest, node_id, expected_hash)) => {
468 let node = forest
469 .get_node_by_id(node_id)
470 .ok_or(ExecutionError::Internal("invalid node ID in hasher replay"))?;
471 let MastNode::Block(basic_block_node) = node else {
472 return Err(ExecutionError::Internal(
473 "expected basic block node in hasher replay",
474 ));
475 };
476 let op_batches = basic_block_node.op_batches();
477 let _ = chiplets.hasher.hash_basic_block(op_batches, expected_hash);
478 check_chiplets_trace_len(&chiplets)?;
479 },
480 HasherOp::BuildMerkleRoot((value, path, index)) => {
481 let _ = chiplets.hasher.build_merkle_root(value, &path, index);
482 check_chiplets_trace_len(&chiplets)?;
483 },
484 HasherOp::UpdateMerkleRoot((old_value, new_value, path, index)) => {
485 chiplets.hasher.update_merkle_root(old_value, new_value, &path, index);
486 check_chiplets_trace_len(&chiplets)?;
487 },
488 }
489 }
490
491 for (bitwise_op, a, b) in bitwise {
493 match bitwise_op {
494 BitwiseOp::U32And => {
495 chiplets.bitwise.u32and(a, b).map_exec_err_no_ctx()?;
496 check_chiplets_trace_len(&chiplets)?;
497 },
498 BitwiseOp::U32Xor => {
499 chiplets.bitwise.u32xor(a, b).map_exec_err_no_ctx()?;
500 check_chiplets_trace_len(&chiplets)?;
501 },
502 }
503 }
504
505 {
511 let elements_written: Box<dyn Iterator<Item = MemoryAccess>> =
512 Box::new(memory_writes.iter_elements_written().map(|(element, addr, ctx, clk)| {
513 MemoryAccess::WriteElement(*addr, *element, *ctx, *clk)
514 }));
515 let words_written: Box<dyn Iterator<Item = MemoryAccess>> = Box::new(
516 memory_writes
517 .iter_words_written()
518 .map(|(word, addr, ctx, clk)| MemoryAccess::WriteWord(*addr, *word, *ctx, *clk)),
519 );
520 let elements_read: Box<dyn Iterator<Item = MemoryAccess>> =
521 Box::new(core_trace_contexts.iter().flat_map(|ctx| {
522 ctx.replay
523 .memory_reads
524 .iter_read_elements()
525 .map(|(_, addr, ctx, clk)| MemoryAccess::ReadElement(addr, ctx, clk))
526 }));
527 let words_read: Box<dyn Iterator<Item = MemoryAccess>> =
528 Box::new(core_trace_contexts.iter().flat_map(|ctx| {
529 ctx.replay
530 .memory_reads
531 .iter_read_words()
532 .map(|(_, addr, ctx, clk)| MemoryAccess::ReadWord(addr, ctx, clk))
533 }));
534
535 [elements_written, words_written, elements_read, words_read]
536 .into_iter()
537 .kmerge_by(|a, b| a.clk() < b.clk())
538 .try_for_each(|mem_access| {
539 match mem_access {
540 MemoryAccess::ReadElement(addr, ctx, clk) => chiplets
541 .memory
542 .read(ctx, addr, clk)
543 .map(|_| ())
544 .map_err(ExecutionError::MemoryErrorNoCtx)?,
545 MemoryAccess::WriteElement(addr, element, ctx, clk) => chiplets
546 .memory
547 .write(ctx, addr, clk, element)
548 .map_err(ExecutionError::MemoryErrorNoCtx)?,
549 MemoryAccess::ReadWord(addr, ctx, clk) => chiplets
550 .memory
551 .read_word(ctx, addr, clk)
552 .map(|_| ())
553 .map_err(ExecutionError::MemoryErrorNoCtx)?,
554 MemoryAccess::WriteWord(addr, word, ctx, clk) => chiplets
555 .memory
556 .write_word(ctx, addr, clk, word)
557 .map_err(ExecutionError::MemoryErrorNoCtx)?,
558 }
559 check_chiplets_trace_len(&chiplets)
560 })?;
561
562 enum MemoryAccess {
563 ReadElement(Felt, ContextId, RowIndex),
564 WriteElement(Felt, Felt, ContextId, RowIndex),
565 ReadWord(Felt, ContextId, RowIndex),
566 WriteWord(Felt, Word, ContextId, RowIndex),
567 }
568
569 impl MemoryAccess {
570 fn clk(&self) -> RowIndex {
571 match self {
572 MemoryAccess::ReadElement(_, _, clk) => *clk,
573 MemoryAccess::WriteElement(_, _, _, clk) => *clk,
574 MemoryAccess::ReadWord(_, _, clk) => *clk,
575 MemoryAccess::WriteWord(_, _, _, clk) => *clk,
576 }
577 }
578 }
579 }
580
581 for (clk, circuit_eval) in ace_replay.into_iter() {
583 chiplets.ace.add_circuit_evaluation(clk, circuit_eval);
584 check_chiplets_trace_len(&chiplets)?;
585 }
586
587 for proc_hash in kernel_replay.into_iter() {
589 chiplets.kernel_rom.access_proc(proc_hash).map_exec_err_no_ctx()?;
590 check_chiplets_trace_len(&chiplets)?;
591 }
592
593 Ok(chiplets)
594}
595
596fn pad_core_row_major(core_trace_data: &mut Vec<Felt>, main_trace_len: usize) {
598 let w = CORE_TRACE_WIDTH;
599 let total_program_rows = core_trace_data.len() / w;
600 assert!(total_program_rows <= main_trace_len);
601 assert!(total_program_rows > 0);
602
603 let num_padding_rows = main_trace_len - total_program_rows;
604 if num_padding_rows == 0 {
605 return;
606 }
607 let last_row_start = (total_program_rows - 1) * w;
608
609 let mut template = [ZERO; CORE_TRACE_WIDTH];
613 let halt_opcode = opcodes::HALT;
615 for i in 0..NUM_OP_BITS {
616 let bit_value = Felt::from_u8((halt_opcode >> i) & 1);
617 template[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + i] = bit_value;
618 }
619 for i in 0..NUM_HASHER_COLUMNS {
623 let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + i;
624 template[col_idx] = if i < 4 {
625 core_trace_data[last_row_start + col_idx]
629 } else {
630 ZERO
631 };
632 }
633
634 template[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1] = ONE;
638
639 for i in 0..STACK_TRACE_WIDTH {
644 let col_idx = STACK_TRACE_OFFSET + i;
645 template[col_idx] = core_trace_data[last_row_start + col_idx];
648 }
649
650 let pad_start = total_program_rows * w;
656 core_trace_data.resize(pad_start + num_padding_rows * w, ZERO);
657 core_trace_data[pad_start..]
658 .par_chunks_mut(w)
659 .enumerate()
660 .for_each(|(idx, row)| {
661 row.copy_from_slice(&template);
662 row[CLK_COL_IDX] = Felt::from_u32((total_program_rows + idx) as u32);
663 });
664}
665
666fn split_trace_fragment_context<'a>(
669 fragment_context: CoreTraceFragmentContext,
670 writer: RowMajorTraceWriter<'a, Felt>,
671 fragment_size: usize,
672 max_stack_depth: usize,
673) -> (
674 ReplayProcessor,
675 CoreTraceGenerationTracer<'a>,
676 ContinuationStack,
677 Arc<MastForest>,
678) {
679 let CoreTraceFragmentContext {
680 state: CoreTraceState { system, decoder, stack },
681 replay:
682 ExecutionReplay {
683 block_stack: block_stack_replay,
684 execution_context: execution_context_replay,
685 stack_overflow: stack_overflow_replay,
686 memory_reads: memory_reads_replay,
687 advice: advice_replay,
688 hasher: hasher_response_replay,
689 block_address: block_address_replay,
690 mast_forest_resolution: mast_forest_resolution_replay,
691 },
692 continuation,
693 initial_mast_forest,
694 } = fragment_context;
695
696 let processor = ReplayProcessor::new(
697 system,
698 stack,
699 stack_overflow_replay,
700 execution_context_replay,
701 advice_replay,
702 memory_reads_replay,
703 hasher_response_replay,
704 mast_forest_resolution_replay,
705 max_stack_depth,
706 fragment_size.into(),
707 );
708 let tracer =
709 CoreTraceGenerationTracer::new(writer, decoder, block_address_replay, block_stack_replay);
710
711 (processor, tracer, continuation, initial_mast_forest)
712}