use std::io;
use ark_serialize::{Read, SerializationError};
use byteorder::{LittleEndian, ReadBytesExt};
use thiserror::Error;
use ark_ff::{BigInteger, PrimeField};
use crate::reader_utils::{self, InvalidHeaderError};
type Result<T> = std::result::Result<T, WitnessParserError>;
const WITNESS_HEADER: &str = "wtns";
const MAX_VERSION: u32 = 2;
const N_SECTIONS: u32 = 2;
#[derive(Debug, Error)]
pub enum WitnessParserError {
#[error(transparent)]
IoError(#[from] io::Error),
#[error(transparent)]
SerializationError(#[from] SerializationError),
#[error("Max supported version is {0}, but got {1}")]
VersionNotSupported(u32, u32),
#[error("Wrong number of sections is {0}, but got {1}")]
InvalidSectionNumber(u32, u32),
#[error("ScalarField from curve does not match in witness file")]
WrongScalarField,
#[error(transparent)]
WrongHeader(#[from] InvalidHeaderError),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Witness<F> {
pub values: Vec<F>,
}
impl<F: PrimeField> Witness<F> {
pub fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
tracing::trace!("trying to read witness");
reader_utils::read_header(&mut reader, WITNESS_HEADER)?;
let version = reader.read_u32::<LittleEndian>()?;
if version > MAX_VERSION {
return Err(WitnessParserError::VersionNotSupported(
MAX_VERSION,
version,
));
}
let n_sections = reader.read_u32::<LittleEndian>()?;
if n_sections > N_SECTIONS {
return Err(WitnessParserError::InvalidSectionNumber(
N_SECTIONS, n_sections,
));
}
let _ = reader.read_u32::<LittleEndian>()?;
let _ = reader.read_u64::<LittleEndian>()?;
let n8 = reader.read_u32::<LittleEndian>()?;
let n8 = usize::try_from(n8).expect("u32 fits into usize");
let mut buf = vec![0; n8];
reader.read_exact(buf.as_mut_slice())?;
if F::MODULUS.to_bytes_le() != buf {
tracing::trace!("wrong scalar field");
return Err(WitnessParserError::WrongScalarField);
}
let n_witness = reader.read_u32::<LittleEndian>()?;
let _ = reader.read_u32::<LittleEndian>()?;
let _ = reader.read_u64::<LittleEndian>()?;
Ok(Self {
values: (0..n_witness)
.map(|_| {
reader_utils::prime_field_from_reader(&mut reader, n8)
.map_err(WitnessParserError::from)
})
.collect::<Result<Vec<F>>>()?,
})
}
}
#[cfg(test)]
#[cfg(feature = "bn254")]
mod bn254_tests {
use std::fs::File;
use crate::tests::groth16_bn254_kats;
use super::Witness;
#[test]
fn can_deser_witness_bn254() {
let groth16_bn254_kats = groth16_bn254_kats();
let witness = File::open(groth16_bn254_kats.join("witness.wtns")).unwrap();
let is_witness = Witness::<ark_bn254::Fr>::from_reader(witness).unwrap();
assert_eq!(
is_witness,
Witness {
values: vec![
ark_bn254::Fr::from(1),
ark_bn254::Fr::from(33),
ark_bn254::Fr::from(3),
ark_bn254::Fr::from(11),
],
}
);
}
}
#[cfg(test)]
#[cfg(feature = "bls12-381")]
mod bls12_381_tests {
use std::fs::File;
use crate::tests::groth16_bls12_381_kats;
use super::Witness;
#[test]
fn can_deser_witness_bls12381() {
let groth16_bls12_381_kats = groth16_bls12_381_kats();
let witness = File::open(groth16_bls12_381_kats.join("witness.wtns")).unwrap();
let is_witness = Witness::<ark_bls12_381::Fr>::from_reader(witness).unwrap();
assert_eq!(
is_witness,
Witness {
values: vec![
ark_bls12_381::Fr::from(1),
ark_bls12_381::Fr::from(33),
ark_bls12_381::Fr::from(3),
ark_bls12_381::Fr::from(11),
],
}
);
}
}