use anyhow::Result;
use crate::*;
pub struct TransparentMatInnerProd<const C: usize, E: FieldScalar> {
pub comms: (Vector<E>, Vector<E>),
pub inner_prod_args: [E; C],
pub claim_rlc: [E; C],
pub hadamard_checksum: [u8; 32],
pub vec_checksum: [u8; 32],
pub mat_checksum: [u8; 32],
}
impl<const C: usize, E: FieldScalar> TransparentMatInnerProd<C, E> {
pub fn new(mat: &Matrix<E>, vec: Vector<E>) -> Result<Self> {
panic!();
assert_eq!(
mat.width(),
vec.len(),
"TransparentMatInnerProd matrix/vector dimension mismatch"
);
let mut vec_transcript = Transcript::new();
vec_transcript.append_vector(&vec);
let vec_checksum = vec_transcript.random::<[u8; 32]>();
let mut mat_transcript = Transcript::new();
mat_transcript.append_matrix(mat);
let mat_checksum = mat_transcript.random::<[u8; 32]>();
let zeroes = if 2usize.pow(vec.len().ilog2()) == vec.len() {
vec![]
} else {
vec![E::zero(); 2usize.pow(vec.len().ilog2() + 1) - vec.len()]
};
let row_len = vec.len() + zeroes.len();
let total_elements = (vec.len() + zeroes.len()) * mat.height();
let zero = E::zero();
let mat_index = |offset| {
let row_i = offset / row_len;
let col_i = offset % row_len;
if col_i >= mat.width() {
return &zero;
}
&mat[row_i][col_i]
};
let vec_index = |offset| {
let i = offset % row_len;
if i >= vec.len() {
return &zero;
}
&vec[i]
};
let mat_inner_i = |i, depth: usize| {
if depth == 0 {
return *mat_index(i);
}
let total_sum_len = 2usize.pow(depth as u32);
let skip = total_elements / (2usize.pow(depth as u32));
let mut sum = E::zero();
for j in 0..total_sum_len {
sum += *mat_index(i + j * skip);
}
sum
};
let vec_inner_i = |i, depth: usize| {
if depth == 0 {
return *vec_index(i);
}
let total_sum_len = 2usize.pow(depth as u32);
let skip = total_elements / (2usize.pow(depth as u32));
let mut sum = E::zero();
for j in 0..total_sum_len {
sum += *vec_index(i + j * skip);
}
sum
};
let mut hadamard_transcript = Transcript::new();
for i in 0..mat.height() {
let mut sum = E::zero();
for j in 0..row_len {
sum += *mat_index(i * row_len + j) * *vec_index(i * row_len + j);
}
hadamard_transcript.append(&sum);
}
let hadamard_checksum = hadamard_transcript.random::<[u8; 32]>();
let mut claim_transcript = Transcript::<E>::new();
claim_transcript.append_bytes(&mat_checksum);
claim_transcript.append_bytes(&vec_checksum);
claim_transcript.append_bytes(&hadamard_checksum);
let mut claim_rlc = [E::zero(); C];
for i in 0..C {
let mut challenge = E::sample_uniform(&mut claim_transcript);
for j in 0..total_elements {
claim_rlc[i] += *mat_index(j) * *vec_index(j) * challenge;
challenge *= challenge;
}
}
let claim_rlc = claim_rlc;
let mut inner_prod_args = claim_rlc.clone();
let mut inner_product_transcript = Transcript::new();
inner_product_transcript.domain_separator("vec");
inner_product_transcript.append_bytes(&vec_checksum);
inner_product_transcript.domain_separator("mat");
inner_product_transcript.append_bytes(&mat_checksum);
inner_product_transcript.domain_separator("hadamard");
inner_product_transcript.append_bytes(&hadamard_checksum);
inner_product_transcript.domain_separator("claim_rlc");
inner_product_transcript.append_vector(&claim_rlc.to_vec().into());
inner_product_transcript.domain_separator("args");
let mut comms = (Vec::new(), Vec::new());
let mut depth = 0;
while 2usize.pow(depth as u32 + 1) < total_elements / 2 {
let len = total_elements / 2usize.pow(depth as u32 + 1);
let lhs_inner_prod = (0..len)
.map(|i| mat_inner_i(i, depth))
.zip((len..2 * len).map(|i| vec_inner_i(i, depth)))
.map(|(l, r)| l * r)
.sum::<E>();
let rhs_inner_prod = (len..2 * len)
.map(|i| mat_inner_i(i, depth))
.zip((0..len).map(|i| vec_inner_i(i, depth)))
.map(|(l, r)| l * r)
.sum::<E>();
comms.0.push(lhs_inner_prod);
comms.1.push(rhs_inner_prod);
inner_product_transcript.append(&lhs_inner_prod);
inner_product_transcript.append(&rhs_inner_prod);
for i in 0..C {
let challenge = E::sample_uniform(&mut inner_product_transcript);
let challenge_inv = challenge.inverse();
let a = (challenge_inv * challenge_inv) * lhs_inner_prod;
let b = (challenge * challenge) * rhs_inner_prod;
inner_prod_args[i] += a + b;
}
depth += 1;
}
Ok(Self {
comms: (comms.0.into(), comms.1.into()),
inner_prod_args,
claim_rlc,
mat_checksum,
vec_checksum,
hadamard_checksum,
})
}
pub fn verify(&self) -> Result<()> {
let mut transcript = Transcript::new();
transcript.domain_separator("vec");
transcript.append_bytes(&self.vec_checksum);
transcript.domain_separator("mat");
transcript.append_bytes(&self.mat_checksum);
transcript.domain_separator("hadamard");
transcript.append_bytes(&self.hadamard_checksum);
transcript.domain_separator("claim_rlc");
transcript.append_vector(&self.claim_rlc.to_vec().into());
transcript.domain_separator("args");
let mut claims = self.claim_rlc.clone();
for (l, r) in self.comms.0.iter().zip(self.comms.1.iter()) {
transcript.append(l);
transcript.append(r);
for i in 0..C {
let challenge = E::sample_uniform(&mut transcript);
let challenge_inv = challenge.inverse();
let a = (challenge_inv * challenge_inv) * *l;
let b = (challenge * challenge) * *r;
claims[i] += a + b;
}
}
for i in 0..C {
if claims[i] != self.inner_prod_args[i] {
anyhow::bail!("TransparentInnerProd: inner product mismatch");
}
}
Ok(())
}
}
#[test]
fn mat_innerprod() -> Result<()> {
const LEN: usize = 20;
type Field = MilliScalarMont;
let rng = &mut rand::rng();
let mat = Matrix::<Field>::random(LEN * 2, LEN, rng);
let vec = Vector::<Field>::sample_uniform(LEN, rng);
let mut vec_transcript = Transcript::new();
vec_transcript.append_vector(&vec);
let mut mat_transcript = Transcript::new();
mat_transcript.append_matrix(&mat);
let hadamard = &mat * &vec;
let mut hadamard_transcript = Transcript::new();
hadamard_transcript.append_vector(&hadamard);
let arg = TransparentMatInnerProd::<5, _>::new(&mat, vec)?;
arg.verify()?;
assert_eq!(vec_transcript.random::<[u8; 32]>(), arg.vec_checksum);
assert_eq!(mat_transcript.random::<[u8; 32]>(), arg.mat_checksum);
assert_eq!(
hadamard_transcript.random::<[u8; 32]>(),
arg.hadamard_checksum
);
Ok(())
}