use core::ops::ControlFlow;
use miden_air::trace::{
DECODER_TRACE_OFFSET,
decoder::{
ADDR_COL_IDX, GROUP_COUNT_COL_IDX, HASHER_STATE_OFFSET, IN_SPAN_COL_IDX,
NUM_OP_BATCH_FLAGS, NUM_OP_BITS, NUM_USER_OP_HELPERS, OP_BATCH_FLAGS_OFFSET,
OP_INDEX_COL_IDX,
},
};
use miden_core::{
Felt, ONE, Operation, ZERO,
mast::{BasicBlockNode, MastNodeExt, OpBatch},
};
use super::{CoreTraceFragmentGenerator, trace_builder::OperationTraceConfig};
const HASH_CYCLE_LEN: Felt = Felt::new(miden_air::trace::chiplets::hasher::HASH_CYCLE_LEN as u64);
impl CoreTraceFragmentGenerator {
pub fn add_span_start_trace_row(
&mut self,
first_op_batch: &OpBatch,
num_groups: Felt,
parent_addr: Felt,
) -> ControlFlow<()> {
let row_idx = self.num_rows_built();
self.populate_system_trace_columns(row_idx);
self.fragment.columns[DECODER_TRACE_OFFSET + ADDR_COL_IDX][row_idx] = parent_addr;
self.append_opcode(Operation::Span.op_code(), row_idx);
for (i, &group) in first_op_batch.groups().iter().enumerate() {
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + i][row_idx] = group;
}
self.fragment.columns[DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX][row_idx] = ZERO;
self.fragment.columns[DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX][row_idx] = num_groups;
self.fragment.columns[DECODER_TRACE_OFFSET + OP_INDEX_COL_IDX][row_idx] = ZERO;
{
let op_batch_flags = get_op_batch_flags(num_groups);
for (i, flag) in op_batch_flags.iter().enumerate().take(NUM_OP_BATCH_FLAGS) {
self.fragment.columns[DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET + i][row_idx] =
*flag;
}
}
self.populate_stack_trace_columns(row_idx);
self.increment_clk()
}
pub fn add_span_end_trace_row(&mut self, basic_block_node: &BasicBlockNode) -> ControlFlow<()> {
let (ended_node_addr, flags) = self.update_decoder_state_on_node_end();
let config = OperationTraceConfig {
opcode: Operation::End.op_code(),
hasher_state: (basic_block_node.digest(), flags.to_hasher_state_second_word()),
addr: ended_node_addr,
};
self.span_context = None;
self.add_control_flow_trace_row(config)
}
pub fn respan(&mut self, op_batch: &OpBatch) -> ControlFlow<()> {
self.add_respan_trace_row(op_batch)?;
self.context.state.decoder.current_addr += HASH_CYCLE_LEN;
let span_context = self
.span_context
.as_mut()
.expect("Span context should be initialized for RESPAN");
span_context.num_groups_left -= ONE;
span_context.group_ops_left = op_batch.groups()[0];
ControlFlow::Continue(())
}
fn add_respan_trace_row(&mut self, op_batch: &OpBatch) -> ControlFlow<()> {
use miden_air::trace::{
DECODER_TRACE_OFFSET,
decoder::{
ADDR_COL_IDX, GROUP_COUNT_COL_IDX, HASHER_STATE_OFFSET, IN_SPAN_COL_IDX,
NUM_OP_BATCH_FLAGS, OP_BATCH_FLAGS_OFFSET, OP_INDEX_COL_IDX,
},
};
let group_count = self
.span_context
.as_ref()
.expect("Span context should be initialized for RESPAN")
.num_groups_left;
let row_idx = self.num_rows_built();
self.populate_system_trace_columns(row_idx);
self.fragment.columns[DECODER_TRACE_OFFSET + ADDR_COL_IDX][row_idx] =
self.context.state.decoder.current_addr;
self.append_opcode(Operation::Respan.op_code(), row_idx);
for (i, &group) in op_batch.groups().iter().enumerate() {
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + i][row_idx] = group;
}
self.fragment.columns[DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX][row_idx] = ZERO;
self.fragment.columns[DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX][row_idx] = group_count;
self.fragment.columns[DECODER_TRACE_OFFSET + OP_INDEX_COL_IDX][row_idx] = ZERO;
let op_batch_flags = get_op_batch_flags(group_count);
for (i, flag) in op_batch_flags.iter().enumerate().take(NUM_OP_BATCH_FLAGS) {
self.fragment.columns[DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET + i][row_idx] =
*flag;
}
self.populate_stack_trace_columns(row_idx);
self.increment_clk()
}
pub fn add_operation_trace_row(
&mut self,
operation: Operation,
op_idx_in_group: usize,
user_op_helpers: Option<[Felt; NUM_USER_OP_HELPERS]>,
) -> ControlFlow<()> {
let row_idx = self.num_rows_built();
self.populate_system_trace_columns(row_idx);
self.append_opcode(operation.op_code(), row_idx);
self.fragment.columns[DECODER_TRACE_OFFSET + ADDR_COL_IDX][row_idx] =
self.context.state.decoder.current_addr;
let ctx = self.span_context.as_mut().expect("not in span");
ctx.group_ops_left = remove_opcode_from_group(ctx.group_ops_left, operation);
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET][row_idx] =
ctx.group_ops_left;
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + 1][row_idx] =
self.context.state.decoder.parent_addr;
{
let user_op_helpers = user_op_helpers.unwrap_or([ZERO; NUM_USER_OP_HELPERS]);
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + 2][row_idx] =
user_op_helpers[0];
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + 3][row_idx] =
user_op_helpers[1];
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + 4][row_idx] =
user_op_helpers[2];
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + 5][row_idx] =
user_op_helpers[3];
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + 6][row_idx] =
user_op_helpers[4];
self.fragment.columns[DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + 7][row_idx] =
user_op_helpers[5];
}
self.fragment.columns[DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX][row_idx] = ONE;
self.fragment.columns[DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX][row_idx] =
ctx.num_groups_left;
self.fragment.columns[DECODER_TRACE_OFFSET + OP_INDEX_COL_IDX][row_idx] =
Felt::from(op_idx_in_group as u32);
for i in 0..NUM_OP_BATCH_FLAGS {
self.fragment.columns[DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET + i][row_idx] = ZERO;
}
if operation.imm_value().is_some() {
ctx.num_groups_left -= ONE;
}
self.populate_stack_trace_columns(row_idx);
self.increment_clk()
}
}
fn get_op_batch_flags(num_groups_left: Felt) -> [Felt; 3] {
use miden_air::trace::decoder::{
OP_BATCH_1_GROUPS, OP_BATCH_2_GROUPS, OP_BATCH_4_GROUPS, OP_BATCH_8_GROUPS,
};
use miden_core::mast::OP_BATCH_SIZE;
let num_groups = core::cmp::min(num_groups_left.as_int() as usize, OP_BATCH_SIZE);
match num_groups {
8 => OP_BATCH_8_GROUPS,
4 => OP_BATCH_4_GROUPS,
2 => OP_BATCH_2_GROUPS,
1 => OP_BATCH_1_GROUPS,
_ => panic!(
"invalid number of groups in a batch: {num_groups}, group count: {num_groups_left}"
),
}
}
pub(crate) fn remove_opcode_from_group(op_group: Felt, op: Operation) -> Felt {
let opcode = op.op_code() as u64;
let result = Felt::new((op_group.as_int() - opcode) >> NUM_OP_BITS);
debug_assert!(op_group.as_int() >= result.as_int(), "op group underflow");
result
}