Skip to main content

miden_lifted_stark/prover/
commit.rs

1//! Trace commitment (LDE + LMCS).
2//!
3//! This module provides types and functions for committing traces with lifting support:
4//!
5//! - [`commit_traces`]: Commit traces with lifting support (LDE → LMCS)
6//! - [`Committed`]: Wrapper around LMCS tree with domain metadata
7
8use alloc::vec::Vec;
9
10use miden_lifted_air::log2_strict_u8;
11use p3_dft::TwoAdicSubgroupDft;
12use p3_field::{ExtensionField, TwoAdicField};
13use p3_matrix::{
14    Matrix,
15    bitrev::{BitReversedMatrixView, BitReversibleMatrix},
16    dense::{RowMajorMatrix, RowMajorMatrixView},
17};
18use tracing::info_span;
19
20use crate::{
21    StarkConfig,
22    coset::LiftedCoset,
23    lmcs::{Lmcs, LmcsTree, bitrev::materialize_bitrev},
24};
25
26// ============================================================================
27// Committed
28// ============================================================================
29
30/// Committed polynomial evaluations with domain metadata.
31///
32/// Wraps an LMCS tree and stores the blowup used for the LDE.
33///
34/// We keep `log_blowup` here (instead of threading it through every call) so we can
35/// recover per-matrix domain information (trace height, lift ratio, coset shift) from
36/// just the committed matrices. This is especially useful when committing multiple
37/// traces of different heights into one tree.
38///
39/// # Type Parameters
40///
41/// - `F`: Scalar field element type
42/// - `M`: Matrix type (e.g., `RowMajorMatrix<F>`)
43/// - `L`: LMCS configuration type
44///
45/// # Usage
46///
47/// ```ignore
48/// let committed = commit_traces(config, traces);
49/// let root = committed.root();
50/// let view = committed.evals_on_quotient_domain(0, constraint_degree);
51/// ```
52///
53/// Storing the blowup also avoids re-deriving `trace_height = lde_height / blowup` for each
54/// matrix, which is needed for quotient-domain views and lifting shifts.
55pub struct Committed<F, M, L>
56where
57    F: TwoAdicField,
58    L: Lmcs<F = F>,
59    M: Matrix<F>,
60{
61    /// The underlying LMCS tree.
62    tree: L::Tree<M>,
63    /// Log₂ of the blowup factor used during LDE.
64    log_blowup: u8,
65}
66
67impl<F, M, L> Committed<F, M, L>
68where
69    F: TwoAdicField,
70    L: Lmcs<F = F>,
71    M: Matrix<F>,
72{
73    /// Create a new `Committed` wrapper.
74    ///
75    /// # Arguments
76    ///
77    /// - `tree`: The LMCS tree containing committed LDE matrices
78    /// - `log_blowup`: Log₂ of the blowup factor used during LDE
79    #[inline]
80    pub fn new(tree: L::Tree<M>, log_blowup: u8) -> Self {
81        Self { tree, log_blowup }
82    }
83
84    /// Get the commitment root.
85    #[inline]
86    pub fn root(&self) -> L::Commitment {
87        self.tree.root()
88    }
89
90    /// Get a reference to the underlying tree.
91    #[inline]
92    pub fn tree(&self) -> &L::Tree<M> {
93        &self.tree
94    }
95
96    /// Get log₂ of the maximum LDE height across all matrices.
97    ///
98    /// This is the height of the tree (the largest matrix height).
99    #[inline]
100    fn log_max_lde_height(&self) -> u8 {
101        log2_strict_u8(self.tree.height())
102    }
103
104    /// Returns the [`LiftedCoset`] the `m`-th matrix was committed on.
105    ///
106    /// # Panics
107    ///
108    /// Panics if `m >= num_matrices()`.
109    fn lifted_coset(&self, m: usize) -> LiftedCoset {
110        let matrix = &self.tree.leaves()[m];
111        let log_lde_height = log2_strict_u8(matrix.height());
112        let log_trace_height = log_lde_height - self.log_blowup;
113        let log_max_trace_height = self.log_max_lde_height() - self.log_blowup;
114
115        LiftedCoset::new(log_trace_height, self.log_blowup, log_max_trace_height)
116    }
117}
118
119impl<F, L> Committed<F, RowMajorMatrix<F>, L>
120where
121    F: TwoAdicField,
122    L: Lmcs<F = F>,
123{
124    /// Return a zero-copy view of matrix `m` on the quotient evaluation domain.
125    ///
126    /// This returns evaluations over the quotient coset `gJ ⊆ gK`.
127    ///
128    /// The tree commits to LDE evaluations on `gK` (size `N·B`). The `RowMajorMatrix`
129    /// stores bit-reversed evaluations; `gJ` appears as the first `N·D` rows, so this is
130    /// a zero-copy prefix view followed by `bit_reverse_rows()` to expose natural order.
131    ///
132    /// # Panics
133    ///
134    /// Panics if `m >= num_matrices()`.
135    pub fn evals_on_quotient_domain(
136        &self,
137        m: usize,
138        constraint_degree: usize,
139    ) -> BitReversedMatrixView<RowMajorMatrixView<'_, F>> {
140        let quotient_height = self.lifted_coset(m).trace_height() * constraint_degree;
141        self.tree.leaves()[m].split_rows(quotient_height).0.bit_reverse_rows()
142    }
143}
144
145// ============================================================================
146// commit_traces
147// ============================================================================
148
149/// Commit multiple trace matrices with lifting: LDE → LMCS tree.
150///
151/// Traces must be sorted by height in ascending order. Each trace is lifted to
152/// the max LDE domain using the appropriate nested coset shift.
153///
154/// The DFT output is wrapped in `BitReversedMatrixView` (zero-cost view) and
155/// passed directly to the LMCS — no materialization needed.
156///
157/// Returns a [`Committed`] wrapper providing:
158/// - Commitment root via [`Committed::root()`]
159/// - Underlying LMCS tree via [`Committed::tree()`]
160/// - Quotient domain views via [`Committed::evals_on_quotient_domain()`]
161///
162/// # Arguments
163/// - `config`: STARK configuration containing PCS params, LMCS, and DFT
164/// - `traces`: Trace matrices sorted by height (ascending)
165///
166/// # Panics
167/// - If `traces` is empty
168/// - If trace heights are not powers of two
169/// - If traces are not sorted by height in ascending order
170///
171/// Lifting note: for a trace of height `n` embedded into a max height `n_max`, let
172/// `r = n_max / n`. The commitment should behave as if it contains evaluations of the
173/// lifted polynomial `f_lift(X) = f(Xʳ)` on the max LDE coset. This is achieved by
174/// evaluating the original trace on a *nested* coset with shift gʳ: the map
175/// `(g·ω)ʳ = gʳ·ωʳ` sends the max domain down to the smaller one.
176pub fn commit_traces<F, EF, SC>(
177    config: &SC,
178    traces: Vec<RowMajorMatrix<F>>,
179) -> Committed<F, RowMajorMatrix<F>, SC::Lmcs>
180where
181    F: TwoAdicField,
182    EF: ExtensionField<F>,
183    SC: StarkConfig<F, EF>,
184{
185    assert!(!traces.is_empty(), "at least one trace required");
186
187    assert!(
188        traces.windows(2).all(|w| w[0].height() <= w[1].height()),
189        "traces must be sorted by height in ascending order"
190    );
191
192    let log_blowup = config.pcs().log_blowup();
193
194    // Find max trace height
195    let max_trace_height = traces.last().unwrap().height();
196    let log_max_trace_height = log2_strict_u8(max_trace_height);
197
198    let ldes: Vec<_> = traces
199        .into_iter()
200        .enumerate()
201        .map(|(idx, trace)| {
202            let trace_height = trace.height();
203            let width = trace.width();
204
205            // Validate height is power of two
206            assert!(
207                trace_height.is_power_of_two(),
208                "trace height must be power of two (index {idx})"
209            );
210
211            let log_trace_height = log2_strict_u8(trace_height);
212
213            // Use LiftedCoset to compute the coset shift
214            let coset = LiftedCoset::new(log_trace_height, log_blowup, log_max_trace_height);
215            let coset_shift = coset.lde_shift::<F>();
216
217            info_span!("LDE", trace = idx, log_height = log_trace_height, width).in_scope(|| {
218                let lde = config.dft().coset_lde_batch(trace, log_blowup.into(), coset_shift);
219                materialize_bitrev(lde)
220            })
221        })
222        .collect();
223
224    // Build aligned LMCS tree and wrap in Committed
225    let tree = config.lmcs().build_aligned_tree(ldes);
226    Committed::new(tree, log_blowup)
227}
228
229// ============================================================================
230// Tests
231// ============================================================================
232
233#[cfg(test)]
234mod tests {
235    use alloc::vec;
236
237    use p3_field::PrimeCharacteristicRing;
238    use p3_util::reverse_bits_len;
239
240    use super::*;
241    use crate::testing::configs::goldilocks_poseidon2::Felt;
242
243    #[test]
244    fn split_rows_truncates_correctly() {
245        // Create a 16x4 matrix (LDE height = 16, width = 4)
246        let data: Vec<Felt> = (0u64..64).map(Felt::from_u64).collect();
247        let matrix = RowMajorMatrix::new(data, 4);
248
249        // Truncate to 8 rows via split_rows
250        let truncated = matrix.split_rows(8).0;
251        assert_eq!(truncated.height(), 8);
252        assert_eq!(truncated.width(), 4);
253
254        // Verify first row is unchanged
255        let row: Vec<Felt> = truncated.row(0).unwrap().into_iter().collect();
256        assert_eq!(
257            row,
258            vec![Felt::from_u64(0), Felt::from_u64(1), Felt::from_u64(2), Felt::from_u64(3)]
259        );
260    }
261
262    #[test]
263    fn bit_reverse_rows_gives_natural_order() {
264        // Create an 8x2 matrix with values that let us verify bit-reversal
265        // Row i (bit-reversed) contains [2*i, 2*i+1]
266        let data: Vec<Felt> = (0u64..16).map(Felt::from_u64).collect();
267        let matrix = RowMajorMatrix::new(data, 2);
268
269        let natural = matrix.as_view().bit_reverse_rows();
270        assert_eq!(natural.height(), 8);
271        assert_eq!(natural.width(), 2);
272
273        // General verification: natural row i should have values from bit-reversed row bitrev(i)
274        for i in 0..8 {
275            let br_i = reverse_bits_len(i, 3);
276            let natural_row: Vec<Felt> = natural.row(i).unwrap().into_iter().collect();
277            let expected: Vec<Felt> =
278                vec![Felt::from_u64((br_i * 2) as u64), Felt::from_u64((br_i * 2 + 1) as u64)];
279            assert_eq!(natural_row, expected, "mismatch at natural row {i}");
280        }
281    }
282
283    #[test]
284    fn truncate_then_bit_reverse() {
285        // Create a 16x2 matrix
286        let data: Vec<Felt> = (0u64..32).map(Felt::from_u64).collect();
287        let matrix = RowMajorMatrix::new(data, 2);
288
289        // Truncate to 8 rows and convert to natural order
290        let truncated_natural = matrix.split_rows(8).0.bit_reverse_rows();
291        assert_eq!(truncated_natural.height(), 8);
292        assert_eq!(truncated_natural.width(), 2);
293
294        for i in 0..8 {
295            assert_eq!(
296                truncated_natural.row(i).unwrap().into_iter().count(),
297                2,
298                "row {i} should have 2 elements"
299            );
300        }
301    }
302}