miden_lifted_stark/prover/commit.rs
1//! Trace commitment (LDE + LMCS).
2//!
3//! This module provides types and functions for committing traces with lifting support:
4//!
5//! - [`commit_traces`]: Commit traces with lifting support (LDE → LMCS)
6//! - [`Committed`]: Wrapper around LMCS tree with domain metadata
7
8use alloc::vec::Vec;
9
10use miden_lifted_air::log2_strict_u8;
11use p3_dft::TwoAdicSubgroupDft;
12use p3_field::{ExtensionField, TwoAdicField};
13use p3_matrix::{
14 Matrix,
15 bitrev::{BitReversedMatrixView, BitReversibleMatrix},
16 dense::{RowMajorMatrix, RowMajorMatrixView},
17};
18use tracing::info_span;
19
20use crate::{
21 StarkConfig,
22 coset::LiftedCoset,
23 lmcs::{Lmcs, LmcsTree, bitrev::materialize_bitrev},
24};
25
26// ============================================================================
27// Committed
28// ============================================================================
29
30/// Committed polynomial evaluations with domain metadata.
31///
32/// Wraps an LMCS tree and stores the blowup used for the LDE.
33///
34/// We keep `log_blowup` here (instead of threading it through every call) so we can
35/// recover per-matrix domain information (trace height, lift ratio, coset shift) from
36/// just the committed matrices. This is especially useful when committing multiple
37/// traces of different heights into one tree.
38///
39/// # Type Parameters
40///
41/// - `F`: Scalar field element type
42/// - `M`: Matrix type (e.g., `RowMajorMatrix<F>`)
43/// - `L`: LMCS configuration type
44///
45/// # Usage
46///
47/// ```ignore
48/// let committed = commit_traces(config, traces);
49/// let root = committed.root();
50/// let view = committed.evals_on_quotient_domain(0, constraint_degree);
51/// ```
52///
53/// Storing the blowup also avoids re-deriving `trace_height = lde_height / blowup` for each
54/// matrix, which is needed for quotient-domain views and lifting shifts.
55pub struct Committed<F, M, L>
56where
57 F: TwoAdicField,
58 L: Lmcs<F = F>,
59 M: Matrix<F>,
60{
61 /// The underlying LMCS tree.
62 tree: L::Tree<M>,
63 /// Log₂ of the blowup factor used during LDE.
64 log_blowup: u8,
65}
66
67impl<F, M, L> Committed<F, M, L>
68where
69 F: TwoAdicField,
70 L: Lmcs<F = F>,
71 M: Matrix<F>,
72{
73 /// Create a new `Committed` wrapper.
74 ///
75 /// # Arguments
76 ///
77 /// - `tree`: The LMCS tree containing committed LDE matrices
78 /// - `log_blowup`: Log₂ of the blowup factor used during LDE
79 #[inline]
80 pub fn new(tree: L::Tree<M>, log_blowup: u8) -> Self {
81 Self { tree, log_blowup }
82 }
83
84 /// Get the commitment root.
85 #[inline]
86 pub fn root(&self) -> L::Commitment {
87 self.tree.root()
88 }
89
90 /// Get a reference to the underlying tree.
91 #[inline]
92 pub fn tree(&self) -> &L::Tree<M> {
93 &self.tree
94 }
95
96 /// Get log₂ of the maximum LDE height across all matrices.
97 ///
98 /// This is the height of the tree (the largest matrix height).
99 #[inline]
100 fn log_max_lde_height(&self) -> u8 {
101 log2_strict_u8(self.tree.height())
102 }
103
104 /// Returns the [`LiftedCoset`] the `m`-th matrix was committed on.
105 ///
106 /// # Panics
107 ///
108 /// Panics if `m >= num_matrices()`.
109 fn lifted_coset(&self, m: usize) -> LiftedCoset {
110 let matrix = &self.tree.leaves()[m];
111 let log_lde_height = log2_strict_u8(matrix.height());
112 let log_trace_height = log_lde_height - self.log_blowup;
113 let log_max_trace_height = self.log_max_lde_height() - self.log_blowup;
114
115 LiftedCoset::new(log_trace_height, self.log_blowup, log_max_trace_height)
116 }
117}
118
119impl<F, L> Committed<F, RowMajorMatrix<F>, L>
120where
121 F: TwoAdicField,
122 L: Lmcs<F = F>,
123{
124 /// Return a zero-copy view of matrix `m` on the quotient evaluation domain.
125 ///
126 /// This returns evaluations over the quotient coset `gJ ⊆ gK`.
127 ///
128 /// The tree commits to LDE evaluations on `gK` (size `N·B`). The `RowMajorMatrix`
129 /// stores bit-reversed evaluations; `gJ` appears as the first `N·D` rows, so this is
130 /// a zero-copy prefix view followed by `bit_reverse_rows()` to expose natural order.
131 ///
132 /// # Panics
133 ///
134 /// Panics if `m >= num_matrices()`.
135 pub fn evals_on_quotient_domain(
136 &self,
137 m: usize,
138 constraint_degree: usize,
139 ) -> BitReversedMatrixView<RowMajorMatrixView<'_, F>> {
140 let quotient_height = self.lifted_coset(m).trace_height() * constraint_degree;
141 self.tree.leaves()[m].split_rows(quotient_height).0.bit_reverse_rows()
142 }
143}
144
145// ============================================================================
146// commit_traces
147// ============================================================================
148
149/// Commit multiple trace matrices with lifting: LDE → LMCS tree.
150///
151/// Traces must be sorted by height in ascending order. Each trace is lifted to
152/// the max LDE domain using the appropriate nested coset shift.
153///
154/// The DFT output is wrapped in `BitReversedMatrixView` (zero-cost view) and
155/// passed directly to the LMCS — no materialization needed.
156///
157/// Returns a [`Committed`] wrapper providing:
158/// - Commitment root via [`Committed::root()`]
159/// - Underlying LMCS tree via [`Committed::tree()`]
160/// - Quotient domain views via [`Committed::evals_on_quotient_domain()`]
161///
162/// # Arguments
163/// - `config`: STARK configuration containing PCS params, LMCS, and DFT
164/// - `traces`: Trace matrices sorted by height (ascending)
165///
166/// # Panics
167/// - If `traces` is empty
168/// - If trace heights are not powers of two
169/// - If traces are not sorted by height in ascending order
170///
171/// Lifting note: for a trace of height `n` embedded into a max height `n_max`, let
172/// `r = n_max / n`. The commitment should behave as if it contains evaluations of the
173/// lifted polynomial `f_lift(X) = f(Xʳ)` on the max LDE coset. This is achieved by
174/// evaluating the original trace on a *nested* coset with shift gʳ: the map
175/// `(g·ω)ʳ = gʳ·ωʳ` sends the max domain down to the smaller one.
176pub fn commit_traces<F, EF, SC>(
177 config: &SC,
178 traces: Vec<RowMajorMatrix<F>>,
179) -> Committed<F, RowMajorMatrix<F>, SC::Lmcs>
180where
181 F: TwoAdicField,
182 EF: ExtensionField<F>,
183 SC: StarkConfig<F, EF>,
184{
185 assert!(!traces.is_empty(), "at least one trace required");
186
187 assert!(
188 traces.windows(2).all(|w| w[0].height() <= w[1].height()),
189 "traces must be sorted by height in ascending order"
190 );
191
192 let log_blowup = config.pcs().log_blowup();
193
194 // Find max trace height
195 let max_trace_height = traces.last().unwrap().height();
196 let log_max_trace_height = log2_strict_u8(max_trace_height);
197
198 let ldes: Vec<_> = traces
199 .into_iter()
200 .enumerate()
201 .map(|(idx, trace)| {
202 let trace_height = trace.height();
203 let width = trace.width();
204
205 // Validate height is power of two
206 assert!(
207 trace_height.is_power_of_two(),
208 "trace height must be power of two (index {idx})"
209 );
210
211 let log_trace_height = log2_strict_u8(trace_height);
212
213 // Use LiftedCoset to compute the coset shift
214 let coset = LiftedCoset::new(log_trace_height, log_blowup, log_max_trace_height);
215 let coset_shift = coset.lde_shift::<F>();
216
217 info_span!("LDE", trace = idx, log_height = log_trace_height, width).in_scope(|| {
218 let lde = config.dft().coset_lde_batch(trace, log_blowup.into(), coset_shift);
219 materialize_bitrev(lde)
220 })
221 })
222 .collect();
223
224 // Build aligned LMCS tree and wrap in Committed
225 let tree = config.lmcs().build_aligned_tree(ldes);
226 Committed::new(tree, log_blowup)
227}
228
229// ============================================================================
230// Tests
231// ============================================================================
232
233#[cfg(test)]
234mod tests {
235 use alloc::vec;
236
237 use p3_field::PrimeCharacteristicRing;
238 use p3_util::reverse_bits_len;
239
240 use super::*;
241 use crate::testing::configs::goldilocks_poseidon2::Felt;
242
243 #[test]
244 fn split_rows_truncates_correctly() {
245 // Create a 16x4 matrix (LDE height = 16, width = 4)
246 let data: Vec<Felt> = (0u64..64).map(Felt::from_u64).collect();
247 let matrix = RowMajorMatrix::new(data, 4);
248
249 // Truncate to 8 rows via split_rows
250 let truncated = matrix.split_rows(8).0;
251 assert_eq!(truncated.height(), 8);
252 assert_eq!(truncated.width(), 4);
253
254 // Verify first row is unchanged
255 let row: Vec<Felt> = truncated.row(0).unwrap().into_iter().collect();
256 assert_eq!(
257 row,
258 vec![Felt::from_u64(0), Felt::from_u64(1), Felt::from_u64(2), Felt::from_u64(3)]
259 );
260 }
261
262 #[test]
263 fn bit_reverse_rows_gives_natural_order() {
264 // Create an 8x2 matrix with values that let us verify bit-reversal
265 // Row i (bit-reversed) contains [2*i, 2*i+1]
266 let data: Vec<Felt> = (0u64..16).map(Felt::from_u64).collect();
267 let matrix = RowMajorMatrix::new(data, 2);
268
269 let natural = matrix.as_view().bit_reverse_rows();
270 assert_eq!(natural.height(), 8);
271 assert_eq!(natural.width(), 2);
272
273 // General verification: natural row i should have values from bit-reversed row bitrev(i)
274 for i in 0..8 {
275 let br_i = reverse_bits_len(i, 3);
276 let natural_row: Vec<Felt> = natural.row(i).unwrap().into_iter().collect();
277 let expected: Vec<Felt> =
278 vec![Felt::from_u64((br_i * 2) as u64), Felt::from_u64((br_i * 2 + 1) as u64)];
279 assert_eq!(natural_row, expected, "mismatch at natural row {i}");
280 }
281 }
282
283 #[test]
284 fn truncate_then_bit_reverse() {
285 // Create a 16x2 matrix
286 let data: Vec<Felt> = (0u64..32).map(Felt::from_u64).collect();
287 let matrix = RowMajorMatrix::new(data, 2);
288
289 // Truncate to 8 rows and convert to natural order
290 let truncated_natural = matrix.split_rows(8).0.bit_reverse_rows();
291 assert_eq!(truncated_natural.height(), 8);
292 assert_eq!(truncated_natural.width(), 2);
293
294 for i in 0..8 {
295 assert_eq!(
296 truncated_natural.row(i).unwrap().into_iter().count(),
297 2,
298 "row {i} should have 2 elements"
299 );
300 }
301 }
302}