#[cfg(feature = "prover")]
use crate::data_structures::RecursiveLigeroWitness;
use crate::utils::{eval_sk_at_vks, evaluate_lagrange_basis, evaluate_scaled_basis_inplace};
use binary_fields::{BinaryFieldElement, BinaryPolynomial};
use merkle_tree::{build_merkle_tree, Hash};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[cfg(feature = "prover")]
use reed_solomon::ReedSolomon;
use sha2::{Digest, Sha256};
#[cfg(feature = "std")]
macro_rules! debug_println {
($($arg:tt)*) => { std::println!($($arg)*) }
}
#[cfg(not(feature = "std"))]
macro_rules! debug_println {
($($arg:tt)*) => {};
}
#[cfg(feature = "prover")]
pub fn poly2mat<F: BinaryFieldElement>(
poly: &[F],
m: usize,
n: usize,
inv_rate: usize,
) -> Vec<Vec<F>> {
let m_target = m * inv_rate;
let mut mat = vec![vec![F::zero(); n]; m_target];
#[cfg(feature = "parallel")]
{
mat.par_iter_mut().enumerate().for_each(|(i, row)| {
for j in 0..n {
let idx = j * m + i;
if idx < poly.len() {
row[j] = poly[idx];
}
}
});
}
#[cfg(not(feature = "parallel"))]
{
for (i, row) in mat.iter_mut().enumerate() {
for j in 0..n {
let idx = j * m + i;
if idx < poly.len() {
row[j] = poly[idx];
}
}
}
}
mat
}
#[cfg(all(feature = "prover", feature = "parallel"))]
pub fn encode_cols<F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static>(
poly_mat: &mut Vec<Vec<F>>,
rs: &ReedSolomon<F>,
parallel: bool,
) {
let n = poly_mat[0].len();
if parallel {
let cols: Vec<Vec<F>> = (0..n)
.into_par_iter()
.map(|j| {
let mut col: Vec<F> = poly_mat.iter().map(|row| row[j]).collect();
reed_solomon::encode_in_place_with_parallel(rs, &mut col, false);
col
})
.collect();
for (i, row) in poly_mat.iter_mut().enumerate() {
for (j, col) in cols.iter().enumerate() {
row[j] = col[i];
}
}
} else {
for j in 0..n {
let mut col: Vec<F> = poly_mat.iter().map(|row| row[j]).collect();
reed_solomon::encode_in_place(rs, &mut col);
for (i, val) in col.iter().enumerate() {
poly_mat[i][j] = *val;
}
}
}
}
#[cfg(all(feature = "prover", not(feature = "parallel")))]
pub fn encode_cols<F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static>(
poly_mat: &mut Vec<Vec<F>>,
rs: &ReedSolomon<F>,
_parallel: bool,
) {
let n = poly_mat[0].len();
for j in 0..n {
let mut col: Vec<F> = poly_mat.iter().map(|row| row[j]).collect();
reed_solomon::encode_in_place(rs, &mut col);
for (i, val) in col.iter().enumerate() {
poly_mat[i][j] = *val;
}
}
}
#[inline(always)]
pub fn hash_row<F: BinaryFieldElement>(row: &[F]) -> Hash {
let mut hasher = Sha256::new();
hasher.update((row.len() as u32).to_le_bytes());
let row_bytes = unsafe {
core::slice::from_raw_parts(row.as_ptr() as *const u8, std::mem::size_of_val(row))
};
hasher.update(row_bytes);
hasher.finalize().into()
}
#[cfg(feature = "prover")]
pub fn ligero_commit<F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static>(
poly: &[F],
m: usize,
n: usize,
rs: &ReedSolomon<F>,
) -> RecursiveLigeroWitness<F> {
let mut poly_mat = poly2mat(poly, m, n, 4);
encode_cols(&mut poly_mat, rs, true);
#[cfg(feature = "parallel")]
let hashed_rows: Vec<Hash> = poly_mat.par_iter().map(|row| hash_row(row)).collect();
#[cfg(not(feature = "parallel"))]
let hashed_rows: Vec<Hash> = poly_mat.iter().map(|row| hash_row(row)).collect();
let tree = build_merkle_tree(&hashed_rows);
RecursiveLigeroWitness {
mat: poly_mat,
tree,
}
}
pub fn verify_ligero<T, U>(queries: &[usize], opened_rows: &[Vec<T>], yr: &[T], challenges: &[U])
where
T: BinaryFieldElement + Send + Sync,
U: BinaryFieldElement + Send + Sync + From<T>,
{
debug_println!("verify_ligero: challenges = {:?}", challenges);
let gr = evaluate_lagrange_basis(challenges);
let n = yr.len().trailing_zeros() as usize;
let sks_vks: Vec<T> = eval_sk_at_vks(1 << n);
if !queries.is_empty() {
let query = queries[0];
let row = &opened_rows[query];
let dot = row.iter().zip(gr.iter()).fold(U::zero(), |acc, (&r, &g)| {
let r_u = U::from(r);
acc.add(&r_u.mul(&g))
});
let query_for_basis = query % (1 << n);
let qf = T::from_poly(<T as BinaryFieldElement>::Poly::from_value(
query_for_basis as u64,
));
let mut local_sks_x = vec![T::zero(); sks_vks.len()];
let mut local_basis = vec![U::zero(); 1 << n];
let scale = U::from(T::one());
evaluate_scaled_basis_inplace(&mut local_sks_x, &mut local_basis, &sks_vks, qf, scale);
let e = yr
.iter()
.zip(local_basis.iter())
.fold(U::zero(), |acc, (&y, &b)| {
let y_u = U::from(y);
acc.add(&y_u.mul(&b))
});
debug_println!(
"verify_ligero: Query {} -> e = {:?}, dot = {:?}",
query,
e,
dot
);
debug_println!("verify_ligero: Equal? {}", e == dot);
if e != dot {
debug_println!(
"verify_ligero: mathematical relationship mismatch for query {}",
query
);
debug_println!(" e = {:?}", e);
debug_println!(" dot = {:?}", dot);
debug_println!(" this might be expected in certain contexts");
} else {
debug_println!(
"verify_ligero: mathematical relationship holds for query {}",
query
);
}
}
}