use alloc::vec::Vec;
use core::cell::RefCell;
use p3_field::PackedValue;
use p3_matrix::{Matrix, dense::RowMajorMatrix};
use p3_miden_stateful_hasher::{Alignable, StatefulHasher};
use p3_miden_transcript::VerifierChannel;
use p3_symmetric::{Hash, PseudoCompressionFunction};
use rand::{
Rng,
distr::{Distribution, StandardUniform},
};
use crate::{BatchProof, LiftedMerkleTree, Lmcs, LmcsConfig, LmcsError, OpenedRows};
#[derive(Clone, Debug)]
pub struct HidingLmcsConfig<
PF,
PD,
H,
C,
R,
const WIDTH: usize,
const DIGEST: usize,
const SALT: usize,
> {
pub inner: LmcsConfig<PF, PD, H, C, WIDTH, DIGEST, SALT>,
rng: RefCell<R>,
}
impl<PF, PD, H, C, R, const WIDTH: usize, const DIGEST: usize, const SALT: usize>
HidingLmcsConfig<PF, PD, H, C, R, WIDTH, DIGEST, SALT>
{
#[inline]
pub fn new(sponge: H, compress: C, rng: R) -> Self {
const { assert!(SALT > 0) }
Self {
inner: LmcsConfig::new(sponge, compress),
rng: RefCell::new(rng),
}
}
}
impl<PF, PD, H, C, R, const WIDTH: usize, const DIGEST: usize, const SALT: usize> Lmcs
for HidingLmcsConfig<PF, PD, H, C, R, WIDTH, DIGEST, SALT>
where
PF: PackedValue + Default,
PD: PackedValue + Default,
R: Rng + Clone,
StandardUniform: Distribution<PF::Value>,
H: StatefulHasher<PF::Value, [PD::Value; DIGEST], State = [PD::Value; WIDTH]>
+ StatefulHasher<PF, [PD; DIGEST], State = [PD; WIDTH]>
+ Alignable<PF::Value, PD::Value>
+ Sync,
C: PseudoCompressionFunction<[PD::Value; DIGEST], 2>
+ PseudoCompressionFunction<[PD; DIGEST], 2>
+ Sync,
{
type F = PF::Value;
type Commitment = Hash<PF::Value, PD::Value, DIGEST>;
type BatchProof = BatchProof<PF::Value, Self::Commitment, SALT>;
type Tree<M: Matrix<PF::Value>> = LiftedMerkleTree<PF::Value, PD::Value, M, DIGEST, SALT>;
fn build_tree<M: Matrix<Self::F>>(&self, leaves: Vec<M>) -> Self::Tree<M> {
let tree_height = leaves.last().map(|m| m.height()).unwrap_or(0);
let salt = RowMajorMatrix::rand(&mut *self.rng.borrow_mut(), tree_height, SALT);
LiftedMerkleTree::build_with_alignment::<PF, PD, H, C, WIDTH>(
&self.inner.sponge,
&self.inner.compress,
leaves,
Some(salt),
1,
)
}
fn build_aligned_tree<M: Matrix<Self::F>>(&self, leaves: Vec<M>) -> Self::Tree<M> {
let tree_height = leaves.last().map(|m| m.height()).unwrap_or(0);
let salt = RowMajorMatrix::rand(&mut *self.rng.borrow_mut(), tree_height, SALT);
LiftedMerkleTree::build_with_alignment::<PF, PD, H, C, WIDTH>(
&self.inner.sponge,
&self.inner.compress,
leaves,
Some(salt),
<H as Alignable<PF::Value, PD::Value>>::ALIGNMENT,
)
}
fn hash<'a, I>(&self, rows: I) -> Self::Commitment
where
I: IntoIterator<Item = &'a [Self::F]>,
Self::F: 'a,
{
self.inner.hash(rows)
}
fn compress(&self, left: Self::Commitment, right: Self::Commitment) -> Self::Commitment {
self.inner.compress(left, right)
}
fn open_batch<Ch>(
&self,
commitment: &Self::Commitment,
widths: &[usize],
log_max_height: u8,
indices: impl IntoIterator<Item = usize>,
channel: &mut Ch,
) -> Result<OpenedRows<Self::F>, LmcsError>
where
Ch: VerifierChannel<F = Self::F, Commitment = Self::Commitment>,
{
self.inner
.open_batch(commitment, widths, log_max_height, indices, channel)
}
fn read_batch_proof_from_channel<Ch>(
&self,
widths: &[usize],
log_max_height: u8,
indices: &[usize],
channel: &mut Ch,
) -> Result<Self::BatchProof, LmcsError>
where
Ch: VerifierChannel<F = Self::F, Commitment = Self::Commitment>,
{
self.inner
.read_batch_proof_from_channel(widths, log_max_height, indices, channel)
}
fn alignment(&self) -> usize {
self.inner.alignment()
}
}