mod hiding;
#[cfg(test)]
mod tests;
use alloc::vec::Vec;
use core::iter::zip;
use p3_commit::{BatchOpening, BatchOpeningRef, Mmcs};
use p3_field::PackedValue;
use p3_matrix::{Dimensions, Matrix};
use p3_miden_stateful_hasher::{Alignable, StatefulHasher};
use p3_symmetric::{Hash, MerkleCap, PseudoCompressionFunction};
use p3_util::log2_ceil_usize;
use serde::{Deserialize, Serialize};
use crate::{Lmcs, LmcsConfig, LmcsError, LmcsTree, lifted_tree::LiftedMerkleTree};
impl<PF, PD, H, C, const WIDTH: usize, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
Mmcs<PF::Value> for LmcsConfig<PF, PD, H, C, WIDTH, DIGEST_ELEMS, SALT_ELEMS>
where
PF: PackedValue + Default,
PD: PackedValue + Default,
PF::Value: PartialEq,
H: StatefulHasher<PF, [PD; DIGEST_ELEMS], State = [PD; WIDTH]>
+ StatefulHasher<PF::Value, [PD::Value; DIGEST_ELEMS], State = [PD::Value; WIDTH]>
+ Alignable<PF::Value, PD::Value>
+ Sync,
C: PseudoCompressionFunction<[PD::Value; DIGEST_ELEMS], 2>
+ PseudoCompressionFunction<[PD; DIGEST_ELEMS], 2>
+ Sync,
[PF::Value; SALT_ELEMS]: Serialize + for<'de> Deserialize<'de>,
[PD::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
{
type ProverData<M> = LiftedMerkleTree<PF::Value, PD::Value, M, DIGEST_ELEMS, SALT_ELEMS>;
type Commitment = MerkleCap<PF::Value, [PD::Value; DIGEST_ELEMS]>;
type Proof = ([PF::Value; SALT_ELEMS], Vec<Self::Commitment>);
type Error = LmcsError;
fn commit<M: Matrix<PF::Value>>(
&self,
inputs: Vec<M>,
) -> (Self::Commitment, Self::ProverData<M>) {
let tree = self.build_tree(inputs);
(MerkleCap::from(tree.root()), tree)
}
fn open_batch<M: Matrix<PF::Value>>(
&self,
index: usize,
tree: &Self::ProverData<M>,
) -> BatchOpening<PF::Value, Self> {
let final_height = tree.height();
assert!(
index < final_height,
"index {index} out of range {final_height}"
);
let crate::Proof {
rows,
salt,
siblings,
} = tree.single_proof(index);
let siblings_cap: Vec<Self::Commitment> =
siblings.into_iter().map(MerkleCap::from).collect();
let rows_vec: Vec<Vec<PF::Value>> = rows.iter_rows().map(|r| r.to_vec()).collect();
BatchOpening::new(rows_vec, (salt, siblings_cap))
}
fn get_matrices<'a, M: Matrix<PF::Value>>(&self, tree: &'a Self::ProverData<M>) -> Vec<&'a M> {
tree.leaves.iter().collect()
}
fn verify_batch(
&self,
commitment: &Self::Commitment,
dimensions: &[Dimensions],
index: usize,
batch_opening: BatchOpeningRef<'_, PF::Value, Self>,
) -> Result<(), Self::Error> {
let (rows, (salt, siblings)) = batch_opening.unpack();
let widths: Vec<usize> = dimensions.iter().map(|d| d.width).collect();
if batch_opening.opened_values.len() != widths.len() {
return Err(LmcsError::InvalidProof);
}
for (row, &width) in zip(batch_opening.opened_values, &widths) {
if row.len() != width {
return Err(LmcsError::InvalidProof);
}
}
let rows_iter = rows.iter().map(|row| row.as_slice());
let leaf_hash = if SALT_ELEMS > 0 {
self.hash(rows_iter.chain([salt.as_slice()]))
} else {
self.hash(rows_iter)
};
let max_height = dimensions
.iter()
.map(|d| d.height)
.max()
.ok_or(LmcsError::InvalidProof)?;
let log_max_height = log2_ceil_usize(max_height);
if siblings.len() != log_max_height {
return Err(LmcsError::InvalidProof);
}
if index >= max_height {
return Err(LmcsError::InvalidProof);
}
let computed_commitment: Hash<PF::Value, PD::Value, DIGEST_ELEMS> = {
let mut current = leaf_hash;
let mut pos = index;
for sibling_cap in siblings {
let sibling_hash = Hash::from(sibling_cap.roots()[0]);
let is_left = pos & 1 == 0;
current = if is_left {
self.compress(current, sibling_hash)
} else {
self.compress(sibling_hash, current)
};
pos >>= 1;
}
current
};
if MerkleCap::from(computed_commitment) != *commitment {
return Err(LmcsError::RootMismatch);
}
Ok(())
}
}