p3-miden-lmcs 0.5.0

Lifted Matrix Commitment Scheme (LMCS) for matrices with power-of-two heights.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
use alloc::{vec, vec::Vec};
use core::{array, mem};

use p3_field::PackedValue;
use p3_matrix::{Matrix, dense::RowMajorMatrix};
use p3_maybe_rayon::prelude::*;
use p3_miden_stateful_hasher::StatefulHasher;
use p3_miden_transcript::ProverChannel;
use p3_symmetric::{Hash, PseudoCompressionFunction};
use p3_util::log2_strict_usize;
use serde::{Deserialize, Serialize};
use tracing::{debug_span, info_span};

use crate::{
    LmcsTree, Proof,
    utils::{PackedValueExt, RowList, aligned_widths, pad_row_to_alignment},
};

/// A uniform binary Merkle tree whose leaves are constructed from matrices with power-of-two heights.
///
/// # Type Parameters
///
/// * `F` – scalar field element type used in both matrices and hash words.
/// * `D` – digest element type.
/// * `M` – matrix type. Must implement [`Matrix<F>`].
/// * `DIGEST_ELEMS` – number of elements in one digest.
/// * `SALT_ELEMS` – number of salt elements per leaf (0 = non-hiding, >0 = hiding).
///
/// Unlike the standard `MerkleTree`, this uniform variant requires:
/// - **All matrix heights must be powers of two**
/// - **Matrices must be sorted by height** (shortest to tallest)
/// - Uses incremental hashing via [`StatefulHasher`] instead of one-shot hashing
///
/// The per-leaf row composition uses nearest-neighbor upsampling: each matrix Mᵢ is virtually
/// extended to height N (width unchanged) by repeating each row rᵢ = N/nᵢ times
/// contiguously. For leaf index `j`, the sponge absorbs the `j`-th row from each lifted matrix
/// in sequence. The sponge applies its own padding semantics during absorption; LMCS alignment
/// only affects transcript hints.
///
/// Note: alignment padding is a convention for transcript openings and does not affect the
/// commitment. It is independent of the sponge's absorption alignment. LMCS does not enforce
/// that padded columns are zero; verifiers cannot distinguish zero padding from arbitrary values
/// unless they check those columns or constrain them elsewhere.
///
/// Equivalent single-matrix view: this commitment is equivalent to first forming a single
/// height-`N` matrix by (a) lifting every input matrix to height `N`, (b) padding each lifted
/// matrix horizontally with zero columns to reflect the sponge's absorption alignment (if any),
/// and (c) concatenating the results side-by-side. The leaf hash at index `j` is then the
/// sponge of that single concatenated matrix's row `j`. This is a conceptual view: LMCS does
/// not enforce that those padded columns are zero.
///
/// Since [`StatefulHasher`] operates on a single field type, this tree uses the same type `F`
/// for both matrix elements and hash words, unlike `MerkleTree` which can hash `F → W`.
///
/// Use [`root`](Self::root) to fetch the final commitment once the tree is built.
///
/// ## Transcript Hints
///
/// `prove_batch` streams transcript hints in the format expected by
/// [`Lmcs::open_batch`](crate::Lmcs::open_batch):
/// - For each unique query index **in sorted tree index order** (ascending, deduplicated): one
///   row per matrix (in leaf order), then `SALT_ELEMS` field elements of salt.
/// - Each row is padded with explicit zeros to the LMCS alignment.
///   This allows verifiers to absorb fixed-size chunks without special-casing
///   the final partial chunk; padding is not enforced to be zero.
/// - After all indices: missing sibling hashes, level-by-level, left-to-right, bottom-to-top.
///
/// Hints are not observed into the Fiat-Shamir challenger.
///
/// This generally shouldn't be used directly. If you're using a Merkle tree as an MMCS,
/// see the MMCS wrapper types.
#[derive(Debug, Serialize, Deserialize)]
pub struct LiftedMerkleTree<F, D, M, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize = 0> {
    /// All leaf matrices in insertion order.
    ///
    /// Matrices must be sorted by height (shortest to tallest) and all heights must be
    /// powers of two. Each matrix's rows are absorbed into sponge states that are
    /// maintained and upsampled across matrices of increasing height.
    ///
    /// This vector is retained for inspection or re-opening of the tree; it is not used
    /// after construction time.
    pub(crate) leaves: Vec<M>,

    /// All intermediate hash layers (digest arrays), index 0 being the leaf hash layer
    /// and the last layer containing exactly one root hash.
    ///
    /// Every inner vector holds contiguous hashes. Higher layers are built by
    /// compressing pairs from the previous layer.
    #[serde(bound(
        serialize = "[D; DIGEST_ELEMS]: Serialize",
        deserialize = "[D; DIGEST_ELEMS]: Deserialize<'de>"
    ))]
    pub(crate) digest_layers: Vec<Vec<[D; DIGEST_ELEMS]>>,

    /// Salt matrix for hiding commitment. Each row contains `SALT_ELEMS` random field elements.
    /// `None` when `SALT_ELEMS = 0` (non-hiding mode).
    pub(crate) salt: Option<RowMajorMatrix<F>>,
    /// Column alignment used for transcript proofs.
    pub(crate) alignment: usize,
}

impl<F, D, M, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
    LmcsTree<F, Hash<F, D, DIGEST_ELEMS>, M> for LiftedMerkleTree<F, D, M, DIGEST_ELEMS, SALT_ELEMS>
where
    F: Copy + Default + PartialEq + Send + Sync,
    D: Copy + Default + PartialEq + Send + Sync,
    M: Matrix<F>,
{
    fn root(&self) -> Hash<F, D, DIGEST_ELEMS> {
        Hash::from(self.digest_layers.last().unwrap()[0])
    }

    fn height(&self) -> usize {
        self.leaves.last().unwrap().height()
    }

    fn leaves(&self) -> &[M] {
        &self.leaves
    }

    /// Return the upsampled rows for `index`, padded to `alignment`.
    ///
    /// Padding uses `Default::default()` and is not enforced by verification; callers
    /// that require zero padding must check these columns explicitly.
    ///
    /// Panics if `index` is out of range for the tree height.
    fn rows(&self, index: usize) -> RowList<F> {
        let max_height = self.height();
        let rows_iter = self.leaves.iter().map(|m| {
            // Lifting: a matrix of height h maps leaf index i to row i >> log₂(max_height/h).
            let log_scaling = log2_strict_usize(max_height / m.height());
            let row_index = index >> log_scaling;
            m.row_slice(row_index)
                .expect("row_index must be valid after upsampling")
                .to_vec()
        });
        RowList::from_rows_aligned(rows_iter, self.alignment)
    }

    fn alignment(&self) -> usize {
        self.alignment
    }

    fn widths(&self) -> Vec<usize> {
        let alignment = self.alignment;
        let widths = self.leaves.iter().map(|m| m.width()).collect();
        aligned_widths(widths, alignment)
    }

    /// Prove a batch opening and stream it into a transcript channel.
    ///
    /// Panics if any index is out of range. Rows are padded to `alignment` and those
    /// padding values are not validated by verification; callers that require zero
    /// padding must check the opened rows explicitly.
    ///
    /// Leaf openings are written in **sorted tree index order** (ascending, deduplicated).
    fn prove_batch<Ch>(&self, indices: impl IntoIterator<Item = usize>, channel: &mut Ch)
    where
        Ch: ProverChannel<F = F, Commitment = Hash<F, D, DIGEST_ELEMS>>,
    {
        use alloc::collections::BTreeSet;

        let final_height = self.leaves.last().unwrap().height();
        let depth = log2_strict_usize(final_height);
        let alignment = self.alignment;

        // Collect and deduplicate indices. BTreeSet iteration yields sorted order,
        // which is critical for transcript determinism: both prover and verifier
        // must process indices in the same order.
        let unique_indices: BTreeSet<usize> = indices.into_iter().collect();

        // Stream leaf openings in sorted tree index order.
        for &index in &unique_indices {
            assert!(
                index < final_height,
                "index {index} out of range {final_height}"
            );
            for m in self.leaves.iter() {
                let height = m.height();
                let log_scaling_factor = log2_strict_usize(final_height / height);
                let row_index = index >> log_scaling_factor;
                let row = m
                    .row_slice(row_index)
                    .expect("row_index must be valid after upsampling")
                    .to_vec();
                let row = pad_row_to_alignment(row, alignment);
                channel.hint_field_slice(&row);
            }
            if SALT_ELEMS > 0 {
                let salt = self.salt(index);
                channel.hint_field_slice(&salt);
            }
        }

        // Use the same sorted set for sibling traversal
        let mut known = unique_indices;

        // Walk up the tree level by level using the deduplicated set.
        for layer_idx in 0..depth {
            let mut parents = BTreeSet::new();

            // BTreeSet iterates in sorted order (left-to-right)
            for &pos in &known {
                let parent_pos = pos / 2;
                if !parents.insert(parent_pos) {
                    continue; // Already processed this pair
                }

                let left_pos = parent_pos * 2;
                let right_pos = left_pos + 1;
                let have_left = known.contains(&left_pos);
                let have_right = known.contains(&right_pos);

                // Add sibling hash if exactly one child is known
                if have_left && !have_right {
                    channel.hint_commitment(Hash::from(self.digest_layers[layer_idx][right_pos]));
                } else if !have_left && have_right {
                    channel.hint_commitment(Hash::from(self.digest_layers[layer_idx][left_pos]));
                }
            }

            known = parents;
        }
    }
}

impl<F, D, M, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
    LiftedMerkleTree<F, D, M, DIGEST_ELEMS, SALT_ELEMS>
where
    F: Copy + Default + PartialEq + Send + Sync,
    D: Copy + Default + PartialEq + Send + Sync,
    M: Matrix<F>,
{
    /// Builder for creating trees with optional salt and explicit alignment.
    ///
    /// Preconditions:
    /// - `leaves` is non-empty and heights are powers of two.
    /// - Matrices are sorted by height (shortest to tallest).
    ///
    /// `alignment` controls transcript padding only; it does not affect the commitment.
    /// LMCS does not enforce that padded columns are zero.
    ///
    /// Panics if `leaves` is empty.
    pub(crate) fn build_with_alignment<PF, PD, H, C, const WIDTH: usize>(
        h: &H,
        c: &C,
        leaves: Vec<M>,
        salt: Option<RowMajorMatrix<F>>,
        alignment: usize,
    ) -> Self
    where
        PF: PackedValue<Value = F>,
        PD: PackedValue<Value = D>,
        H: StatefulHasher<F, [D; DIGEST_ELEMS], State = [D; WIDTH]>
            + StatefulHasher<PF, [PD; DIGEST_ELEMS], State = [PD; WIDTH]>
            + Sync,
        C: PseudoCompressionFunction<[D; DIGEST_ELEMS], 2>
            + PseudoCompressionFunction<[PD; DIGEST_ELEMS], 2>
            + Sync,
    {
        const { assert!(PF::WIDTH == PD::WIDTH) }
        assert!(!leaves.is_empty(), "cannot commit empty batch");
        debug_assert!(alignment > 0, "alignment must be non-zero");

        // Build leaf hashes: absorb all matrix rows into sponge states, then squeeze.
        let leaf_digests: Vec<[PD::Value; DIGEST_ELEMS]> =
            info_span!("hash leaves").in_scope(|| {
                let mut leaf_states: Vec<[PD::Value; WIDTH]> =
                    build_leaf_states_upsampled::<PF, PD, M, H, WIDTH, DIGEST_ELEMS>(&leaves, h);

                // Absorb salt into states using SIMD-parallelized path (no-op when salt is None)
                if let Some(ref salt_matrix) = salt {
                    debug_assert_eq!(salt_matrix.height(), leaf_states.len());
                    debug_assert_eq!(salt_matrix.width(), SALT_ELEMS);
                    absorb_matrix::<PF, PD, _, _, WIDTH, DIGEST_ELEMS>(
                        &mut leaf_states,
                        salt_matrix,
                        h,
                    );
                }

                // Squeeze the final hashes from the states
                leaf_states
                    .into_par_iter()
                    .map(|state| h.squeeze(&state))
                    .collect()
            });

        // Build digest layers by repeatedly compressing until we reach the root
        let digest_layers = debug_span!("compress tree layers").in_scope(|| {
            let mut digest_layers = vec![leaf_digests];
            loop {
                let prev_layer = digest_layers.last().unwrap();
                if prev_layer.len() == 1 {
                    break;
                }

                let next_layer = compress_uniform::<PD, C, DIGEST_ELEMS>(prev_layer, c);
                digest_layers.push(next_layer);
            }
            digest_layers
        });

        Self {
            leaves,
            digest_layers,
            salt,
            alignment: alignment.max(1),
        }
    }

    /// Build a full opening proof for a single leaf index.
    ///
    /// Rows are padded to `alignment` and LMCS does not enforce that padding is zero.
    /// Panics if `index` is out of range for the tree height.
    pub fn single_proof(&self, index: usize) -> Proof<F, Hash<F, D, DIGEST_ELEMS>, SALT_ELEMS> {
        let mut siblings = Vec::with_capacity(self.digest_layers.len().saturating_sub(1));
        let mut layer_index = index;
        for layer in &self.digest_layers {
            if layer.len() == 1 {
                break;
            }
            let sibling = layer[layer_index ^ 1];
            siblings.push(Hash::from(sibling));
            layer_index >>= 1;
        }

        Proof {
            rows: self.rows(index),
            salt: self.salt(index),
            siblings,
        }
    }

    /// Column alignment used when streaming openings.
    pub fn alignment(&self) -> usize {
        self.alignment
    }

    /// Extract the salt for the given leaf index.
    ///
    /// # Panics
    ///
    /// Panics if `index` is out of range, or if `SALT_ELEMS > 0` but the tree was
    /// constructed without salt.
    pub fn salt(&self, index: usize) -> [F; SALT_ELEMS] {
        match &self.salt {
            Some(salt_matrix) => {
                let row = salt_matrix.row_slice(index).expect("index must be valid");
                // Tree construction guarantees salt width == SALT_ELEMS
                array::from_fn(|i| row[i])
            }
            None => {
                // For SALT_ELEMS == 0, this returns an empty array.
                // For SALT_ELEMS > 0, this should never be reached if using safe constructors.
                debug_assert!(
                    SALT_ELEMS == 0,
                    "tree constructed without salt but SALT_ELEMS > 0"
                );
                [F::default(); SALT_ELEMS]
            }
        }
    }
}

/// Build leaf states using the upsampled view (nearest-neighbor upsampling).
///
/// Returns the sponge states after absorbing all matrix rows but **before squeezing**.
/// Callers must squeeze the states to obtain final leaf hashes.
///
/// Conceptually, each matrix is virtually extended to height `H` by repeating each row
/// `L = H / h` times (width unchanged), and the leaf `r` absorbs the `r`-th row from each
/// extended matrix in order. Each absorbed row is virtually padded with zeros to a multiple of the
/// hasher's padding width for absorption; see [`LiftedMerkleTree`](crate::LiftedMerkleTree) docs
/// for the equivalent single-matrix view.
///
/// Padding is implicit and not checked; callers that require zero padding must enforce
/// it elsewhere.
///
/// # Preconditions
/// - `matrices` is non-empty and sorted by non-decreasing power-of-two heights.
/// - `P::WIDTH` is a power of two.
///
/// Panics in debug builds if preconditions are violated.
fn build_leaf_states_upsampled<PF, PD, M, H, const WIDTH: usize, const DIGEST_ELEMS: usize>(
    matrices: &[M],
    sponge: &H,
) -> Vec<[PD::Value; WIDTH]>
where
    PF: PackedValue,
    PD: PackedValue,
    M: Matrix<PF::Value>,
    H: StatefulHasher<PF::Value, [PD::Value; DIGEST_ELEMS], State = [PD::Value; WIDTH]>
        + StatefulHasher<PF, [PD; DIGEST_ELEMS], State = [PD; WIDTH]>
        + Sync,
{
    const { assert!(PF::WIDTH.is_power_of_two()) };
    const { assert!(PD::WIDTH.is_power_of_two()) };
    let final_height = validate_heights(matrices.iter().map(|d| d.dimensions().height));

    // Memory buffers:
    // - states: Per-leaf scalar states (one per final row), maintained across matrices.
    // - scratch_states: Temporary buffer used when duplicating states during upsampling.
    let default_state = [PD::Value::default(); WIDTH];
    let mut states = vec![default_state; final_height];
    let mut scratch_states = vec![default_state; final_height];

    let mut active_height = matrices.first().unwrap().height();

    for matrix in matrices {
        let height = matrix.height();

        // Upsample states when height increases (applies to both scalar and packed paths).
        // Duplicate each existing state to fill the expanded height.
        // E.g., [s0, s1] with scaling_factor=2 → [s0, s0, s1, s1]
        if height > active_height {
            let scaling_factor = height / active_height;

            // Copy `states` into `scratch_states`, repeating each entry `scaling_factor` times
            // so we keep the accumulated sponge states aligned with the taller matrix.
            scratch_states[..height]
                .par_chunks_mut(scaling_factor)
                .zip(states[..active_height].par_iter())
                .for_each(|(chunk, state)| chunk.fill(*state));

            // Copy upsampled states back to canonical buffer
            mem::swap(&mut scratch_states, &mut states);
        }

        // Absorb the rows of the matrix into the extended state vector
        absorb_matrix::<PF, PD, _, _, _, _>(&mut states[..height], matrix, sponge);

        active_height = height;
    }

    states
}

/// Incorporate one matrix's row-wise contribution into the running per-leaf states.
///
/// Semantics: given `states` of length `h = matrix.height()`, for each row index `r ∈ [0, h)`
/// update `states[r]` by absorbing the matrix row `r` into that state. In the overall tree
/// construction, callers ensure that `states` is the correct lifted view for the current matrix
/// (either the "nearest-neighbor" duplication or the "modulo" duplication across the final
/// height). This helper performs exactly one absorption round for that matrix and returns with the
/// states mutated; it does not change the lifting shape or squeeze hashes.
fn absorb_matrix<PF, PD, M, H, const WIDTH: usize, const DIGEST_ELEMS: usize>(
    states: &mut [[PD::Value; WIDTH]],
    matrix: &M,
    sponge: &H,
) where
    PF: PackedValue,
    PD: PackedValue,
    M: Matrix<PF::Value>,
    H: StatefulHasher<PF::Value, [PD::Value; DIGEST_ELEMS], State = [PD::Value; WIDTH]>
        + StatefulHasher<PF, [PD; DIGEST_ELEMS], State = [PD; WIDTH]>
        + Sync,
{
    let height = matrix.height();
    assert_eq!(height, states.len());

    if height < PF::WIDTH || PF::WIDTH == 1 {
        // Scalar path: walk every final leaf state and absorb the wrapped row for this matrix.
        states
            .par_iter_mut()
            .zip(matrix.par_rows())
            .for_each(|(state, row)| {
                sponge.absorb_into(state, row);
            });
    } else {
        // SIMD path: gather → absorb wrapped packed row → scatter per chunk.
        states
            .par_chunks_mut(PF::WIDTH)
            .enumerate()
            .for_each(|(packed_idx, states_chunk)| {
                let mut packed_state: [PD; WIDTH] = PD::pack_columns(states_chunk);
                let row_idx = packed_idx * PF::WIDTH;
                let row = matrix.vertically_packed_row::<PF>(row_idx);
                sponge.absorb_into(&mut packed_state, row);
                PD::unpack_into(&packed_state, states_chunk);
            });
    }
}

/// Compress a layer of hashes in a uniform Merkle tree.
///
/// Takes a layer of hashes and compresses pairs into a new layer with half as many elements.
/// The layer length must be a power of two.
///
/// When the result would be smaller than the packing width, uses a pure scalar path.
/// Otherwise uses SIMD parallelization. Since both the result length and packing width are
/// powers of two, the result is always a multiple of the packing width in the SIMD path,
/// requiring no scalar fallback for remainders.
fn compress_uniform<
    P: PackedValue,
    C: PseudoCompressionFunction<[P::Value; DIGEST_ELEMS], 2>
        + PseudoCompressionFunction<[P; DIGEST_ELEMS], 2>
        + Sync,
    const DIGEST_ELEMS: usize,
>(
    prev_layer: &[[P::Value; DIGEST_ELEMS]],
    c: &C,
) -> Vec<[P::Value; DIGEST_ELEMS]> {
    assert!(
        prev_layer.len().is_power_of_two(),
        "previous layer length must be a power of 2"
    );

    let next_len = prev_layer.len() / 2;
    let default_digest = [P::Value::default(); DIGEST_ELEMS];
    let mut next_digests = vec![default_digest; next_len];

    // Use scalar path when output is too small for packing
    if next_len < P::WIDTH || P::WIDTH == 1 {
        let (prev_layer_pairs, _) = prev_layer.as_chunks::<2>();
        next_digests
            .par_iter_mut()
            .zip(prev_layer_pairs.par_iter())
            .for_each(|(next_digest, prev_layer_pair)| {
                *next_digest = c.compress(*prev_layer_pair);
            });
    } else {
        // Packed path: since next_len and P::WIDTH are both powers of 2,
        // next_len is a multiple of P::WIDTH, so no remainder handling needed.
        next_digests
            .par_chunks_exact_mut(P::WIDTH)
            .enumerate()
            .for_each(|(packed_chunk_idx, digests_chunk)| {
                let chunk_idx = packed_chunk_idx * P::WIDTH;
                let left: [P; DIGEST_ELEMS] =
                    array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (chunk_idx + k)][j]));
                let right: [P; DIGEST_ELEMS] =
                    array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (chunk_idx + k) + 1][j]));
                let packed_digest = c.compress([left, right]);
                P::unpack_into(&packed_digest, digests_chunk);
            });
    }
    next_digests
}

/// Validate a sequence of matrix heights for LMCS.
///
/// Requirements enforced:
/// - Non-empty sequence (at least one matrix).
/// - Every height is a power of two and non-zero.
/// - Heights are in non-decreasing order (sorted by height), so the last height is the maximum
///   `H` used by lifting.
///
/// # Panics
/// Panics if any requirement is violated.
fn validate_heights(heights: impl IntoIterator<Item = usize>) -> usize {
    let mut active_height = 0;

    for (matrix, height) in heights.into_iter().enumerate() {
        assert_ne!(height, 0, "zero height at matrix {matrix}");
        assert!(
            height.is_power_of_two(),
            "non-power-of-two height at matrix {matrix}"
        );
        assert!(height >= active_height, "matrices must be sorted by height");
        active_height = height;
    }

    assert_ne!(active_height, 0, "empty batch");
    active_height
}

#[cfg(test)]
mod tests {
    use alloc::vec::Vec;

    use p3_matrix::{Matrix, dense::RowMajorMatrix};
    use p3_miden_dev_utils::configs::baby_bear_poseidon2 as bb;
    use p3_miden_stateful_hasher::StatefulHasher;
    use rand::{SeedableRng, rngs::SmallRng};

    use crate::{
        tests::{DIGEST, F, P, RATE, Sponge, build_leaves_single, concatenate_matrices},
        utils::upsample_matrix,
    };

    fn build_leaves_upsampled(matrices: &[RowMajorMatrix<F>], sponge: &Sponge) -> Vec<[F; DIGEST]> {
        let mut states = super::build_leaf_states_upsampled::<P, P, _, _, _, _>(matrices, sponge);
        states.iter_mut().map(|s| sponge.squeeze(s)).collect()
    }

    /// Test that upsampled lifting produces correct results:
    /// 1. Incremental lifting equals explicit lifting
    /// 2. Explicit lifting equals single-matrix concatenation baseline
    #[test]
    fn upsampled_equivalence() {
        let (_, sponge, _compressor) = bb::test_components();
        let mut rng = SmallRng::seed_from_u64(42);

        for scenario in p3_miden_dev_utils::fixtures::matrix_scenarios::<P>(RATE) {
            let matrices: Vec<RowMajorMatrix<F>> = scenario
                .into_iter()
                .map(|(h, w)| RowMajorMatrix::rand(&mut rng, h, w))
                .collect();

            let max_height = matrices.last().unwrap().height();

            // Upsampled path equivalence vs explicit upsampled lifting and single-concat baseline
            let leaves = build_leaves_upsampled(&matrices, &sponge);

            let matrices_upsampled: Vec<_> = matrices
                .iter()
                .map(|m: &RowMajorMatrix<F>| upsample_matrix(m, max_height))
                .collect();
            let leaves_lifted = build_leaves_upsampled(&matrices_upsampled, &sponge);
            assert_eq!(leaves, leaves_lifted);

            let matrix_single = concatenate_matrices::<_, RATE>(&matrices_upsampled);
            let leaves_single = build_leaves_single(&matrix_single, &sponge);
            assert_eq!(leaves, leaves_single);
        }
    }
}