use ark_ff::PrimeField;
use byteorder::{LittleEndian, ReadBytesExt};
use std::io::{Error, ErrorKind};
use thiserror::Error;
use ark_ec::pairing::Pairing;
use ark_serialize::{CanonicalDeserialize, SerializationError};
use ark_std::io::{Read, Seek, SeekFrom};
use std::collections::HashMap;
use crate::{
reader_utils::{self, InvalidHeaderError},
traits::CircomArkworksPairingBridge,
};
const R1CS_HEADER: &str = "r1cs";
const MAX_VERSION: u32 = 1;
type Result<T> = std::result::Result<T, R1CSParserError>;
pub(crate) type Constraints<P> = (ConstraintVec<P>, ConstraintVec<P>, ConstraintVec<P>);
pub(crate) type ConstraintVec<P> = Vec<(usize, <P as Pairing>::ScalarField)>;
#[derive(Debug, Error)]
pub enum R1CSParserError {
#[error(transparent)]
SerializationError(#[from] SerializationError),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("Max supported version is {0}, but got {1}")]
VersionNotSupported(u32, u32),
#[error(transparent)]
WrongHeader(#[from] InvalidHeaderError),
#[error("ScalarField from curve does not match in witness file")]
WrongScalarField,
}
#[derive(Clone, Debug)]
pub struct R1CS<P: Pairing> {
pub num_inputs: usize,
pub num_aux: usize,
pub num_variables: usize,
pub constraints: Vec<Constraints<P>>,
pub wire_mapping: Vec<usize>,
pub n_pub_out: u32,
pub n_pub_in: u32,
pub n_prv_in: u32,
pub n_labels: u64,
pub n_constraints: usize,
}
impl<P: Pairing + CircomArkworksPairingBridge> R1CS<P> {
pub fn from_reader<R: Read + Seek>(mut reader: R) -> Result<Self> {
reader_utils::read_header(&mut reader, R1CS_HEADER)?;
let version = reader.read_u32::<LittleEndian>()?;
if version != MAX_VERSION {
return Err(R1CSParserError::VersionNotSupported(MAX_VERSION, version));
}
let num_sections = reader.read_u32::<LittleEndian>()?;
let mut sec_offsets = HashMap::<u32, u64>::new();
let mut sec_sizes = HashMap::<u32, u64>::new();
for _ in 0..num_sections {
let sec_type = reader.read_u32::<LittleEndian>()?;
let sec_size = reader.read_u64::<LittleEndian>()?;
let offset = reader.stream_position()?;
sec_offsets.insert(sec_type, offset);
sec_sizes.insert(sec_type, sec_size);
reader.seek(SeekFrom::Current(sec_size as i64))?;
}
let header_type = 1;
let constraint_type = 2;
let wire2label_type = 3;
let header_offset = sec_offsets.get(&header_type).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
"No section offset for header type found",
)
});
reader.seek(SeekFrom::Start(*header_offset?))?;
let header_size = sec_sizes.get(&header_type).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
"No section size for header type found",
)
});
let field_size =
usize::try_from(reader.read_u32::<LittleEndian>()?).expect("u32 fits into usize");
if field_size != P::SCALAR_FIELD_BYTE_SIZE {
return Err(R1CSParserError::WrongScalarField);
}
if *header_size? != 32 + field_size as u64 {
return Err(R1CSParserError::WrongScalarField);
}
let q = <P::ScalarField as PrimeField>::BigInt::deserialize_uncompressed(&mut reader)?;
let modulus = P::ScalarField::MODULUS;
if q != modulus {
return Err(R1CSParserError::WrongScalarField);
}
let num_variables =
usize::try_from(reader.read_u32::<LittleEndian>()?).expect("u32 fits into usize");
let n_pub_out = reader.read_u32::<LittleEndian>()?;
let n_pub_in = reader.read_u32::<LittleEndian>()?;
let n_prv_in = reader.read_u32::<LittleEndian>()?;
let n_labels = reader.read_u64::<LittleEndian>()?;
let n_constraints =
usize::try_from(reader.read_u32::<LittleEndian>()?).expect("u32 fits into usize");
let constraint_offset = sec_offsets.get(&constraint_type).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
"No section offset for constraint type found",
)
});
reader.seek(SeekFrom::Start(*constraint_offset?))?;
let constraints = read_constraints::<&mut R, P>(&mut reader, n_constraints, field_size)?;
let wire2label_offset = sec_offsets.get(&wire2label_type).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
"No section offset for wire2label type found",
)
});
reader.seek(SeekFrom::Start(*wire2label_offset?))?;
let wire2label_size = sec_sizes.get(&wire2label_type).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
"No section size for wire2label type found",
)
})?;
let wire_mapping = read_map(&mut reader, *wire2label_size, num_variables)?;
let num_inputs = (1 + n_pub_in + n_pub_out) as usize;
let num_aux = num_variables - num_inputs;
Ok(R1CS {
num_aux,
num_inputs,
num_variables,
constraints,
wire_mapping: wire_mapping.iter().map(|e| *e as usize).collect(),
n_pub_out,
n_pub_in,
n_prv_in,
n_labels,
n_constraints,
})
}
}
fn read_constraint_vec<R: Read, P: Pairing>(
mut reader: R,
field_size: usize,
) -> Result<ConstraintVec<P>> {
let n_vec = reader.read_u32::<LittleEndian>()? as usize;
let mut vec = Vec::with_capacity(n_vec);
for _ in 0..n_vec {
vec.push((
reader.read_u32::<LittleEndian>()? as usize,
reader_utils::prime_field_from_reader(&mut reader, field_size)?,
));
}
Ok(vec)
}
fn read_constraints<R: Read, P: Pairing>(
mut reader: R,
n_constraints: usize,
field_size: usize,
) -> Result<Vec<Constraints<P>>> {
let mut vec = Vec::with_capacity(n_constraints);
for _ in 0..n_constraints {
vec.push((
read_constraint_vec::<_, P>(&mut reader, field_size)?,
read_constraint_vec::<_, P>(&mut reader, field_size)?,
read_constraint_vec::<_, P>(&mut reader, field_size)?,
));
}
Ok(vec)
}
fn read_map<R: Read>(mut reader: R, size: u64, n_wires: usize) -> Result<Vec<u64>> {
if size != u64::try_from(n_wires).expect("usize fits into u64") * 8 {
Err(std::io::Error::new(
ErrorKind::InvalidData,
"Invalid map section size",
))?;
}
let mut vec = Vec::with_capacity(n_wires);
for _ in 0..n_wires {
vec.push(reader.read_u64::<LittleEndian>()?);
}
if vec[0] != 0 {
Err(std::io::Error::new(
ErrorKind::InvalidData,
"Wire 0 should always be mapped to 0",
))?;
}
Ok(vec)
}
#[cfg(test)]
#[cfg(feature = "bls12-381")]
mod bls12_381_tests {
use crate::tests::groth16_bls12_381_kats;
use ark_bls12_381::Bls12_381;
use super::*;
use std::{fs::File, str::FromStr};
#[test]
fn test_bls_12_381_mult2() {
let groth16_bls12_381_kats = groth16_bls12_381_kats();
let r1cs_file = File::open(groth16_bls12_381_kats.join("circuit.r1cs")).unwrap();
let r1cs = R1CS::<Bls12_381>::from_reader(r1cs_file).unwrap();
assert_eq!(r1cs.num_inputs, 2);
assert_eq!(r1cs.num_aux, 2);
assert_eq!(r1cs.num_variables, 4);
assert_eq!(r1cs.n_pub_out, 1);
assert_eq!(r1cs.n_pub_in, 0);
assert_eq!(r1cs.n_prv_in, 2);
assert_eq!(r1cs.n_labels, 0x0004);
assert_eq!(r1cs.n_constraints, 1);
assert_eq!(r1cs.constraints.len(), 1);
assert_eq!(r1cs.constraints[0].0.len(), 1);
assert_eq!(r1cs.constraints[0].0[0].0, 2);
assert_eq!(
r1cs.constraints[0].0[0].1,
ark_bls12_381::Fr::from_str(
"52435875175126190479447740508185965837690552500527637822603658699938581184512"
)
.unwrap()
);
assert_eq!(r1cs.wire_mapping, vec![0, 1, 2, 3]);
}
}
#[cfg(test)]
#[cfg(feature = "bn254")]
mod bn254_tests {
use crate::tests::groth16_bn254_kats;
use super::*;
use ark_bn254::Bn254;
use std::{fs::File, str::FromStr};
#[test]
fn test_bn254_mult2() {
let groth16_bn254_kats = groth16_bn254_kats();
let r1cs_file = File::open(groth16_bn254_kats.join("circuit.r1cs")).unwrap();
let r1cs = R1CS::<Bn254>::from_reader(r1cs_file).unwrap();
assert_eq!(r1cs.num_inputs, 2);
assert_eq!(r1cs.num_aux, 2);
assert_eq!(r1cs.num_variables, 4);
assert_eq!(r1cs.n_pub_out, 1);
assert_eq!(r1cs.n_pub_in, 0);
assert_eq!(r1cs.n_prv_in, 2);
assert_eq!(r1cs.n_labels, 0x0004);
assert_eq!(r1cs.n_constraints, 1);
assert_eq!(r1cs.constraints.len(), 1);
assert_eq!(r1cs.constraints[0].0.len(), 1);
assert_eq!(r1cs.constraints[0].0[0].0, 2);
assert_eq!(
r1cs.constraints[0].0[0].1,
ark_bn254::Fr::from_str(
"21888242871839275222246405745257275088548364400416034343698204186575808495616"
)
.unwrap()
);
assert_eq!(r1cs.wire_mapping, vec![0, 1, 2, 3]);
}
}