Skip to main content

p3_miden_lmcs/
utils.rs

1//! Utility functions for LMCS operations.
2
3use alloc::vec::Vec;
4use core::array;
5
6use p3_field::PackedValue;
7use p3_matrix::{Matrix, dense::RowMajorMatrix};
8use p3_util::log2_strict_usize;
9use serde::{Deserialize, Serialize};
10
11/// Strict logâ‚‚ returning `u8`.
12///
13/// Panics if `n` is not a power of two.
14#[inline]
15pub fn log2_strict_u8(n: usize) -> u8 {
16    log2_strict_usize(n) as u8
17}
18
19// ============================================================================
20// RowList
21// ============================================================================
22
23/// Flat storage of variable-width rows.
24///
25/// In a STARK proof, each row typically holds one committed matrix's evaluations at a
26/// leaf index queried by the verifier as part of the low-degree test (LDT). Matrices
27/// have different widths because they encode different sets of constraint polynomials
28/// (e.g., main trace vs auxiliary trace).
29///
30/// Stores all elements contiguously in a single `Vec<T>`, with a separate `Vec<usize>`
31/// tracking the width of each row. This avoids N+1 heap allocations compared to
32/// `Vec<Vec<T>>` and enables efficient flat iteration.
33///
34/// Invariant: `widths.iter().sum::<usize>() == elems.len()`.
35#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>"))]
37pub struct RowList<T> {
38    elems: Vec<T>,
39    widths: Vec<usize>,
40}
41
42impl<T> RowList<T> {
43    /// Create a `RowList` from raw elements and widths.
44    ///
45    /// # Panics
46    ///
47    /// Panics if `widths.iter().sum() != elems.len()`.
48    pub fn new(elems: Vec<T>, widths: &[usize]) -> Self {
49        let expected: usize = widths.iter().sum();
50        assert_eq!(
51            elems.len(),
52            expected,
53            "RowList invariant violated: {} elems but widths sum to {}",
54            elems.len(),
55            expected,
56        );
57        Self {
58            elems,
59            widths: widths.to_vec(),
60        }
61    }
62
63    /// Build a `RowList` from an iterator of row-like items.
64    ///
65    /// Accepts anything that derefs to `[T]`: owned `Vec<T>`, `&Vec<T>`, `&[T]`, etc.
66    pub fn from_rows<R: AsRef<[T]>>(rows: impl IntoIterator<Item = R>) -> Self
67    where
68        T: Clone,
69    {
70        let mut elems = Vec::new();
71        let mut widths = Vec::new();
72        for row in rows {
73            let row = row.as_ref();
74            widths.push(row.len());
75            elems.extend_from_slice(row);
76        }
77        Self { elems, widths }
78    }
79
80    /// Contiguous element slice.
81    #[inline]
82    pub fn as_slice(&self) -> &[T] {
83        &self.elems
84    }
85
86    /// Iterate over all elements by value.
87    #[inline]
88    pub fn iter_values(&self) -> impl Iterator<Item = T> + '_
89    where
90        T: Copy,
91    {
92        self.elems.iter().copied()
93    }
94
95    /// Number of rows.
96    #[inline]
97    pub fn num_rows(&self) -> usize {
98        self.widths.len()
99    }
100
101    /// Iterate over rows as slices.
102    pub fn iter_rows(&self) -> impl Iterator<Item = &[T]> {
103        let mut offset = 0;
104        self.widths.iter().map(move |&w| {
105            let row = &self.elems[offset..offset + w];
106            offset += w;
107            row
108        })
109    }
110
111    /// Get a single row by index.
112    ///
113    /// # Panics
114    ///
115    /// Panics if `idx >= self.num_rows()`.
116    pub fn row(&self, idx: usize) -> &[T] {
117        let offset: usize = self.widths[..idx].iter().sum();
118        &self.elems[offset..offset + self.widths[idx]]
119    }
120}
121
122impl<T: Copy + Default> RowList<T> {
123    /// Iterate over all elements with each row zero-padded to a multiple of `alignment`.
124    ///
125    /// Alignment matches the cryptographic sponge's absorption rate. Both prover and
126    /// verifier must hash identical padded data for the Merkle commitment to verify,
127    /// so OOD evaluations sent over the transcript use the same padding convention.
128    ///
129    /// Yields the original row elements followed by implicit zeros, without allocating
130    /// a padded copy.
131    pub fn iter_aligned(&self, alignment: usize) -> impl Iterator<Item = T> + '_ {
132        self.iter_rows().flat_map(move |row| {
133            let padding = aligned_len(row.len(), alignment) - row.len();
134            row.iter()
135                .copied()
136                .chain(core::iter::repeat_n(T::default(), padding))
137        })
138    }
139}
140
141impl<T: Default + Clone> RowList<T> {
142    /// Build a `RowList` from an iterator of row-like items, padding each to `alignment`.
143    pub fn from_rows_aligned<R: AsRef<[T]>>(
144        rows: impl IntoIterator<Item = R>,
145        alignment: usize,
146    ) -> Self {
147        let mut elems = Vec::new();
148        let mut widths = Vec::new();
149        for row in rows {
150            let row = row.as_ref();
151            let padded_len = aligned_len(row.len(), alignment);
152            widths.push(padded_len);
153            elems.extend_from_slice(row);
154            elems.resize(elems.len() + (padded_len - row.len()), T::default());
155        }
156        Self { elems, widths }
157    }
158}
159
160/// Extension trait for `PackedValue` providing columnar pack/unpack operations.
161///
162/// These methods perform transpose operations on packed data, useful for
163/// SIMD-parallelized Merkle tree construction.
164pub trait PackedValueExt: PackedValue {
165    /// Pack columns from `WIDTH` rows of scalar values.
166    ///
167    /// Given `WIDTH` rows of `N` scalar values, extract each column and pack it
168    /// into a single packed value. This performs a transpose operation.
169    #[inline]
170    #[must_use]
171    fn pack_columns<const N: usize>(rows: &[[Self::Value; N]]) -> [Self; N] {
172        assert_eq!(rows.len(), Self::WIDTH);
173        array::from_fn(|col| Self::from_fn(|lane| rows[lane][col]))
174    }
175}
176
177// Blanket implementation for all PackedValue types
178impl<T: PackedValue> PackedValueExt for T {}
179
180/// Compute the aligned length for `len` given an alignment.
181#[inline]
182pub const fn aligned_len(len: usize, alignment: usize) -> usize {
183    if alignment <= 1 {
184        len
185    } else {
186        len.next_multiple_of(alignment)
187    }
188}
189
190/// Align each width in place, returning the same `Vec`.
191pub fn aligned_widths(mut widths: Vec<usize>, alignment: usize) -> Vec<usize> {
192    for w in &mut widths {
193        *w = aligned_len(*w, alignment);
194    }
195    widths
196}
197
198/// Pad a row with `Default::default()` so its length is a multiple of `alignment`.
199///
200/// This is a formatting convention for transcript hints; LMCS does not enforce that
201/// padded values are zero unless the caller checks them.
202pub fn pad_row_to_alignment<F: Default>(mut row: Vec<F>, alignment: usize) -> Vec<F> {
203    debug_assert!(alignment > 0, "alignment must be non-zero");
204    let padded_len = aligned_len(row.len(), alignment);
205    row.resize_with(padded_len, || F::default());
206    row
207}
208
209/// Upsample matrix to exactly `target_height` rows via nearest-neighbor repetition.
210///
211/// Each original row is repeated `target_height / height` times.
212/// Requires `target_height >= height` and both be powers of two.
213///
214/// This is the explicit form of the "lifting" operation used in LMCS, where smaller
215/// matrices are virtually extended to match the height of the tallest matrix.
216pub fn upsample_matrix<F: Clone + Send + Sync>(
217    matrix: &impl Matrix<F>,
218    target_height: usize,
219) -> RowMajorMatrix<F> {
220    let height = matrix.height();
221    assert!(target_height >= height);
222    assert!(height.is_power_of_two() && target_height.is_power_of_two());
223
224    let repeat_factor = target_height / height;
225    let width = matrix.width();
226
227    let mut values = Vec::with_capacity(target_height * width);
228    for row in matrix.rows() {
229        let row_vec: Vec<F> = row.collect();
230        for _ in 0..repeat_factor {
231            values.extend(row_vec.iter().cloned());
232        }
233    }
234
235    RowMajorMatrix::new(values, width)
236}