Skip to main content

miden_processor/trace/parallel/
mod.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::borrow::{Borrow, BorrowMut};
3
4use itertools::Itertools;
5use miden_air::{
6    CoreCols, Felt, StackCols, SystemCols,
7    trace::{
8        DECODER_TRACE_WIDTH, MIN_TRACE_LEN, MainTrace, RANGE_CHECK_TRACE_WIDTH, RowIndex,
9        STACK_TRACE_WIDTH, SYS_TRACE_WIDTH, decoder::NUM_OP_BITS,
10    },
11};
12use miden_core::{
13    ONE, Word, ZERO,
14    field::{PrimeCharacteristicRing, batch_inversion_allow_zeros},
15    mast::{ExecutableMastForest, MastForestId, MastNode, SparseMastForest},
16    operations::opcodes,
17    program::{Kernel, MIN_STACK_DEPTH},
18    utils::Idx,
19};
20use rayon::prelude::*;
21use tracing::instrument;
22
23use crate::{
24    ContextId, ExecutionError,
25    continuation_stack::{Continuation, ContinuationStack},
26    errors::MapExecErrNoCtx,
27    trace::{
28        ChipletsLengths, ExecutionTrace, TraceBuildInputs, TraceLenSummary,
29        parallel::{processor::ReplayProcessor, tracer::CoreTraceGenerationTracer},
30        range::RangeChecker,
31        utils::RowMajorTraceWriter,
32    },
33};
34
35/// Per-row payload written by the core tracer (system + decoder + stack).
36pub const CORE_TRACE_WIDTH: usize = SYS_TRACE_WIDTH + DECODER_TRACE_WIDTH + STACK_TRACE_WIDTH;
37
38/// Physical row width of the core buffer: the [`CORE_TRACE_WIDTH`] payload plus the two
39/// trailing range-checker columns, which together form the per-AIR Core matrix
40/// (`NUM_CORE_COLS`) consumed directly by proving. The range columns are filled in-place
41/// after padding (see `write_range_into_core`).
42pub const CORE_STORAGE_WIDTH: usize = CORE_TRACE_WIDTH + RANGE_CHECK_TRACE_WIDTH;
43
44/// `build_trace()` uses this as a hard cap on trace rows.
45///
46/// The code checks `core_trace_contexts.len() * fragment_size` before allocation. It checks the
47/// same cap again while replaying chiplet activity. This keeps memory use bounded.
48const MAX_TRACE_LEN: usize = 1 << 29;
49
50pub(crate) mod core_trace_fragment;
51
52mod processor;
53mod tracer;
54
55use super::{
56    chiplets::Chiplets,
57    execution_tracer::TraceGenerationContext,
58    trace_state::{
59        AceReplay, BitwiseOp, BitwiseReplay, CoreTraceFragmentContext, CoreTraceState,
60        ExecutionReplay, HasherOp, HasherRequestReplay, KernelReplay, MemoryWritesReplay,
61        RangeCheckerReplay,
62    },
63};
64
65#[cfg(test)]
66mod tests;
67
68// BUILD TRACE
69// ================================================================================================
70
71/// Builds the main trace from the provided trace states in parallel.
72///
73/// # Example
74/// ```
75/// use miden_assembly::Assembler;
76/// use miden_processor::{DefaultHost, FastProcessor, StackInputs};
77///
78/// let program = Assembler::default()
79///     .assemble_program("prg", "begin push.1 drop end")
80///     .unwrap()
81///     .unwrap_program();
82/// let mut host = DefaultHost::default();
83///
84/// let trace_inputs = FastProcessor::new(StackInputs::default())
85///     .execute_trace_inputs_sync(&program, &mut host)
86///     .unwrap();
87/// let trace = miden_processor::trace::build_trace(trace_inputs).unwrap();
88///
89/// assert_eq!(*trace.program_hash(), program.hash());
90/// ```
91#[instrument(name = "build_trace", skip_all)]
92pub fn build_trace(inputs: TraceBuildInputs) -> Result<ExecutionTrace, ExecutionError> {
93    build_trace_with_max_len(inputs, MAX_TRACE_LEN)
94}
95
96/// Same as [`build_trace`], but with a custom hard cap.
97///
98/// When the trace would go over `max_trace_len`, this returns
99/// [`ExecutionError::TraceLenExceeded`].
100pub fn build_trace_with_max_len(
101    inputs: TraceBuildInputs,
102    max_trace_len: usize,
103) -> Result<ExecutionTrace, ExecutionError> {
104    let TraceBuildInputs {
105        trace_output,
106        trace_generation_context,
107        program_info,
108    } = inputs;
109
110    if !trace_output.has_matching_precompile_requests_digest() {
111        return Err(ExecutionError::Internal(
112            "trace inputs do not match deferred precompile requests",
113        ));
114    }
115
116    let TraceGenerationContext {
117        core_trace_contexts,
118        mast_forest_store,
119        range_checker_replay,
120        memory_writes,
121        bitwise_replay: bitwise,
122        kernel_replay,
123        hasher_for_chiplet,
124        ace_replay,
125        fragment_size,
126        max_stack_depth,
127    } = trace_generation_context;
128
129    // Before any trace generation, check that the number of core trace rows doesn't exceed the
130    // maximum trace length. This is a necessary check to avoid OOM panics during trace generation,
131    // which can occur if the execution produces an extremely large number of steps.
132    //
133    // Note that we add 1 to the total core trace rows to account for the additional HALT opcode row
134    // that is pushed at the end of the last fragment.
135    let total_core_trace_rows = core_trace_contexts
136        .len()
137        .checked_mul(fragment_size)
138        .and_then(|n| n.checked_add(1))
139        .ok_or(ExecutionError::TraceLenExceeded(max_trace_len))?;
140    if total_core_trace_rows > max_trace_len {
141        return Err(ExecutionError::TraceLenExceeded(max_trace_len));
142    }
143
144    if core_trace_contexts.is_empty() {
145        return Err(ExecutionError::Internal(
146            "no trace fragments provided in the trace generation context",
147        ));
148    }
149
150    let chiplets = initialize_chiplets(
151        program_info.kernel().clone(),
152        &core_trace_contexts,
153        memory_writes,
154        bitwise,
155        kernel_replay,
156        hasher_for_chiplet,
157        ace_replay,
158        &mast_forest_store,
159        max_trace_len,
160    )?;
161
162    let range_checker = initialize_range_checker(range_checker_replay, &chiplets);
163
164    let mut core_trace_data = generate_core_trace_row_major(
165        core_trace_contexts,
166        program_info.kernel().clone(),
167        fragment_size,
168        &mast_forest_store,
169        max_stack_depth,
170    )?;
171
172    let core_trace_len = core_trace_data.len() / CORE_STORAGE_WIDTH;
173
174    // Get the number of rows for the range checker
175    let range_table_len = range_checker.get_number_range_checker_rows();
176
177    let core_height = pad_to_trace_length(core_trace_len.max(range_table_len));
178    let chiplets_height = pad_to_trace_length(chiplets.trace_len());
179    let padded_trace_len = core_height.max(chiplets_height);
180
181    // Cap check against the padded height: pad-up can push over MAX_TRACE_LEN even
182    // when the unpadded check above passed.
183    if padded_trace_len > max_trace_len {
184        return Err(ExecutionError::TraceLenExceeded(max_trace_len));
185    }
186
187    let trace_len_summary = TraceLenSummary::new_with_padded(
188        core_trace_len,
189        range_table_len,
190        ChipletsLengths::new(&chiplets),
191        padded_trace_len,
192    );
193
194    // Each segment is built at its own per-AIR height (no cross-padding to the unified max).
195    let (chiplets_trace, ()) = rayon::join(
196        || chiplets.into_trace(chiplets_height),
197        || pad_core_row_major(&mut core_trace_data, core_height),
198    );
199
200    // The range checker occupies the two trailing columns of the core buffer.
201    range_checker.write_range_into_core(
202        &mut core_trace_data,
203        CORE_STORAGE_WIDTH,
204        CORE_TRACE_WIDTH,
205        CORE_TRACE_WIDTH + 1,
206        range_table_len,
207        core_height,
208    );
209
210    // Create the MainTrace
211    let main_trace = {
212        let last_program_row = RowIndex::from((core_trace_len as u32).saturating_sub(1));
213        MainTrace::from_parts(core_trace_data, chiplets_trace.trace, last_program_row)
214    };
215
216    Ok(ExecutionTrace::new_from_parts(
217        program_info,
218        trace_output,
219        main_trace,
220        trace_len_summary,
221    ))
222}
223
224// HELPERS
225// ================================================================================================
226
227/// Pad a logical row count to a valid trace length: next power of two, clamped to `MIN_TRACE_LEN`.
228fn pad_to_trace_length(logical_len: usize) -> usize {
229    logical_len.next_power_of_two().max(MIN_TRACE_LEN)
230}
231
232/// Generates row-major core trace in parallel from the provided trace fragment contexts.
233fn generate_core_trace_row_major(
234    core_trace_contexts: Vec<CoreTraceFragmentContext>,
235    kernel: Kernel,
236    fragment_size: usize,
237    mast_forest_store: &[Arc<SparseMastForest>],
238    max_stack_depth: usize,
239) -> Result<Vec<Felt>, ExecutionError> {
240    let num_fragments = core_trace_contexts.len();
241    let total_allocated_rows = num_fragments * fragment_size;
242
243    let mut core_trace_data = Felt::zero_vec(total_allocated_rows * CORE_STORAGE_WIDTH);
244
245    // Save the first stack top for initialization
246    let first_stack_top = if let Some(first_context) = core_trace_contexts.first() {
247        first_context.state.stack.stack_top.to_vec()
248    } else {
249        vec![ZERO; MIN_STACK_DEPTH]
250    };
251
252    let writers: Vec<RowMajorTraceWriter<'_, Felt>> = core_trace_data
253        .chunks_exact_mut(fragment_size * CORE_STORAGE_WIDTH)
254        .map(|chunk| {
255            RowMajorTraceWriter::with_stride(chunk, CORE_STORAGE_WIDTH, CORE_STORAGE_WIDTH)
256        })
257        .collect();
258
259    // Build the core trace fragments in parallel
260    let fragment_results: Result<Vec<_>, ExecutionError> = core_trace_contexts
261        .into_par_iter()
262        .zip(writers.into_par_iter())
263        .map(|(trace_state, writer)| {
264            let (mut processor, mut tracer, mut continuation_stack, mut current_forest) =
265                split_trace_fragment_context(
266                    trace_state,
267                    writer,
268                    fragment_size,
269                    mast_forest_store,
270                    max_stack_depth,
271                )?;
272
273            processor.execute(
274                &mut continuation_stack,
275                &mut current_forest,
276                &kernel,
277                &mut tracer,
278            )?;
279
280            tracer.into_final_state()
281        })
282        .collect();
283    let fragment_results = fragment_results?;
284
285    let mut stack_rows = Vec::new();
286    let mut system_rows = Vec::new();
287    let mut total_core_trace_rows = 0;
288
289    for final_state in fragment_results {
290        stack_rows.push(final_state.last_stack_cols);
291        system_rows.push(final_state.last_system_cols);
292        total_core_trace_rows += final_state.num_rows_written;
293    }
294
295    // Fix up stack and system rows
296    fixup_stack_and_system_rows(
297        &mut core_trace_data,
298        fragment_size,
299        &stack_rows,
300        &system_rows,
301        &first_stack_top,
302    );
303
304    // Run batch inversion on stack's H0 helper column, processing each fragment in parallel.
305    // This must be done after fixup_stack_and_system_rows since that function overwrites the first
306    // row of each fragment with non-inverted values.
307    {
308        let w = CORE_STORAGE_WIDTH;
309        core_trace_data[..total_core_trace_rows * w]
310            .par_chunks_mut(fragment_size * w)
311            .for_each(|fragment_chunk| {
312                let num_rows = fragment_chunk.len() / w;
313                let mut h0_vals: Vec<Felt> = (0..num_rows)
314                    .map(|r| {
315                        let row: &CoreCols<Felt> = fragment_chunk[r * w..(r + 1) * w].borrow();
316                        row.stack.h0
317                    })
318                    .collect();
319                batch_inversion_allow_zeros(&mut h0_vals);
320                for (r, &val) in h0_vals.iter().enumerate() {
321                    let row: &mut CoreCols<Felt> = fragment_chunk[r * w..(r + 1) * w].borrow_mut();
322                    row.stack.h0 = val;
323                }
324            });
325    }
326
327    // Truncate the core trace columns to the actual number of rows written.
328    core_trace_data.truncate(total_core_trace_rows * CORE_STORAGE_WIDTH);
329
330    push_halt_opcode_row(
331        &mut core_trace_data,
332        total_core_trace_rows,
333        system_rows.last().ok_or(ExecutionError::Internal(
334            "no trace fragments provided in the trace generation context",
335        ))?,
336        stack_rows.last().ok_or(ExecutionError::Internal(
337            "no trace fragments provided in the trace generation context",
338        ))?,
339    );
340
341    Ok(core_trace_data)
342}
343
344/// Initializing the first row of each fragment with the appropriate stack and system state.
345///
346/// This needs to be done as a separate pass after all fragments have been generated, because the
347/// system and stack rows write the state at clk `i` to the row at index `i+1`. Hence, the state of
348/// the last row of any given fragment cannot be written in parallel, since any given fragment
349/// filler doesn't have access to the next fragment's first row.
350fn fixup_stack_and_system_rows(
351    core_trace_data: &mut [Felt],
352    fragment_size: usize,
353    stack_rows: &[StackCols<Felt>],
354    system_rows: &[SystemCols<Felt>],
355    first_stack_top: &[Felt],
356) {
357    const MIN_STACK_DEPTH_FELT: Felt = Felt::new_unchecked(MIN_STACK_DEPTH as u64);
358    let w = CORE_STORAGE_WIDTH;
359
360    {
361        let row: &mut CoreCols<Felt> = core_trace_data[..w].borrow_mut();
362
363        // Stack order in the trace is reversed vs `first_stack_top`.
364        for (stack_col_idx, &value) in first_stack_top.iter().rev().enumerate() {
365            row.stack.top[stack_col_idx] = value;
366        }
367
368        row.stack.b0 = MIN_STACK_DEPTH_FELT;
369        row.stack.b1 = ZERO;
370        row.stack.h0 = ZERO;
371    }
372
373    let total_rows = core_trace_data.len() / w;
374    let num_fragments = total_rows / fragment_size;
375
376    for frag_idx in 1..num_fragments {
377        let row_idx = frag_idx * fragment_size;
378        let row_start = row_idx * w;
379        let row: &mut CoreCols<Felt> = core_trace_data[row_start..row_start + w].borrow_mut();
380        row.system = system_rows[frag_idx - 1].clone();
381        row.stack = stack_rows[frag_idx - 1].clone();
382    }
383}
384
385/// Appends a HALT row (`num_rows_before` is the row count before append).
386///
387/// This ensures that the trace ends with at least one HALT operation, which is necessary to satisfy
388/// the constraints.
389fn push_halt_opcode_row(
390    core_trace_data: &mut Vec<Felt>,
391    num_rows_before: usize,
392    last_system_state: &SystemCols<Felt>,
393    last_stack_state: &StackCols<Felt>,
394) {
395    let w = CORE_STORAGE_WIDTH;
396    let mut row_data = [ZERO; CORE_STORAGE_WIDTH];
397
398    // Read the previous row's hasher state first half before we take a mutable borrow on
399    // `row_data` (propagates the program hash into the HALT padding).
400    let prev_hasher_state_first_half: [Felt; 4] = if num_rows_before > 0 {
401        let last_row_start = (num_rows_before - 1) * w;
402        let prev: &CoreCols<Felt> = core_trace_data[last_row_start..last_row_start + w].borrow();
403        let hs = &prev.decoder.hasher_state;
404        [hs[0], hs[1], hs[2], hs[3]]
405    } else {
406        [ZERO; 4]
407    };
408
409    {
410        let row: &mut CoreCols<Felt> = row_data.as_mut_slice().borrow_mut();
411
412        row.system = last_system_state.clone();
413        row.stack = last_stack_state.clone();
414
415        // Pad op_bits columns with HALT opcode bits
416        let halt_opcode = opcodes::HALT;
417        for bit_idx in 0..NUM_OP_BITS {
418            row.decoder.op_bits[bit_idx] = Felt::from_u8((halt_opcode >> bit_idx) & 1);
419        }
420
421        // Pad hasher state columns (8 columns)
422        // - First 4 columns: copy the last value (to propagate program hash)
423        // - Remaining 4 columns: fill with ZEROs
424        row.decoder.hasher_state[..4].copy_from_slice(&prev_hasher_state_first_half);
425
426        // Pad op_bit_extra columns (2 columns)
427        // - First column: do nothing (pre-filled with ZEROs, HALT doesn't use this)
428        // - Second column: fill with ONEs (product of two most significant HALT bits, both are 1)
429        row.decoder.extra[1] = ONE;
430    }
431
432    core_trace_data.extend_from_slice(&row_data);
433}
434
435/// Initializes the ranger checker from the recorded range checks during execution and returns it.
436///
437/// Note that the maximum number of rows that the range checker can produce is 2^16, which is less
438/// than the maximum trace length (2^29). Hence, we can safely generate the entire range checker
439/// trace and then pad it to the final trace length, without worrying about hitting memory limits.
440fn initialize_range_checker(
441    range_checker_replay: RangeCheckerReplay,
442    chiplets: &Chiplets,
443) -> RangeChecker {
444    let mut range_checker = RangeChecker::new();
445
446    // Add all u32 range checks recorded during execution
447    for values in range_checker_replay {
448        range_checker.add_range_checks(&values);
449    }
450
451    // Add all memory-related range checks
452    chiplets.append_range_checks(&mut range_checker);
453
454    range_checker
455}
456
457/// Replays recorded operations to populate chiplet traces. Results were already used during
458/// execution; this pass only needs the trace-recording side effects.
459fn initialize_chiplets(
460    kernel: Kernel,
461    core_trace_contexts: &[CoreTraceFragmentContext],
462    memory_writes: MemoryWritesReplay,
463    bitwise: BitwiseReplay,
464    kernel_replay: KernelReplay,
465    hasher_for_chiplet: HasherRequestReplay,
466    ace_replay: AceReplay,
467    mast_forest_store: &[Arc<SparseMastForest>],
468    max_trace_len: usize,
469) -> Result<Chiplets, ExecutionError> {
470    let check_chiplets_trace_len = |chiplets: &Chiplets| -> Result<(), ExecutionError> {
471        if chiplets.trace_len() > max_trace_len {
472            return Err(ExecutionError::TraceLenExceeded(max_trace_len));
473        }
474        Ok(())
475    };
476
477    let mut chiplets = Chiplets::new(kernel);
478
479    // populate hasher chiplet
480    for hasher_op in hasher_for_chiplet.into_iter() {
481        match hasher_op {
482            HasherOp::Permute(input_state) => {
483                let _ = chiplets.hasher.permute(input_state);
484                check_chiplets_trace_len(&chiplets)?;
485            },
486            HasherOp::HashControlBlock((h1, h2, domain, expected_hash)) => {
487                let _ = chiplets.hasher.hash_control_block(h1, h2, domain, expected_hash);
488                check_chiplets_trace_len(&chiplets)?;
489            },
490            HasherOp::HashBasicBlock((forest_id, node_id, expected_hash)) => {
491                let forest =
492                    mast_forest_store.get(forest_id.to_usize()).ok_or(ExecutionError::Internal(
493                        "MAST forest id in hasher replay out of range of mast_forest_store",
494                    ))?;
495                let node = forest
496                    .get_node_by_id(node_id)
497                    .ok_or(ExecutionError::Internal("invalid node ID in hasher replay"))?;
498                let MastNode::Block(basic_block_node) = node else {
499                    return Err(ExecutionError::Internal(
500                        "expected basic block node in hasher replay",
501                    ));
502                };
503                let op_batches = basic_block_node.op_batches();
504                let _ = chiplets.hasher.hash_basic_block(op_batches, expected_hash);
505                check_chiplets_trace_len(&chiplets)?;
506            },
507            HasherOp::BuildMerkleRoot((value, path, index)) => {
508                let _ = chiplets.hasher.build_merkle_root(value, &path, index);
509                check_chiplets_trace_len(&chiplets)?;
510            },
511            HasherOp::UpdateMerkleRoot((old_value, new_value, path, index)) => {
512                chiplets.hasher.update_merkle_root(old_value, new_value, &path, index);
513                check_chiplets_trace_len(&chiplets)?;
514            },
515        }
516    }
517
518    // populate bitwise chiplet
519    for (bitwise_op, a, b) in bitwise {
520        match bitwise_op {
521            BitwiseOp::U32And => {
522                chiplets.bitwise.u32and(a, b).map_exec_err_no_ctx()?;
523                check_chiplets_trace_len(&chiplets)?;
524            },
525            BitwiseOp::U32Xor => {
526                chiplets.bitwise.u32xor(a, b).map_exec_err_no_ctx()?;
527                check_chiplets_trace_len(&chiplets)?;
528            },
529        }
530    }
531
532    // populate memory chiplet
533    //
534    // Note: care is taken to order all the accesses by clock cycle, since the memory chiplet
535    // currently assumes that all memory accesses are issued in the same order as they appear in
536    // the trace.
537    {
538        let elements_written: Box<dyn Iterator<Item = MemoryAccess>> =
539            Box::new(memory_writes.iter_elements_written().map(|(element, addr, ctx, clk)| {
540                MemoryAccess::WriteElement(*addr, *element, *ctx, *clk)
541            }));
542        let words_written: Box<dyn Iterator<Item = MemoryAccess>> = Box::new(
543            memory_writes
544                .iter_words_written()
545                .map(|(word, addr, ctx, clk)| MemoryAccess::WriteWord(*addr, *word, *ctx, *clk)),
546        );
547        let elements_read: Box<dyn Iterator<Item = MemoryAccess>> =
548            Box::new(core_trace_contexts.iter().flat_map(|ctx| {
549                ctx.replay
550                    .memory_reads
551                    .iter_read_elements()
552                    .map(|(_, addr, ctx, clk)| MemoryAccess::ReadElement(addr, ctx, clk))
553            }));
554        let words_read: Box<dyn Iterator<Item = MemoryAccess>> =
555            Box::new(core_trace_contexts.iter().flat_map(|ctx| {
556                ctx.replay
557                    .memory_reads
558                    .iter_read_words()
559                    .map(|(_, addr, ctx, clk)| MemoryAccess::ReadWord(addr, ctx, clk))
560            }));
561
562        [elements_written, words_written, elements_read, words_read]
563            .into_iter()
564            .kmerge_by(|a, b| a.clk() < b.clk())
565            .try_for_each(|mem_access| {
566                match mem_access {
567                    MemoryAccess::ReadElement(addr, ctx, clk) => chiplets
568                        .memory
569                        .read(ctx, addr, clk)
570                        .map(|_| ())
571                        .map_err(ExecutionError::MemoryErrorNoCtx)?,
572                    MemoryAccess::WriteElement(addr, element, ctx, clk) => chiplets
573                        .memory
574                        .write(ctx, addr, clk, element)
575                        .map_err(ExecutionError::MemoryErrorNoCtx)?,
576                    MemoryAccess::ReadWord(addr, ctx, clk) => chiplets
577                        .memory
578                        .read_word(ctx, addr, clk)
579                        .map(|_| ())
580                        .map_err(ExecutionError::MemoryErrorNoCtx)?,
581                    MemoryAccess::WriteWord(addr, word, ctx, clk) => chiplets
582                        .memory
583                        .write_word(ctx, addr, clk, word)
584                        .map_err(ExecutionError::MemoryErrorNoCtx)?,
585                }
586                check_chiplets_trace_len(&chiplets)
587            })?;
588
589        enum MemoryAccess {
590            ReadElement(Felt, ContextId, RowIndex),
591            WriteElement(Felt, Felt, ContextId, RowIndex),
592            ReadWord(Felt, ContextId, RowIndex),
593            WriteWord(Felt, Word, ContextId, RowIndex),
594        }
595
596        impl MemoryAccess {
597            fn clk(&self) -> RowIndex {
598                match self {
599                    MemoryAccess::ReadElement(_, _, clk) => *clk,
600                    MemoryAccess::WriteElement(_, _, _, clk) => *clk,
601                    MemoryAccess::ReadWord(_, _, clk) => *clk,
602                    MemoryAccess::WriteWord(_, _, _, clk) => *clk,
603                }
604            }
605        }
606    }
607
608    // populate ACE chiplet
609    for (clk, circuit_eval) in ace_replay.into_iter() {
610        chiplets.ace.add_circuit_evaluation(clk, circuit_eval);
611        check_chiplets_trace_len(&chiplets)?;
612    }
613
614    // populate kernel ROM
615    for proc_hash in kernel_replay.into_iter() {
616        chiplets.kernel_rom.access_proc(proc_hash).map_exec_err_no_ctx()?;
617        check_chiplets_trace_len(&chiplets)?;
618    }
619
620    Ok(chiplets)
621}
622
623/// Pads the core trace to `core_height` rows (HALT template, CLK incremented per row).
624fn pad_core_row_major(core_trace_data: &mut Vec<Felt>, core_height: usize) {
625    let w = CORE_STORAGE_WIDTH;
626    let total_program_rows = core_trace_data.len() / w;
627    assert!(total_program_rows <= core_height);
628    assert!(total_program_rows > 0);
629
630    let num_padding_rows = core_height - total_program_rows;
631    if num_padding_rows == 0 {
632        return;
633    }
634    let last_row_start = (total_program_rows - 1) * w;
635
636    // Safety: per our documented safety guarantees, we know that `total_program_rows > 0`,
637    // and row `total_program_rows - 1` is initialized.
638    let (last_hasher_first_half, last_stack): ([Felt; 4], StackCols<Felt>) = {
639        let last: &CoreCols<Felt> = core_trace_data[last_row_start..last_row_start + w].borrow();
640        let hs = &last.decoder.hasher_state;
641        let last_hasher: [Felt; 4] = [hs[0], hs[1], hs[2], hs[3]];
642        (last_hasher, last.stack.clone())
643    };
644
645    let mut template_data = [ZERO; CORE_STORAGE_WIDTH];
646    {
647        let template: &mut CoreCols<Felt> = template_data.as_mut_slice().borrow_mut();
648
649        // Decoder columns
650        // ------------------------
651
652        // Pad op_bits columns with HALT opcode bits
653        let halt_opcode = opcodes::HALT;
654        for i in 0..NUM_OP_BITS {
655            template.decoder.op_bits[i] = Felt::from_u8((halt_opcode >> i) & 1);
656        }
657        // Pad hasher state columns (8 columns)
658        // - First 4 columns: copy the last value (to propagate program hash)
659        // - Remaining 4 columns: fill with ZEROs
660        template.decoder.hasher_state[..4].copy_from_slice(&last_hasher_first_half);
661
662        // Pad op_bit_extra columns (2 columns)
663        // - First column: do nothing (filled with ZEROs, HALT doesn't use this)
664        // - Second column: fill with ONEs (product of two most significant HALT bits, both are 1)
665        template.decoder.extra[1] = ONE;
666
667        // Stack columns
668        // ------------------------
669
670        // Pad stack columns with the last value in each column (analogous to Stack::into_trace())
671        template.stack = last_stack;
672    }
673
674    // System columns
675    // ------------------------
676
677    // Pad CLK trace - fill with index values
678
679    let pad_start = total_program_rows * w;
680    core_trace_data.resize(pad_start + num_padding_rows * w, ZERO);
681    core_trace_data[pad_start..]
682        .par_chunks_mut(w)
683        .enumerate()
684        .for_each(|(idx, row_buf)| {
685            row_buf.copy_from_slice(&template_data);
686            let row: &mut CoreCols<Felt> = row_buf.borrow_mut();
687            row.system.clk = Felt::from_u32((total_program_rows + idx) as u32);
688        });
689}
690
691type SplitFragmentContext<'a> = (
692    ReplayProcessor,
693    CoreTraceGenerationTracer<'a>,
694    ContinuationStack<Arc<SparseMastForest>>,
695    Arc<SparseMastForest>,
696);
697
698/// Uses the provided `CoreTraceFragmentContext` to build and return a `ReplayProcessor` and
699/// `CoreTraceGenerationTracer` that can be used to execute the fragment.
700///
701/// `mast_forest_store` provides the [`SparseMastForest`]s that the indices stored in the fragment
702/// (the initial forest index and the `EnterForest` continuations) refer to.
703///
704/// # Errors
705///
706/// Returns [`ExecutionError::Internal`] if any [`MastForestId`] referenced by the fragment
707/// (either `initial_mast_forest_id` or an `EnterForest` continuation) is out of range of
708/// `mast_forest_store`. Because [`CoreTraceFragmentContext`] is attacker-controllable when fed in
709/// from outside, we validate these indices rather than indexing-and-panicking.
710fn split_trace_fragment_context<'a>(
711    fragment_context: CoreTraceFragmentContext,
712    writer: RowMajorTraceWriter<'a, Felt>,
713    fragment_size: usize,
714    mast_forest_store: &[Arc<SparseMastForest>],
715    max_stack_depth: usize,
716) -> Result<SplitFragmentContext<'a>, ExecutionError> {
717    let CoreTraceFragmentContext {
718        state: CoreTraceState { system, decoder, stack },
719        replay:
720            ExecutionReplay {
721                block_stack: block_stack_replay,
722                execution_context: execution_context_replay,
723                stack_overflow: stack_overflow_replay,
724                memory_reads: memory_reads_replay,
725                advice: advice_replay,
726                hasher: hasher_response_replay,
727                block_address: block_address_replay,
728                mast_forest_resolution: mast_forest_resolution_replay,
729            },
730        continuation,
731        initial_mast_forest_id,
732    } = fragment_context;
733
734    let translated_continuation =
735        translate_snapshot_continuation_stack(continuation, mast_forest_store)?;
736
737    let initial_mast_forest =
738        lookup_mast_forest(mast_forest_store, initial_mast_forest_id)?.clone();
739
740    let processor = ReplayProcessor::new(
741        system,
742        stack,
743        stack_overflow_replay,
744        execution_context_replay,
745        advice_replay,
746        memory_reads_replay,
747        hasher_response_replay,
748        mast_forest_resolution_replay,
749        mast_forest_store.to_vec(),
750        max_stack_depth,
751        fragment_size.into(),
752    );
753    let tracer =
754        CoreTraceGenerationTracer::new(writer, decoder, block_address_replay, block_stack_replay);
755
756    Ok((processor, tracer, translated_continuation, initial_mast_forest))
757}
758
759/// Translates a snapshotted `ContinuationStack<MastForestId>` into one carrying actual
760/// [`Arc<SparseMastForest>`] handles, ready to drive `execute_impl`.
761///
762/// Returns [`ExecutionError::Internal`] if any `EnterForest` continuation carries a
763/// [`MastForestId`] that is out of range of `mast_forest_store`.
764fn translate_snapshot_continuation_stack(
765    snapshot: ContinuationStack<MastForestId>,
766    mast_forest_store: &[Arc<SparseMastForest>],
767) -> Result<ContinuationStack<Arc<SparseMastForest>>, ExecutionError> {
768    let mut out: ContinuationStack<Arc<SparseMastForest>> = ContinuationStack::default();
769    for cont in snapshot.into_inner() {
770        let translated = match cont {
771            Continuation::EnterForest { forest: id, package_debug_info } => {
772                Continuation::EnterForest {
773                    forest: lookup_mast_forest(mast_forest_store, id)?.clone(),
774                    package_debug_info,
775                }
776            },
777            Continuation::StartNode(id) => Continuation::StartNode(id),
778            Continuation::FinishJoin(id) => Continuation::FinishJoin(id),
779            Continuation::FinishSplit(id) => Continuation::FinishSplit(id),
780            Continuation::FinishLoop(node_id) => Continuation::FinishLoop(node_id),
781            Continuation::FinishCall(id) => Continuation::FinishCall(id),
782            Continuation::FinishDyn(id) => Continuation::FinishDyn(id),
783            Continuation::ResumeBasicBlock { node_id, batch_index, op_idx_in_batch } => {
784                Continuation::ResumeBasicBlock { node_id, batch_index, op_idx_in_batch }
785            },
786            Continuation::Respan { node_id, batch_index } => {
787                Continuation::Respan { node_id, batch_index }
788            },
789            Continuation::FinishBasicBlock(id) => Continuation::FinishBasicBlock(id),
790        };
791        out.push_continuation(translated);
792    }
793    Ok(out)
794}
795
796/// Looks up `id` in `mast_forest_store`, returning [`ExecutionError::Internal`] if it is out of
797/// range.
798pub(super) fn lookup_mast_forest(
799    mast_forest_store: &[Arc<SparseMastForest>],
800    id: MastForestId,
801) -> Result<&Arc<SparseMastForest>, ExecutionError> {
802    mast_forest_store
803        .get(id.to_usize())
804        .ok_or(ExecutionError::Internal("MastForestId out of range of mast_forest_store"))
805}