use super::error::ProofError;
use crate::compute::init_backend;
use curve25519_dalek::{
ristretto::{CompressedRistretto, RistrettoPoint},
scalar::Scalar,
};
use merlin::Transcript;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct InnerProductProof {
pub(crate) l_vector: Vec<CompressedRistretto>,
pub(crate) r_vector: Vec<CompressedRistretto>,
pub(crate) ap_value: Scalar,
}
impl InnerProductProof {
pub fn create(
transcript: &mut Transcript,
a: &[Scalar],
b: &[Scalar],
generators_offset: u64,
) -> InnerProductProof {
init_backend();
let n: u64 = a.len() as u64;
assert!(n > 0);
assert!(n == b.len() as u64);
let ceil_lg2_n = n.next_power_of_two().trailing_zeros() as usize;
let mut ap_value = Scalar::default();
let mut l_vector: Vec<CompressedRistretto> =
vec![CompressedRistretto::default(); ceil_lg2_n];
let mut r_vector: Vec<CompressedRistretto> =
vec![CompressedRistretto::default(); ceil_lg2_n];
unsafe {
let a = a.as_ptr() as *const blitzar_sys::sxt_curve25519_scalar;
let b = b.as_ptr() as *const blitzar_sys::sxt_curve25519_scalar;
let transcript = transcript as *mut Transcript as *mut blitzar_sys::sxt_transcript;
let ap_value = &mut ap_value as *mut Scalar as *mut blitzar_sys::sxt_curve25519_scalar;
let l_vector = l_vector.as_mut_ptr() as *mut blitzar_sys::sxt_ristretto255_compressed;
let r_vector = r_vector.as_mut_ptr() as *mut blitzar_sys::sxt_ristretto255_compressed;
blitzar_sys::sxt_curve25519_prove_inner_product(
l_vector,
r_vector,
ap_value,
transcript,
n,
generators_offset,
a,
b,
);
}
InnerProductProof {
l_vector,
r_vector,
ap_value,
}
}
pub fn verify(
&self,
transcript: &mut Transcript,
a_commit: &RistrettoPoint,
product: &Scalar,
b: &[Scalar],
generators_offset: u64,
) -> Result<(), ProofError> {
init_backend();
let n = b.len();
assert!(n > 0);
let ceil_lg2_n = n.next_power_of_two().trailing_zeros() as usize;
if ceil_lg2_n != self.l_vector.len() || ceil_lg2_n != self.r_vector.len() {
return Err(ProofError::VerificationError);
}
let transcript = transcript as *mut Transcript as *mut blitzar_sys::sxt_transcript;
let b = b.as_ptr() as *const blitzar_sys::sxt_curve25519_scalar;
let product = product as *const Scalar as *const blitzar_sys::sxt_curve25519_scalar;
let a_commit = a_commit as *const RistrettoPoint as *const blitzar_sys::sxt_ristretto255;
let ap_value = &self.ap_value as *const Scalar as *const blitzar_sys::sxt_curve25519_scalar;
let l_vector = self.l_vector.as_ptr() as *const blitzar_sys::sxt_ristretto255_compressed;
let r_vector = self.r_vector.as_ptr() as *const blitzar_sys::sxt_ristretto255_compressed;
unsafe {
let ret = blitzar_sys::sxt_curve25519_verify_inner_product(
transcript,
n as u64,
generators_offset,
b,
product,
a_commit,
l_vector,
r_vector,
ap_value,
);
if ret == 1 {
return Ok(());
}
}
Err(ProofError::VerificationError)
}
}