use alloc::{vec, vec::Vec};
use core::{array, mem};
use p3_field::PackedValue;
use p3_matrix::{Matrix, dense::RowMajorMatrix};
use p3_maybe_rayon::prelude::*;
use p3_miden_stateful_hasher::StatefulHasher;
use p3_miden_transcript::ProverChannel;
use p3_symmetric::{Hash, PseudoCompressionFunction};
use p3_util::log2_strict_usize;
use serde::{Deserialize, Serialize};
use tracing::{debug_span, info_span};
use crate::{
LmcsTree, Proof,
utils::{PackedValueExt, RowList, aligned_widths, pad_row_to_alignment},
};
#[derive(Debug, Serialize, Deserialize)]
pub struct LiftedMerkleTree<F, D, M, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize = 0> {
pub(crate) leaves: Vec<M>,
#[serde(bound(
serialize = "[D; DIGEST_ELEMS]: Serialize",
deserialize = "[D; DIGEST_ELEMS]: Deserialize<'de>"
))]
pub(crate) digest_layers: Vec<Vec<[D; DIGEST_ELEMS]>>,
pub(crate) salt: Option<RowMajorMatrix<F>>,
pub(crate) alignment: usize,
}
impl<F, D, M, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
LmcsTree<F, Hash<F, D, DIGEST_ELEMS>, M> for LiftedMerkleTree<F, D, M, DIGEST_ELEMS, SALT_ELEMS>
where
F: Copy + Default + PartialEq + Send + Sync,
D: Copy + Default + PartialEq + Send + Sync,
M: Matrix<F>,
{
fn root(&self) -> Hash<F, D, DIGEST_ELEMS> {
Hash::from(self.digest_layers.last().unwrap()[0])
}
fn height(&self) -> usize {
self.leaves.last().unwrap().height()
}
fn leaves(&self) -> &[M] {
&self.leaves
}
fn rows(&self, index: usize) -> RowList<F> {
let max_height = self.height();
let rows_iter = self.leaves.iter().map(|m| {
let log_scaling = log2_strict_usize(max_height / m.height());
let row_index = index >> log_scaling;
m.row_slice(row_index)
.expect("row_index must be valid after upsampling")
.to_vec()
});
RowList::from_rows_aligned(rows_iter, self.alignment)
}
fn alignment(&self) -> usize {
self.alignment
}
fn widths(&self) -> Vec<usize> {
let alignment = self.alignment;
let widths = self.leaves.iter().map(|m| m.width()).collect();
aligned_widths(widths, alignment)
}
fn prove_batch<Ch>(&self, indices: impl IntoIterator<Item = usize>, channel: &mut Ch)
where
Ch: ProverChannel<F = F, Commitment = Hash<F, D, DIGEST_ELEMS>>,
{
use alloc::collections::BTreeSet;
let final_height = self.leaves.last().unwrap().height();
let depth = log2_strict_usize(final_height);
let alignment = self.alignment;
let unique_indices: BTreeSet<usize> = indices.into_iter().collect();
for &index in &unique_indices {
assert!(
index < final_height,
"index {index} out of range {final_height}"
);
for m in self.leaves.iter() {
let height = m.height();
let log_scaling_factor = log2_strict_usize(final_height / height);
let row_index = index >> log_scaling_factor;
let row = m
.row_slice(row_index)
.expect("row_index must be valid after upsampling")
.to_vec();
let row = pad_row_to_alignment(row, alignment);
channel.hint_field_slice(&row);
}
if SALT_ELEMS > 0 {
let salt = self.salt(index);
channel.hint_field_slice(&salt);
}
}
let mut known = unique_indices;
for layer_idx in 0..depth {
let mut parents = BTreeSet::new();
for &pos in &known {
let parent_pos = pos / 2;
if !parents.insert(parent_pos) {
continue; }
let left_pos = parent_pos * 2;
let right_pos = left_pos + 1;
let have_left = known.contains(&left_pos);
let have_right = known.contains(&right_pos);
if have_left && !have_right {
channel.hint_commitment(Hash::from(self.digest_layers[layer_idx][right_pos]));
} else if !have_left && have_right {
channel.hint_commitment(Hash::from(self.digest_layers[layer_idx][left_pos]));
}
}
known = parents;
}
}
}
impl<F, D, M, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
LiftedMerkleTree<F, D, M, DIGEST_ELEMS, SALT_ELEMS>
where
F: Copy + Default + PartialEq + Send + Sync,
D: Copy + Default + PartialEq + Send + Sync,
M: Matrix<F>,
{
pub(crate) fn build_with_alignment<PF, PD, H, C, const WIDTH: usize>(
h: &H,
c: &C,
leaves: Vec<M>,
salt: Option<RowMajorMatrix<F>>,
alignment: usize,
) -> Self
where
PF: PackedValue<Value = F>,
PD: PackedValue<Value = D>,
H: StatefulHasher<F, [D; DIGEST_ELEMS], State = [D; WIDTH]>
+ StatefulHasher<PF, [PD; DIGEST_ELEMS], State = [PD; WIDTH]>
+ Sync,
C: PseudoCompressionFunction<[D; DIGEST_ELEMS], 2>
+ PseudoCompressionFunction<[PD; DIGEST_ELEMS], 2>
+ Sync,
{
const { assert!(PF::WIDTH == PD::WIDTH) }
assert!(!leaves.is_empty(), "cannot commit empty batch");
debug_assert!(alignment > 0, "alignment must be non-zero");
let leaf_digests: Vec<[PD::Value; DIGEST_ELEMS]> =
info_span!("hash leaves").in_scope(|| {
let mut leaf_states: Vec<[PD::Value; WIDTH]> =
build_leaf_states_upsampled::<PF, PD, M, H, WIDTH, DIGEST_ELEMS>(&leaves, h);
if let Some(ref salt_matrix) = salt {
debug_assert_eq!(salt_matrix.height(), leaf_states.len());
debug_assert_eq!(salt_matrix.width(), SALT_ELEMS);
absorb_matrix::<PF, PD, _, _, WIDTH, DIGEST_ELEMS>(
&mut leaf_states,
salt_matrix,
h,
);
}
leaf_states
.into_par_iter()
.map(|state| h.squeeze(&state))
.collect()
});
let digest_layers = debug_span!("compress tree layers").in_scope(|| {
let mut digest_layers = vec![leaf_digests];
loop {
let prev_layer = digest_layers.last().unwrap();
if prev_layer.len() == 1 {
break;
}
let next_layer = compress_uniform::<PD, C, DIGEST_ELEMS>(prev_layer, c);
digest_layers.push(next_layer);
}
digest_layers
});
Self {
leaves,
digest_layers,
salt,
alignment: alignment.max(1),
}
}
pub fn single_proof(&self, index: usize) -> Proof<F, Hash<F, D, DIGEST_ELEMS>, SALT_ELEMS> {
let mut siblings = Vec::with_capacity(self.digest_layers.len().saturating_sub(1));
let mut layer_index = index;
for layer in &self.digest_layers {
if layer.len() == 1 {
break;
}
let sibling = layer[layer_index ^ 1];
siblings.push(Hash::from(sibling));
layer_index >>= 1;
}
Proof {
rows: self.rows(index),
salt: self.salt(index),
siblings,
}
}
pub fn alignment(&self) -> usize {
self.alignment
}
pub fn salt(&self, index: usize) -> [F; SALT_ELEMS] {
match &self.salt {
Some(salt_matrix) => {
let row = salt_matrix.row_slice(index).expect("index must be valid");
array::from_fn(|i| row[i])
}
None => {
debug_assert!(
SALT_ELEMS == 0,
"tree constructed without salt but SALT_ELEMS > 0"
);
[F::default(); SALT_ELEMS]
}
}
}
}
fn build_leaf_states_upsampled<PF, PD, M, H, const WIDTH: usize, const DIGEST_ELEMS: usize>(
matrices: &[M],
sponge: &H,
) -> Vec<[PD::Value; WIDTH]>
where
PF: PackedValue,
PD: PackedValue,
M: Matrix<PF::Value>,
H: StatefulHasher<PF::Value, [PD::Value; DIGEST_ELEMS], State = [PD::Value; WIDTH]>
+ StatefulHasher<PF, [PD; DIGEST_ELEMS], State = [PD; WIDTH]>
+ Sync,
{
const { assert!(PF::WIDTH.is_power_of_two()) };
const { assert!(PD::WIDTH.is_power_of_two()) };
let final_height = validate_heights(matrices.iter().map(|d| d.dimensions().height));
let default_state = [PD::Value::default(); WIDTH];
let mut states = vec![default_state; final_height];
let mut scratch_states = vec![default_state; final_height];
let mut active_height = matrices.first().unwrap().height();
for matrix in matrices {
let height = matrix.height();
if height > active_height {
let scaling_factor = height / active_height;
scratch_states[..height]
.par_chunks_mut(scaling_factor)
.zip(states[..active_height].par_iter())
.for_each(|(chunk, state)| chunk.fill(*state));
mem::swap(&mut scratch_states, &mut states);
}
absorb_matrix::<PF, PD, _, _, _, _>(&mut states[..height], matrix, sponge);
active_height = height;
}
states
}
fn absorb_matrix<PF, PD, M, H, const WIDTH: usize, const DIGEST_ELEMS: usize>(
states: &mut [[PD::Value; WIDTH]],
matrix: &M,
sponge: &H,
) where
PF: PackedValue,
PD: PackedValue,
M: Matrix<PF::Value>,
H: StatefulHasher<PF::Value, [PD::Value; DIGEST_ELEMS], State = [PD::Value; WIDTH]>
+ StatefulHasher<PF, [PD; DIGEST_ELEMS], State = [PD; WIDTH]>
+ Sync,
{
let height = matrix.height();
assert_eq!(height, states.len());
if height < PF::WIDTH || PF::WIDTH == 1 {
states
.par_iter_mut()
.zip(matrix.par_rows())
.for_each(|(state, row)| {
sponge.absorb_into(state, row);
});
} else {
states
.par_chunks_mut(PF::WIDTH)
.enumerate()
.for_each(|(packed_idx, states_chunk)| {
let mut packed_state: [PD; WIDTH] = PD::pack_columns(states_chunk);
let row_idx = packed_idx * PF::WIDTH;
let row = matrix.vertically_packed_row::<PF>(row_idx);
sponge.absorb_into(&mut packed_state, row);
PD::unpack_into(&packed_state, states_chunk);
});
}
}
fn compress_uniform<
P: PackedValue,
C: PseudoCompressionFunction<[P::Value; DIGEST_ELEMS], 2>
+ PseudoCompressionFunction<[P; DIGEST_ELEMS], 2>
+ Sync,
const DIGEST_ELEMS: usize,
>(
prev_layer: &[[P::Value; DIGEST_ELEMS]],
c: &C,
) -> Vec<[P::Value; DIGEST_ELEMS]> {
assert!(
prev_layer.len().is_power_of_two(),
"previous layer length must be a power of 2"
);
let next_len = prev_layer.len() / 2;
let default_digest = [P::Value::default(); DIGEST_ELEMS];
let mut next_digests = vec![default_digest; next_len];
if next_len < P::WIDTH || P::WIDTH == 1 {
let (prev_layer_pairs, _) = prev_layer.as_chunks::<2>();
next_digests
.par_iter_mut()
.zip(prev_layer_pairs.par_iter())
.for_each(|(next_digest, prev_layer_pair)| {
*next_digest = c.compress(*prev_layer_pair);
});
} else {
next_digests
.par_chunks_exact_mut(P::WIDTH)
.enumerate()
.for_each(|(packed_chunk_idx, digests_chunk)| {
let chunk_idx = packed_chunk_idx * P::WIDTH;
let left: [P; DIGEST_ELEMS] =
array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (chunk_idx + k)][j]));
let right: [P; DIGEST_ELEMS] =
array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (chunk_idx + k) + 1][j]));
let packed_digest = c.compress([left, right]);
P::unpack_into(&packed_digest, digests_chunk);
});
}
next_digests
}
fn validate_heights(heights: impl IntoIterator<Item = usize>) -> usize {
let mut active_height = 0;
for (matrix, height) in heights.into_iter().enumerate() {
assert_ne!(height, 0, "zero height at matrix {matrix}");
assert!(
height.is_power_of_two(),
"non-power-of-two height at matrix {matrix}"
);
assert!(height >= active_height, "matrices must be sorted by height");
active_height = height;
}
assert_ne!(active_height, 0, "empty batch");
active_height
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use p3_matrix::{Matrix, dense::RowMajorMatrix};
use p3_miden_dev_utils::configs::baby_bear_poseidon2 as bb;
use p3_miden_stateful_hasher::StatefulHasher;
use rand::{SeedableRng, rngs::SmallRng};
use crate::{
tests::{DIGEST, F, P, RATE, Sponge, build_leaves_single, concatenate_matrices},
utils::upsample_matrix,
};
fn build_leaves_upsampled(matrices: &[RowMajorMatrix<F>], sponge: &Sponge) -> Vec<[F; DIGEST]> {
let mut states = super::build_leaf_states_upsampled::<P, P, _, _, _, _>(matrices, sponge);
states.iter_mut().map(|s| sponge.squeeze(s)).collect()
}
#[test]
fn upsampled_equivalence() {
let (_, sponge, _compressor) = bb::test_components();
let mut rng = SmallRng::seed_from_u64(42);
for scenario in p3_miden_dev_utils::fixtures::matrix_scenarios::<P>(RATE) {
let matrices: Vec<RowMajorMatrix<F>> = scenario
.into_iter()
.map(|(h, w)| RowMajorMatrix::rand(&mut rng, h, w))
.collect();
let max_height = matrices.last().unwrap().height();
let leaves = build_leaves_upsampled(&matrices, &sponge);
let matrices_upsampled: Vec<_> = matrices
.iter()
.map(|m: &RowMajorMatrix<F>| upsample_matrix(m, max_height))
.collect();
let leaves_lifted = build_leaves_upsampled(&matrices_upsampled, &sponge);
assert_eq!(leaves, leaves_lifted);
let matrix_single = concatenate_matrices::<_, RATE>(&matrices_upsampled);
let leaves_single = build_leaves_single(&matrix_single, &sponge);
assert_eq!(leaves, leaves_single);
}
}
}