Skip to main content

miden_air/lookup/debug/trace/
mod.rs

1//! Combined real-trace balance + per-column `(V, U)` oracle debug surface.
2//!
3//! One walk over a concrete main trace, row by row, produces three outputs projected
4//! out of the shared `run_trace_walk` driver:
5//!
6//! - **Balance** — signed multiplicities keyed by encoded denominator. Any residual at the end of
7//!   the walk is an unmatched interaction.
8//! - **Push log** — a [`PushRecord`] per interaction emission, capturing pre-encoding payload,
9//!   encoded denominator, signed multiplicity, and `(row, column, group)` source coordinates.
10//!   Joined back against the balance map at finalize time so each unmatched denominator lists the
11//!   exact pushes that summed to it.
12//! - **Column oracle folds** — per-row per-column `(V_col, U_col)` pairs computed via the
13//!   constraint-path cross-multiplication rule, used by the processor's LogUp cross-check.
14//!
15//! Layout:
16//!
17//! - [`builder`] — the `DebugTraceBuilder` (plus column / group / batch handles) that drives each
18//!   per-row walk.
19//! - This file — the report types ([`BalanceReport`], [`Unmatched`], [`PushRecord`], …), the
20//!   row-by-row `run_trace_walk` driver, and the two public entry points.
21
22use alloc::{string::String, vec, vec::Vec};
23use core::{borrow::Borrow, fmt};
24use std::collections::HashMap;
25
26use miden_core::{
27    field::{PrimeCharacteristicRing, QuadFelt},
28    utils::{Matrix, RowMajorMatrix},
29};
30use miden_crypto::stark::air::RowWindow;
31
32use super::super::{Challenges, LookupAir};
33use crate::Felt;
34
35pub mod builder;
36
37pub use builder::{
38    DebugBoundaryEmitter, DebugTraceBatch, DebugTraceBuilder, DebugTraceColumn, DebugTraceGroup,
39};
40
41// REPORT TYPES
42// ================================================================================================
43
44/// An unmatched interaction: an encoded denom with non-zero net multiplicity after walking
45/// the full trace.
46#[derive(Debug, Clone)]
47pub struct Unmatched {
48    pub denom: QuadFelt,
49    /// Net signed multiplicity modulo the field prime.
50    pub net_multiplicity: Felt,
51    /// Every push that landed on this encoded denominator during the walk, in emission
52    /// order. The caller can bucket these by `msg_repr` / column / row to isolate the
53    /// specific emit that left the denom unbalanced.
54    pub contributions: Vec<PushRecord>,
55}
56
57/// One interaction emission captured during a trace walk.
58///
59/// Populated for every push that passes its flag check, regardless of whether the
60/// interaction eventually balances. When a denom lands in [`BalanceReport::unmatched`],
61/// the join against the push log shows exactly which emits (row, column, group,
62/// payload) summed to the residual multiplicity.
63#[derive(Debug, Clone)]
64pub struct PushRecord {
65    pub row: usize,
66    pub column_idx: usize,
67    pub group_idx: usize,
68    /// `format!("{:?}", msg)` of the `LookupMessage` instance. `"<encoded>"` for
69    /// `insert_encoded` sites, where only the pre-computed denominator is known.
70    pub msg_repr: String,
71    pub denom: QuadFelt,
72    pub multiplicity: Felt,
73}
74
75/// Per-row mutual-exclusion violation inside a cached-encoding group.
76#[derive(Debug, Clone)]
77pub struct MutualExclusionViolation {
78    pub row: usize,
79    pub column_idx: usize,
80    pub group_idx: usize,
81    pub active_flags: usize,
82}
83
84/// Full report returned by [`check_trace_balance`].
85#[derive(Debug, Default)]
86pub struct BalanceReport {
87    pub unmatched: Vec<Unmatched>,
88    pub mutex_violations: Vec<MutualExclusionViolation>,
89}
90
91impl BalanceReport {
92    pub fn is_ok(&self) -> bool {
93        self.unmatched.is_empty() && self.mutex_violations.is_empty()
94    }
95}
96
97impl fmt::Display for BalanceReport {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        /// How many contributing pushes to print per unmatched denom before truncating.
100        const MAX_CONTRIB_LINES: usize = 4;
101
102        if self.is_ok() {
103            return writeln!(f, "BalanceReport: OK");
104        }
105        writeln!(
106            f,
107            "BalanceReport: {} unmatched, {} mutex violations",
108            self.unmatched.len(),
109            self.mutex_violations.len(),
110        )?;
111        for u in &self.unmatched {
112            writeln!(f, "  denom {:?} net multiplicity {:?}", u.denom, u.net_multiplicity)?;
113            for r in u.contributions.iter().take(MAX_CONTRIB_LINES) {
114                writeln!(
115                    f,
116                    "    row={} col={} group={} mult={:?} msg={}",
117                    r.row, r.column_idx, r.group_idx, r.multiplicity, r.msg_repr,
118                )?;
119            }
120            if u.contributions.len() > MAX_CONTRIB_LINES {
121                writeln!(
122                    f,
123                    "    … {} more contributions",
124                    u.contributions.len() - MAX_CONTRIB_LINES,
125                )?;
126            }
127        }
128        for m in &self.mutex_violations {
129            writeln!(
130                f,
131                "  mutex violation at row {} col {} group {}: {} active flags",
132                m.row, m.column_idx, m.group_idx, m.active_flags,
133            )?;
134        }
135        Ok(())
136    }
137}
138
139// STATE
140// ================================================================================================
141
142/// Scratch state threaded through [`DebugTraceBuilder`] for every row in the walk. The
143/// driver creates one instance per walk; it resets `column_folds` at the start of each
144/// row and keeps `balances` / `push_log` / `mutex_violations` accumulating across rows.
145pub struct DebugTraceState {
146    /// Signed-multiplicity accumulator keyed by encoded denominator. Sorted at
147    /// finalize time for deterministic output.
148    pub(super) balances: HashMap<QuadFelt, Felt>,
149    /// Per-push record of every interaction emission. Joined against `balances` in
150    /// [`finalize`] so each unmatched denom carries its source pushes.
151    pub(super) push_log: Vec<PushRecord>,
152    pub(super) mutex_violations: Vec<MutualExclusionViolation>,
153    /// Per-column `(V_col, U_col)`. Reset to `(ZERO, ONE)` at the start of each row by
154    /// [`run_trace_walk`].
155    pub(super) column_folds: Vec<(QuadFelt, QuadFelt)>,
156}
157
158// ENTRY POINTS
159// ================================================================================================
160
161/// Walk a complete main trace and return the balance report (unmatched interactions +
162/// mutex violations).
163///
164/// Includes boundary contributions from [`LookupAir::eval_boundary`], so a fully
165/// closed AIR produces `BalanceReport::is_ok() == true`. `var_len_public_inputs` is
166/// the same shape the prover hands to `miden_crypto::stark::prover::prove_single`
167/// (e.g. `&[&kernel_felts]`); pass `&[]` if the AIR has no variable-length public
168/// inputs or no boundary contributions that consume them.
169pub fn check_trace_balance<A>(
170    air: &A,
171    main_trace: &RowMajorMatrix<Felt>,
172    periodic_columns: &[Vec<Felt>],
173    public_values: &[Felt],
174    var_len_public_inputs: &[&[Felt]],
175    challenges: &Challenges<QuadFelt>,
176) -> BalanceReport
177where
178    for<'a> A: LookupAir<DebugTraceBuilder<'a>>,
179{
180    run_trace_walk(
181        air,
182        main_trace,
183        periodic_columns,
184        public_values,
185        var_len_public_inputs,
186        challenges,
187    )
188    .balance
189}
190
191/// Walk a complete main trace and return the per-row constraint-path `(V_col, U_col)`
192/// folds. `folds[r][col]` is the fold for column `col` at row `r`.
193///
194/// Does not incorporate boundary contributions — the folds are a per-row property of
195/// the main trace, independent of once-per-proof outer emissions.
196pub fn collect_column_oracle_folds<A>(
197    air: &A,
198    main_trace: &RowMajorMatrix<Felt>,
199    periodic_columns: &[Vec<Felt>],
200    public_values: &[Felt],
201    challenges: &Challenges<QuadFelt>,
202) -> Vec<Vec<(QuadFelt, QuadFelt)>>
203where
204    for<'a> A: LookupAir<DebugTraceBuilder<'a>>,
205{
206    run_trace_walk(air, main_trace, periodic_columns, public_values, &[], challenges).folds_per_row
207}
208
209// SHARED DRIVER
210// ================================================================================================
211
212struct TraceWalkOutput {
213    balance: BalanceReport,
214    folds_per_row: Vec<Vec<(QuadFelt, QuadFelt)>>,
215}
216
217/// Shared row-by-row driver used by both public entry points. Each row gets a fresh
218/// [`DebugTraceBuilder`] with column folds reset to `(ZERO, ONE)`; the balance accumulator
219/// persists across rows, the folds snapshot at row end.
220fn run_trace_walk<A>(
221    air: &A,
222    main_trace: &RowMajorMatrix<Felt>,
223    periodic_columns: &[Vec<Felt>],
224    public_values: &[Felt],
225    var_len_public_inputs: &[&[Felt]],
226    challenges: &Challenges<QuadFelt>,
227) -> TraceWalkOutput
228where
229    for<'a> A: LookupAir<DebugTraceBuilder<'a>>,
230{
231    let num_rows = main_trace.height();
232    let width = main_trace.width();
233    let flat: &[Felt] = main_trace.values.borrow();
234    let num_cols = air.num_columns();
235
236    let mut state = DebugTraceState {
237        balances: HashMap::new(),
238        push_log: Vec::new(),
239        mutex_violations: Vec::new(),
240        column_folds: vec![(QuadFelt::ZERO, QuadFelt::ONE); num_cols],
241    };
242    let mut folds_per_row: Vec<Vec<(QuadFelt, QuadFelt)>> = Vec::with_capacity(num_rows);
243    let mut periodic_row: Vec<Felt> = vec![Felt::ZERO; periodic_columns.len()];
244
245    for r in 0..num_rows {
246        let curr = &flat[r * width..(r + 1) * width];
247        let nxt_idx = (r + 1) % num_rows;
248        let next = &flat[nxt_idx * width..(nxt_idx + 1) * width];
249        let window = RowWindow::from_two_rows(curr, next);
250
251        for (i, col) in periodic_columns.iter().enumerate() {
252            periodic_row[i] = col[r % col.len()];
253        }
254
255        // Reset per-row folds; balances and mutex_violations persist.
256        for fold in state.column_folds.iter_mut() {
257            *fold = (QuadFelt::ZERO, QuadFelt::ONE);
258        }
259
260        {
261            let mut lb = DebugTraceBuilder::new(window, &periodic_row, challenges, &mut state, r);
262            air.eval(&mut lb);
263        }
264
265        folds_per_row.push(state.column_folds.clone());
266    }
267
268    // Boundary / outer interactions (once per proof, no row): kernel init, block
269    // hash, log-precompile terminals, …. Accumulates into the same balance map as
270    // the per-row trace emissions — a fully closed AIR produces `is_ok() == true`.
271    {
272        let mut boundary = DebugBoundaryEmitter {
273            challenges,
274            state: &mut state,
275            public_values,
276            var_len_public_inputs,
277        };
278        air.eval_boundary(&mut boundary);
279    }
280
281    TraceWalkOutput { balance: finalize(state), folds_per_row }
282}
283
284fn finalize(state: DebugTraceState) -> BalanceReport {
285    let DebugTraceState { balances, push_log, mutex_violations, .. } = state;
286
287    // Group every push by its encoded denom so each unmatched denom can pull its
288    // contributing records in O(1). Preserves emission order within each bucket.
289    let mut contrib_by_denom: HashMap<QuadFelt, Vec<PushRecord>> = HashMap::new();
290    for record in push_log {
291        contrib_by_denom.entry(record.denom).or_default().push(record);
292    }
293
294    let mut unmatched = Vec::new();
295    for (denom, net) in balances {
296        if net == Felt::ZERO {
297            continue;
298        }
299        let contributions = contrib_by_denom.remove(&denom).unwrap_or_default();
300        unmatched.push(Unmatched {
301            denom,
302            net_multiplicity: net,
303            contributions,
304        });
305    }
306    // Sort for deterministic output — `HashMap` iteration order is arbitrary.
307    unmatched.sort_by_key(|u| u.denom);
308    BalanceReport { unmatched, mutex_violations }
309}