use alloc::{string::String, vec, vec::Vec};
use core::{borrow::Borrow, fmt};
use std::collections::HashMap;
use miden_core::{
field::{PrimeCharacteristicRing, QuadFelt},
utils::{Matrix, RowMajorMatrix},
};
use miden_crypto::stark::air::RowWindow;
use super::super::{Challenges, LookupAir};
use crate::Felt;
pub mod builder;
pub use builder::{
DebugBoundaryEmitter, DebugTraceBatch, DebugTraceBuilder, DebugTraceColumn, DebugTraceGroup,
};
#[derive(Debug, Clone)]
pub struct Unmatched {
pub denom: QuadFelt,
pub net_multiplicity: Felt,
pub contributions: Vec<PushRecord>,
}
#[derive(Debug, Clone)]
pub struct PushRecord {
pub row: usize,
pub column_idx: usize,
pub group_idx: usize,
pub msg_repr: String,
pub denom: QuadFelt,
pub multiplicity: Felt,
}
#[derive(Debug, Clone)]
pub struct MutualExclusionViolation {
pub row: usize,
pub column_idx: usize,
pub group_idx: usize,
pub active_flags: usize,
}
#[derive(Debug, Default)]
pub struct BalanceReport {
pub unmatched: Vec<Unmatched>,
pub mutex_violations: Vec<MutualExclusionViolation>,
}
impl BalanceReport {
pub fn is_ok(&self) -> bool {
self.unmatched.is_empty() && self.mutex_violations.is_empty()
}
}
impl fmt::Display for BalanceReport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
const MAX_CONTRIB_LINES: usize = 4;
if self.is_ok() {
return writeln!(f, "BalanceReport: OK");
}
writeln!(
f,
"BalanceReport: {} unmatched, {} mutex violations",
self.unmatched.len(),
self.mutex_violations.len(),
)?;
for u in &self.unmatched {
writeln!(f, " denom {:?} net multiplicity {:?}", u.denom, u.net_multiplicity)?;
for r in u.contributions.iter().take(MAX_CONTRIB_LINES) {
writeln!(
f,
" row={} col={} group={} mult={:?} msg={}",
r.row, r.column_idx, r.group_idx, r.multiplicity, r.msg_repr,
)?;
}
if u.contributions.len() > MAX_CONTRIB_LINES {
writeln!(
f,
" … {} more contributions",
u.contributions.len() - MAX_CONTRIB_LINES,
)?;
}
}
for m in &self.mutex_violations {
writeln!(
f,
" mutex violation at row {} col {} group {}: {} active flags",
m.row, m.column_idx, m.group_idx, m.active_flags,
)?;
}
Ok(())
}
}
pub struct DebugTraceState {
pub(super) balances: HashMap<QuadFelt, Felt>,
pub(super) push_log: Vec<PushRecord>,
pub(super) mutex_violations: Vec<MutualExclusionViolation>,
pub(super) column_folds: Vec<(QuadFelt, QuadFelt)>,
}
pub fn check_trace_balance<A>(
air: &A,
main_trace: &RowMajorMatrix<Felt>,
periodic_columns: &[Vec<Felt>],
public_values: &[Felt],
var_len_public_inputs: &[&[Felt]],
challenges: &Challenges<QuadFelt>,
) -> BalanceReport
where
for<'a> A: LookupAir<DebugTraceBuilder<'a>>,
{
run_trace_walk(
air,
main_trace,
periodic_columns,
public_values,
var_len_public_inputs,
challenges,
)
.balance
}
pub fn collect_column_oracle_folds<A>(
air: &A,
main_trace: &RowMajorMatrix<Felt>,
periodic_columns: &[Vec<Felt>],
public_values: &[Felt],
challenges: &Challenges<QuadFelt>,
) -> Vec<Vec<(QuadFelt, QuadFelt)>>
where
for<'a> A: LookupAir<DebugTraceBuilder<'a>>,
{
run_trace_walk(air, main_trace, periodic_columns, public_values, &[], challenges).folds_per_row
}
struct TraceWalkOutput {
balance: BalanceReport,
folds_per_row: Vec<Vec<(QuadFelt, QuadFelt)>>,
}
fn run_trace_walk<A>(
air: &A,
main_trace: &RowMajorMatrix<Felt>,
periodic_columns: &[Vec<Felt>],
public_values: &[Felt],
var_len_public_inputs: &[&[Felt]],
challenges: &Challenges<QuadFelt>,
) -> TraceWalkOutput
where
for<'a> A: LookupAir<DebugTraceBuilder<'a>>,
{
let num_rows = main_trace.height();
let width = main_trace.width();
let flat: &[Felt] = main_trace.values.borrow();
let num_cols = air.num_columns();
let mut state = DebugTraceState {
balances: HashMap::new(),
push_log: Vec::new(),
mutex_violations: Vec::new(),
column_folds: vec![(QuadFelt::ZERO, QuadFelt::ONE); num_cols],
};
let mut folds_per_row: Vec<Vec<(QuadFelt, QuadFelt)>> = Vec::with_capacity(num_rows);
let mut periodic_row: Vec<Felt> = vec![Felt::ZERO; periodic_columns.len()];
for r in 0..num_rows {
let curr = &flat[r * width..(r + 1) * width];
let nxt_idx = (r + 1) % num_rows;
let next = &flat[nxt_idx * width..(nxt_idx + 1) * width];
let window = RowWindow::from_two_rows(curr, next);
for (i, col) in periodic_columns.iter().enumerate() {
periodic_row[i] = col[r % col.len()];
}
for fold in state.column_folds.iter_mut() {
*fold = (QuadFelt::ZERO, QuadFelt::ONE);
}
{
let mut lb = DebugTraceBuilder::new(window, &periodic_row, challenges, &mut state, r);
air.eval(&mut lb);
}
folds_per_row.push(state.column_folds.clone());
}
{
let mut boundary = DebugBoundaryEmitter {
challenges,
state: &mut state,
public_values,
var_len_public_inputs,
};
air.eval_boundary(&mut boundary);
}
TraceWalkOutput { balance: finalize(state), folds_per_row }
}
fn finalize(state: DebugTraceState) -> BalanceReport {
let DebugTraceState { balances, push_log, mutex_violations, .. } = state;
let mut contrib_by_denom: HashMap<QuadFelt, Vec<PushRecord>> = HashMap::new();
for record in push_log {
contrib_by_denom.entry(record.denom).or_default().push(record);
}
let mut unmatched = Vec::new();
for (denom, net) in balances {
if net == Felt::ZERO {
continue;
}
let contributions = contrib_by_denom.remove(&denom).unwrap_or_default();
unmatched.push(Unmatched {
denom,
net_multiplicity: net,
contributions,
});
}
unmatched.sort_by_key(|u| u.denom);
BalanceReport { unmatched, mutex_violations }
}