use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use p3_miden_transcript::{TranscriptError, VerifierChannel};
use serde::{Deserialize, Serialize};
use crate::{Lmcs, utils::RowList};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(bound(
serialize = "F: Serialize, C: Serialize, [F; SALT_ELEMS]: Serialize",
deserialize = "F: Deserialize<'de>, C: Deserialize<'de>, [F; SALT_ELEMS]: Deserialize<'de>"
))]
pub struct Proof<F, C, const SALT_ELEMS: usize = 0> {
pub rows: RowList<F>,
pub salt: [F; SALT_ELEMS],
pub siblings: Vec<C>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(bound(
serialize = "F: Serialize, [F; SALT_ELEMS]: Serialize",
deserialize = "F: Deserialize<'de>, [F; SALT_ELEMS]: Deserialize<'de>"
))]
pub struct LeafOpening<F, const SALT_ELEMS: usize = 0> {
pub rows: RowList<F>,
pub salt: [F; SALT_ELEMS],
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(bound(
serialize = "F: Serialize, C: Serialize, [F; SALT_ELEMS]: Serialize",
deserialize = "F: Deserialize<'de>, C: Deserialize<'de>, [F; SALT_ELEMS]: Deserialize<'de>"
))]
pub struct BatchProof<F, C, const SALT_ELEMS: usize = 0> {
pub openings: BTreeMap<usize, LeafOpening<F, SALT_ELEMS>>,
pub siblings: BTreeMap<(usize, usize), C>,
}
impl<F, C, const SALT_ELEMS: usize> BatchProof<F, C, SALT_ELEMS> {
pub fn read_from_channel<Ch>(
widths: &[usize],
log_max_height: u8,
indices: &[usize],
channel: &mut Ch,
) -> Result<Self, TranscriptError>
where
F: Copy,
C: Clone + PartialEq,
Ch: VerifierChannel<F = F, Commitment = C>,
{
let unique_indices: BTreeSet<usize> = indices.iter().copied().collect();
let total_width: usize = widths.iter().sum();
let openings: BTreeMap<usize, LeafOpening<F, SALT_ELEMS>> = unique_indices
.iter()
.copied()
.map(|index| {
let elems = channel.receive_hint_field_slice(total_width)?.to_vec();
let rows = RowList::new(elems, widths);
let salt: [F; SALT_ELEMS] = channel.receive_hint_field_array()?;
Ok((index, LeafOpening { rows, salt }))
})
.collect::<Result<_, _>>()?;
let siblings: BTreeMap<(usize, usize), C> =
required_siblings(openings.keys().copied(), log_max_height.into())
.into_iter()
.map(|key| Ok((key, channel.receive_hint_commitment()?.clone())))
.collect::<Result<_, TranscriptError>>()?;
Ok(Self { openings, siblings })
}
pub fn single_proofs<L>(
&self,
lmcs: &L,
widths: &[usize],
log_max_height: u8,
) -> Option<BTreeMap<usize, Proof<F, C, SALT_ELEMS>>>
where
F: Copy,
C: Clone + PartialEq,
L: Lmcs<F = F, Commitment = C>,
{
let mut proofs: BTreeMap<usize, Proof<F, C, SALT_ELEMS>> = BTreeMap::new();
let mut tree: BTreeMap<(usize, usize), C> = BTreeMap::new();
for (&index, opening) in self.openings.iter() {
if opening.rows.num_rows() != widths.len() {
return None;
}
for (row, &width) in opening.rows.iter_rows().zip(widths.iter()) {
if row.len() != width {
return None;
}
}
let rows_iter = opening.rows.iter_rows();
let leaf_hash = if SALT_ELEMS > 0 {
lmcs.hash(rows_iter.chain([opening.salt.as_slice()]))
} else {
lmcs.hash(rows_iter)
};
proofs.entry(index).or_insert_with(|| Proof {
rows: opening.rows.clone(),
salt: opening.salt,
siblings: Vec::with_capacity(log_max_height as usize),
});
if tree
.insert((0, index), leaf_hash.clone())
.is_some_and(|existing_hash| existing_hash != leaf_hash)
{
return None;
}
}
let tree_depth = log_max_height as usize;
for (depth, index) in required_siblings(self.openings.keys().copied(), tree_depth) {
tree.insert((depth, index), self.siblings.get(&(depth, index))?.clone());
}
for current_depth in 0..tree_depth {
let nodes_at_depth: Vec<(usize, C)> = tree
.range((current_depth, 0)..=(current_depth, usize::MAX))
.map(|(&(_, idx), hash)| (idx, hash.clone()))
.collect();
let mut nodes_iter = nodes_at_depth.into_iter().peekable();
while let Some((index, hash)) = nodes_iter.next() {
let sibling_index = index ^ 1;
let sibling_hash =
match nodes_iter.next_if(|(next_index, _)| *next_index == sibling_index) {
Some((_, hash)) => hash,
None => return None,
};
let is_left_child = index & 1 == 0;
let (left, right) = if is_left_child {
(hash, sibling_hash)
} else {
(sibling_hash, hash)
};
let parent_depth = current_depth + 1;
let parent_index = index / 2;
let parent_hash = lmcs.compress(left, right);
tree.insert((parent_depth, parent_index), parent_hash);
}
}
for (&index, proof) in proofs.iter_mut() {
let mut current_index = index;
for current_depth in 0..tree_depth {
let sibling_index = current_index ^ 1;
let sibling_hash = tree.get(&(current_depth, sibling_index)).cloned()?;
proof.siblings.push(sibling_hash);
current_index >>= 1;
}
}
Some(proofs)
}
}
fn required_siblings<I>(indices: I, log_max_height: usize) -> Vec<(usize, usize)>
where
I: IntoIterator<Item = usize>,
{
let mut missing = Vec::new();
let mut known: BTreeSet<usize> = indices.into_iter().collect();
for current_depth in 0..log_max_height {
let mut parents = BTreeSet::new();
for &pos in &known {
let parent_pos = pos / 2;
if !parents.insert(parent_pos) {
continue;
}
let left_pos = parent_pos * 2;
let right_pos = left_pos + 1;
let have_left = known.contains(&left_pos);
let have_right = known.contains(&right_pos);
let missing_pos = match (have_left, have_right) {
(true, false) => right_pos,
(false, true) => left_pos,
_ => continue,
};
missing.push((current_depth, missing_pos));
}
known = parents;
}
missing
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use p3_matrix::dense::RowMajorMatrix;
use p3_miden_transcript::{VerifierChannel, VerifierTranscript};
use p3_symmetric::Hash;
use rand::{SeedableRng, rngs::SmallRng};
use crate::{
BatchProof, Lmcs, LmcsTree, log2_strict_u8,
tests::{DIGEST, F, lmcs, roundtrip_open_batch},
};
#[test]
fn batch_proof_consistent_with_open_batch() {
let lmcs = lmcs();
let test = |seed: u64, shapes: &[(usize, usize)], indices: &[usize]| {
let mut rng = SmallRng::seed_from_u64(seed);
let matrices: Vec<_> = shapes
.iter()
.map(|&(h, w)| RowMajorMatrix::rand(&mut rng, h, w))
.collect();
let tree = lmcs.build_tree(matrices);
let widths = tree.widths();
let log_max_height = log2_strict_u8(tree.height());
let (transcript, opened_rows) =
roundtrip_open_batch(&lmcs, &tree, indices).expect("open_batch should verify");
let mut verifier_channel = VerifierTranscript::from_data(
p3_miden_dev_utils::configs::baby_bear_poseidon2::test_challenger(),
&transcript,
);
let batch = BatchProof::<F, Hash<F, F, DIGEST>>::read_from_channel(
&widths,
log_max_height,
indices,
&mut verifier_channel,
)
.expect("batch proof should parse");
assert!(
verifier_channel.is_empty(),
"parse path should fully consume transcript"
);
assert_eq!(opened_rows.len(), batch.openings.len());
for (&idx, verified_rows) in &opened_rows {
let parsed = batch.openings.get(&idx).expect("parsed opening for index");
assert_eq!(
*verified_rows, parsed.rows,
"row mismatch between open_batch and BatchProof at index {idx}"
);
}
let proofs = batch
.single_proofs(&lmcs, &widths, log_max_height)
.expect("single_proofs should reconstruct");
for &idx in indices {
let proof = proofs.get(&idx).expect("proof for index");
let expected = tree.single_proof(idx);
assert_eq!(proof, &expected, "single_proof mismatch at index {idx}");
}
};
test(1, &[(8, 4)], &[0, 3, 7]);
test(42, &[(4, 3), (8, 5), (16, 7)], &[0, 5, 10, 15]);
test(99, &[(4, 2), (8, 6)], &[3, 1, 3, 0, 1]); }
}