use alloc::{boxed::Box, sync::Arc, vec::Vec};
use itertools::Itertools;
use miden_air::{
Felt,
trace::{
CLK_COL_IDX, CTX_COL_IDX, DECODER_TRACE_OFFSET, DECODER_TRACE_WIDTH, FN_HASH_RANGE,
MIN_TRACE_LEN, MainTrace, PADDED_TRACE_WIDTH, RowIndex, STACK_TRACE_OFFSET,
STACK_TRACE_WIDTH, SYS_TRACE_WIDTH, TRACE_WIDTH,
decoder::{
ADDR_COL_IDX, GROUP_COUNT_COL_IDX, HASHER_STATE_OFFSET, IN_SPAN_COL_IDX,
NUM_HASHER_COLUMNS, NUM_OP_BATCH_FLAGS, NUM_OP_BITS, OP_BATCH_FLAGS_OFFSET,
OP_BITS_EXTRA_COLS_OFFSET, OP_BITS_OFFSET, OP_INDEX_COL_IDX,
},
stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX, STACK_TOP_OFFSET},
},
};
use miden_core::{
ONE, Word, ZERO,
field::batch_inversion_allow_zeros,
mast::{MastForest, MastNode},
operations::opcodes,
program::{Kernel, MIN_STACK_DEPTH, ProgramInfo},
utils::ColMatrix,
};
use rayon::prelude::*;
use tracing::instrument;
use crate::{
ContextId, ExecutionError,
continuation_stack::ContinuationStack,
errors::MapExecErrNoCtx,
fast::ExecutionOutput,
trace::{
AuxTraceBuilders, ChipletsLengths, ExecutionTrace, TraceLenSummary,
parallel::{processor::ReplayProcessor, tracer::CoreTraceGenerationTracer},
range::RangeChecker,
},
};
pub const CORE_TRACE_WIDTH: usize = SYS_TRACE_WIDTH + DECODER_TRACE_WIDTH + STACK_TRACE_WIDTH;
const MAX_TRACE_LEN: usize = 1 << 29;
pub(crate) mod core_trace_fragment;
use core_trace_fragment::CoreTraceFragment;
mod processor;
mod tracer;
use super::{
chiplets::Chiplets,
decoder::AuxTraceBuilder as DecoderAuxTraceBuilder,
execution_tracer::TraceGenerationContext,
stack::AuxTraceBuilder as StackAuxTraceBuilder,
trace_state::{
AceReplay, BitwiseOp, BitwiseReplay, CoreTraceFragmentContext, CoreTraceState,
ExecutionReplay, HasherOp, HasherRequestReplay, KernelReplay, MemoryWritesReplay,
RangeCheckerReplay,
},
};
#[cfg(test)]
mod tests;
#[instrument(name = "build_trace", skip_all)]
pub fn build_trace(
execution_output: ExecutionOutput,
trace_generation_context: TraceGenerationContext,
program_info: ProgramInfo,
) -> Result<ExecutionTrace, ExecutionError> {
build_trace_with_max_len(
execution_output,
trace_generation_context,
program_info,
MAX_TRACE_LEN,
)
}
pub fn build_trace_with_max_len(
execution_output: ExecutionOutput,
trace_generation_context: TraceGenerationContext,
program_info: ProgramInfo,
max_trace_len: usize,
) -> Result<ExecutionTrace, ExecutionError> {
let TraceGenerationContext {
core_trace_contexts,
range_checker_replay,
memory_writes,
bitwise_replay: bitwise,
kernel_replay,
hasher_for_chiplet,
ace_replay,
fragment_size,
} = trace_generation_context;
let total_core_trace_rows = core_trace_contexts
.len()
.checked_mul(fragment_size)
.and_then(|n| n.checked_add(1))
.ok_or(ExecutionError::TraceLenExceeded(max_trace_len))?;
if total_core_trace_rows > max_trace_len {
return Err(ExecutionError::TraceLenExceeded(max_trace_len));
}
if core_trace_contexts.is_empty() {
return Err(ExecutionError::Internal(
"no trace fragments provided in the trace generation context",
));
}
let chiplets = initialize_chiplets(
program_info.kernel().clone(),
&core_trace_contexts,
memory_writes,
bitwise,
kernel_replay,
hasher_for_chiplet,
ace_replay,
max_trace_len,
)?;
let range_checker = initialize_range_checker(range_checker_replay, &chiplets);
let mut core_trace_columns = generate_core_trace_columns(
core_trace_contexts,
program_info.kernel().clone(),
fragment_size,
)?;
let core_trace_len = core_trace_columns[0].len();
let range_table_len = range_checker.get_number_range_checker_rows();
let trace_len_summary =
TraceLenSummary::new(core_trace_len, range_table_len, ChipletsLengths::new(&chiplets));
let main_trace_len =
compute_main_trace_length(core_trace_len, range_table_len, chiplets.trace_len());
let ((), (range_checker_trace, chiplets_trace)) = rayon::join(
|| pad_trace_columns(&mut core_trace_columns, main_trace_len),
|| {
rayon::join(
|| range_checker.into_trace_with_table(range_table_len, main_trace_len),
|| chiplets.into_trace(main_trace_len),
)
},
);
let padding_columns = vec![vec![ZERO; main_trace_len]; PADDED_TRACE_WIDTH - TRACE_WIDTH];
let trace_columns: Vec<Vec<Felt>> = core_trace_columns
.into_iter()
.chain(range_checker_trace.trace)
.chain(chiplets_trace.trace)
.chain(padding_columns)
.collect();
let main_trace = {
let last_program_row = RowIndex::from((core_trace_len as u32).saturating_sub(1));
let col_matrix = ColMatrix::new(trace_columns);
MainTrace::new(col_matrix, last_program_row)
};
let aux_trace_builders = AuxTraceBuilders {
decoder: DecoderAuxTraceBuilder::default(),
range: range_checker_trace.aux_builder,
chiplets: chiplets_trace.aux_builder,
stack: StackAuxTraceBuilder,
};
Ok(ExecutionTrace::new_from_parts(
program_info,
execution_output,
main_trace,
aux_trace_builders,
trace_len_summary,
))
}
fn compute_main_trace_length(
core_trace_len: usize,
range_table_len: usize,
chiplets_trace_len: usize,
) -> usize {
let max_len = range_table_len.max(core_trace_len).max(chiplets_trace_len);
let trace_len = max_len.next_power_of_two();
core::cmp::max(trace_len, MIN_TRACE_LEN)
}
fn generate_core_trace_columns(
core_trace_contexts: Vec<CoreTraceFragmentContext>,
kernel: Kernel,
fragment_size: usize,
) -> Result<Vec<Vec<Felt>>, ExecutionError> {
let mut core_trace_columns: Vec<Vec<Felt>> =
vec![vec![ZERO; core_trace_contexts.len() * fragment_size]; CORE_TRACE_WIDTH];
let first_stack_top = if let Some(first_context) = core_trace_contexts.first() {
first_context.state.stack.stack_top.to_vec()
} else {
vec![ZERO; MIN_STACK_DEPTH]
};
let mut fragments = create_fragments_from_trace_columns(&mut core_trace_columns, fragment_size);
let fragment_results: Result<Vec<_>, ExecutionError> = core_trace_contexts
.into_par_iter()
.zip(fragments.par_iter_mut())
.map(|(trace_state, fragment)| {
let (mut processor, mut tracer, mut continuation_stack, mut current_forest) =
split_trace_fragment_context(trace_state, fragment, fragment_size);
processor.execute(
&mut continuation_stack,
&mut current_forest,
&kernel,
&mut tracer,
)?;
tracer.into_final_state()
})
.collect();
let fragment_results = fragment_results?;
let mut stack_rows = Vec::new();
let mut system_rows = Vec::new();
let mut total_core_trace_rows = 0;
for final_state in fragment_results {
stack_rows.push(final_state.last_stack_cols);
system_rows.push(final_state.last_system_cols);
total_core_trace_rows += final_state.num_rows_written;
}
fixup_stack_and_system_rows(
&mut core_trace_columns,
fragment_size,
&stack_rows,
&system_rows,
&first_stack_top,
);
{
let h0_column = &mut core_trace_columns[STACK_TRACE_OFFSET + H0_COL_IDX];
h0_column[..total_core_trace_rows]
.par_chunks_mut(fragment_size)
.for_each(batch_inversion_allow_zeros);
}
for col in core_trace_columns.iter_mut() {
col.truncate(total_core_trace_rows);
}
push_halt_opcode_row(
&mut core_trace_columns,
system_rows.last().ok_or(ExecutionError::Internal(
"no trace fragments provided in the trace generation context",
))?,
stack_rows.last().ok_or(ExecutionError::Internal(
"no trace fragments provided in the trace generation context",
))?,
);
Ok(core_trace_columns)
}
fn create_fragments_from_trace_columns(
core_trace_columns: &mut [Vec<Felt>],
fragment_size: usize,
) -> Vec<CoreTraceFragment<'_>> {
let mut column_chunks: Vec<_> = core_trace_columns
.iter_mut()
.map(|col| col.chunks_exact_mut(fragment_size))
.collect();
let mut core_trace_fragments = Vec::new();
loop {
let fragment_cols: Vec<&mut [Felt]> =
column_chunks.iter_mut().filter_map(|col_chunk| col_chunk.next()).collect();
assert!(
fragment_cols.is_empty() || fragment_cols.len() == CORE_TRACE_WIDTH,
"column chunks don't all have the same size"
);
if fragment_cols.is_empty() {
return core_trace_fragments;
} else {
core_trace_fragments.push(CoreTraceFragment {
columns: fragment_cols.try_into().expect("fragment has CORE_TRACE_WIDTH columns"),
});
}
}
}
fn fixup_stack_and_system_rows(
core_trace_columns: &mut [Vec<Felt>],
fragment_size: usize,
stack_rows: &[[Felt; STACK_TRACE_WIDTH]],
system_rows: &[[Felt; SYS_TRACE_WIDTH]],
first_stack_top: &[Felt],
) {
const MIN_STACK_DEPTH_FELT: Felt = Felt::new(MIN_STACK_DEPTH as u64);
let system_state_first_row = [
ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ];
{
for (col_idx, &value) in system_state_first_row.iter().enumerate() {
core_trace_columns[col_idx][0] = value;
}
for (stack_col_idx, &value) in first_stack_top.iter().rev().enumerate() {
core_trace_columns[STACK_TRACE_OFFSET + STACK_TOP_OFFSET + stack_col_idx][0] = value;
}
core_trace_columns[STACK_TRACE_OFFSET + B0_COL_IDX][0] = MIN_STACK_DEPTH_FELT;
core_trace_columns[STACK_TRACE_OFFSET + B1_COL_IDX][0] = ZERO;
core_trace_columns[STACK_TRACE_OFFSET + H0_COL_IDX][0] = ZERO;
}
let fragment_start_row_indices = {
let num_fragments = core_trace_columns[0].len() / fragment_size;
(0..).step_by(fragment_size).take(num_fragments).skip(1)
};
for (row_idx, (system_row, stack_row)) in
fragment_start_row_indices.zip(system_rows.iter().zip(stack_rows.iter()))
{
for (col_idx, &value) in system_row.iter().enumerate() {
core_trace_columns[col_idx][row_idx] = value;
}
for (col_idx, &value) in stack_row.iter().enumerate() {
core_trace_columns[STACK_TRACE_OFFSET + col_idx][row_idx] = value;
}
}
}
fn push_halt_opcode_row(
core_trace_columns: &mut [Vec<Felt>],
last_system_state: &[Felt; SYS_TRACE_WIDTH],
last_stack_state: &[Felt; STACK_TRACE_WIDTH],
) {
for (col_idx, &value) in last_system_state.iter().enumerate() {
core_trace_columns[col_idx].push(value);
}
for (col_idx, &value) in last_stack_state.iter().enumerate() {
core_trace_columns[STACK_TRACE_OFFSET + col_idx].push(value);
}
core_trace_columns[DECODER_TRACE_OFFSET + ADDR_COL_IDX].push(ZERO);
let halt_opcode = opcodes::HALT;
for bit_idx in 0..NUM_OP_BITS {
let bit_value = Felt::from_u8((halt_opcode >> bit_idx) & 1);
core_trace_columns[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + bit_idx].push(bit_value);
}
for hasher_col_idx in 0..NUM_HASHER_COLUMNS {
let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + hasher_col_idx;
if hasher_col_idx < 4 {
let last_row_idx = core_trace_columns[col_idx].len() - 1;
let last_hasher_value = core_trace_columns[col_idx][last_row_idx];
core_trace_columns[col_idx].push(last_hasher_value);
} else {
core_trace_columns[col_idx].push(ZERO);
}
}
core_trace_columns[DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX].push(ZERO);
core_trace_columns[DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX].push(ZERO);
core_trace_columns[DECODER_TRACE_OFFSET + OP_INDEX_COL_IDX].push(ZERO);
for batch_flag_idx in 0..NUM_OP_BATCH_FLAGS {
let col_idx = DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET + batch_flag_idx;
core_trace_columns[col_idx].push(ZERO);
}
core_trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET].push(ZERO);
core_trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1].push(ONE);
}
fn initialize_range_checker(
range_checker_replay: RangeCheckerReplay,
chiplets: &Chiplets,
) -> RangeChecker {
let mut range_checker = RangeChecker::new();
for (clk, values) in range_checker_replay.into_iter() {
range_checker.add_range_checks(clk, &values);
}
chiplets.append_range_checks(&mut range_checker);
range_checker
}
fn initialize_chiplets(
kernel: Kernel,
core_trace_contexts: &[CoreTraceFragmentContext],
memory_writes: MemoryWritesReplay,
bitwise: BitwiseReplay,
kernel_replay: KernelReplay,
hasher_for_chiplet: HasherRequestReplay,
ace_replay: AceReplay,
max_trace_len: usize,
) -> Result<Chiplets, ExecutionError> {
let check_chiplets_trace_len = |chiplets: &Chiplets| -> Result<(), ExecutionError> {
if chiplets.trace_len() > max_trace_len {
return Err(ExecutionError::TraceLenExceeded(max_trace_len));
}
Ok(())
};
let mut chiplets = Chiplets::new(kernel);
for hasher_op in hasher_for_chiplet.into_iter() {
match hasher_op {
HasherOp::Permute(input_state) => {
let _ = chiplets.hasher.permute(input_state);
check_chiplets_trace_len(&chiplets)?;
},
HasherOp::HashControlBlock((h1, h2, domain, expected_hash)) => {
let _ = chiplets.hasher.hash_control_block(h1, h2, domain, expected_hash);
check_chiplets_trace_len(&chiplets)?;
},
HasherOp::HashBasicBlock((forest, node_id, expected_hash)) => {
let node = forest
.get_node_by_id(node_id)
.ok_or(ExecutionError::Internal("invalid node ID in hasher replay"))?;
let MastNode::Block(basic_block_node) = node else {
return Err(ExecutionError::Internal(
"expected basic block node in hasher replay",
));
};
let op_batches = basic_block_node.op_batches();
let _ = chiplets.hasher.hash_basic_block(op_batches, expected_hash);
check_chiplets_trace_len(&chiplets)?;
},
HasherOp::BuildMerkleRoot((value, path, index)) => {
let _ = chiplets.hasher.build_merkle_root(value, &path, index);
check_chiplets_trace_len(&chiplets)?;
},
HasherOp::UpdateMerkleRoot((old_value, new_value, path, index)) => {
chiplets.hasher.update_merkle_root(old_value, new_value, &path, index);
check_chiplets_trace_len(&chiplets)?;
},
}
}
for (bitwise_op, a, b) in bitwise {
match bitwise_op {
BitwiseOp::U32And => {
chiplets.bitwise.u32and(a, b).map_exec_err_no_ctx()?;
check_chiplets_trace_len(&chiplets)?;
},
BitwiseOp::U32Xor => {
chiplets.bitwise.u32xor(a, b).map_exec_err_no_ctx()?;
check_chiplets_trace_len(&chiplets)?;
},
}
}
{
let elements_written: Box<dyn Iterator<Item = MemoryAccess>> =
Box::new(memory_writes.iter_elements_written().map(|(element, addr, ctx, clk)| {
MemoryAccess::WriteElement(*addr, *element, *ctx, *clk)
}));
let words_written: Box<dyn Iterator<Item = MemoryAccess>> = Box::new(
memory_writes
.iter_words_written()
.map(|(word, addr, ctx, clk)| MemoryAccess::WriteWord(*addr, *word, *ctx, *clk)),
);
let elements_read: Box<dyn Iterator<Item = MemoryAccess>> =
Box::new(core_trace_contexts.iter().flat_map(|ctx| {
ctx.replay
.memory_reads
.iter_read_elements()
.map(|(_, addr, ctx, clk)| MemoryAccess::ReadElement(addr, ctx, clk))
}));
let words_read: Box<dyn Iterator<Item = MemoryAccess>> =
Box::new(core_trace_contexts.iter().flat_map(|ctx| {
ctx.replay
.memory_reads
.iter_read_words()
.map(|(_, addr, ctx, clk)| MemoryAccess::ReadWord(addr, ctx, clk))
}));
[elements_written, words_written, elements_read, words_read]
.into_iter()
.kmerge_by(|a, b| a.clk() < b.clk())
.try_for_each(|mem_access| {
match mem_access {
MemoryAccess::ReadElement(addr, ctx, clk) => chiplets
.memory
.read(ctx, addr, clk)
.map(|_| ())
.map_err(ExecutionError::MemoryErrorNoCtx)?,
MemoryAccess::WriteElement(addr, element, ctx, clk) => chiplets
.memory
.write(ctx, addr, clk, element)
.map_err(ExecutionError::MemoryErrorNoCtx)?,
MemoryAccess::ReadWord(addr, ctx, clk) => chiplets
.memory
.read_word(ctx, addr, clk)
.map(|_| ())
.map_err(ExecutionError::MemoryErrorNoCtx)?,
MemoryAccess::WriteWord(addr, word, ctx, clk) => chiplets
.memory
.write_word(ctx, addr, clk, word)
.map_err(ExecutionError::MemoryErrorNoCtx)?,
}
check_chiplets_trace_len(&chiplets)
})?;
enum MemoryAccess {
ReadElement(Felt, ContextId, RowIndex),
WriteElement(Felt, Felt, ContextId, RowIndex),
ReadWord(Felt, ContextId, RowIndex),
WriteWord(Felt, Word, ContextId, RowIndex),
}
impl MemoryAccess {
fn clk(&self) -> RowIndex {
match self {
MemoryAccess::ReadElement(_, _, clk) => *clk,
MemoryAccess::WriteElement(_, _, _, clk) => *clk,
MemoryAccess::ReadWord(_, _, clk) => *clk,
MemoryAccess::WriteWord(_, _, _, clk) => *clk,
}
}
}
}
for (clk, circuit_eval) in ace_replay.into_iter() {
chiplets.ace.add_circuit_evaluation(clk, circuit_eval);
check_chiplets_trace_len(&chiplets)?;
}
for proc_hash in kernel_replay.into_iter() {
chiplets.kernel_rom.access_proc(proc_hash).map_exec_err_no_ctx()?;
check_chiplets_trace_len(&chiplets)?;
}
Ok(chiplets)
}
fn pad_trace_columns(trace_columns: &mut [Vec<Felt>], main_trace_len: usize) {
let total_program_rows = trace_columns[0].len();
assert!(total_program_rows <= main_trace_len);
let num_padding_rows = main_trace_len - total_program_rows;
for padding_row_idx in 0..num_padding_rows {
trace_columns[CLK_COL_IDX]
.push(Felt::from_u32((total_program_rows + padding_row_idx) as u32));
}
trace_columns[CTX_COL_IDX].resize(main_trace_len, ZERO);
for fn_hash_col_idx in FN_HASH_RANGE {
trace_columns[fn_hash_col_idx].resize(main_trace_len, ZERO);
}
trace_columns[DECODER_TRACE_OFFSET + ADDR_COL_IDX].resize(main_trace_len, ZERO);
let halt_opcode = opcodes::HALT;
for i in 0..NUM_OP_BITS {
let bit_value = Felt::from_u8((halt_opcode >> i) & 1);
trace_columns[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + i].resize(main_trace_len, bit_value);
}
for i in 0..NUM_HASHER_COLUMNS {
let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + i;
if i < 4 {
let last_hasher_value = trace_columns[col_idx][total_program_rows - 1];
trace_columns[col_idx].resize(main_trace_len, last_hasher_value);
} else {
trace_columns[col_idx].resize(main_trace_len, ZERO);
}
}
trace_columns[DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX].resize(main_trace_len, ZERO);
trace_columns[DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX].resize(main_trace_len, ZERO);
trace_columns[DECODER_TRACE_OFFSET + OP_INDEX_COL_IDX].resize(main_trace_len, ZERO);
for i in 0..NUM_OP_BATCH_FLAGS {
trace_columns[DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET + i]
.resize(main_trace_len, ZERO);
}
trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET].resize(main_trace_len, ZERO);
trace_columns[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1].resize(main_trace_len, ONE);
for i in 0..STACK_TRACE_WIDTH {
let col_idx = STACK_TRACE_OFFSET + i;
let last_stack_value = trace_columns[col_idx][total_program_rows - 1];
trace_columns[col_idx].resize(main_trace_len, last_stack_value);
}
}
fn split_trace_fragment_context<'a>(
fragment_context: CoreTraceFragmentContext,
fragment: &'a mut CoreTraceFragment<'a>,
fragment_size: usize,
) -> (
ReplayProcessor,
CoreTraceGenerationTracer<'a>,
ContinuationStack,
Arc<MastForest>,
) {
let CoreTraceFragmentContext {
state: CoreTraceState { system, decoder, stack },
replay:
ExecutionReplay {
block_stack: block_stack_replay,
execution_context: execution_context_replay,
stack_overflow: stack_overflow_replay,
memory_reads: memory_reads_replay,
advice: advice_replay,
hasher: hasher_response_replay,
block_address: block_address_replay,
mast_forest_resolution: mast_forest_resolution_replay,
},
continuation,
initial_mast_forest,
} = fragment_context;
let processor = ReplayProcessor::new(
system,
stack,
stack_overflow_replay,
execution_context_replay,
advice_replay,
memory_reads_replay,
hasher_response_replay,
mast_forest_resolution_replay,
fragment_size.into(),
);
let tracer =
CoreTraceGenerationTracer::new(fragment, decoder, block_address_replay, block_stack_replay);
(processor, tracer, continuation, initial_mast_forest)
}