1use std::io::{Error, ErrorKind, Read, Result, Write};
6
7use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
8
9const MAGIC: &[u8; 4] = b"r1cs";
10const VERSION: u32 = 1;
11
12#[derive(Debug, PartialEq, Eq)]
13pub struct R1csFile<const FS: usize> {
14 pub header: Header<FS>,
15 pub constraints: Constraints<FS>,
16 pub map: WireMap,
17}
18
19impl<const FS: usize> R1csFile<FS> {
20 pub fn read<R: Read>(mut r: R) -> Result<Self> {
21 let mut magic = [0u8; 4];
22 r.read_exact(&mut magic)?;
23 if magic != *MAGIC {
24 return Err(Error::new(ErrorKind::InvalidData, "Invalid magic number"));
25 }
26
27 let version = r.read_u32::<LittleEndian>()?;
28 if version != VERSION {
29 return Err(Error::new(ErrorKind::InvalidData, "Unsupported version"));
30 }
31
32 let num_sections = r.read_u32::<LittleEndian>()?;
36
37 let mut header = None;
38 let mut constraints = None;
39 let mut map = None;
40
41 for _ in 0..num_sections {
42 let section_header = SectionHeader::read(&mut r)?;
43
44 match section_header.ty {
45 SectionType::Header => {
46 if let None = header {
47 header = Some(Header::read(&mut r)?);
48 } else {
49 return Err(Error::new(
50 ErrorKind::InvalidData,
51 "Duplicated header section found",
52 ));
53 }
54 }
55 SectionType::Constraint => {
56 if let None = constraints {
57 constraints = Some(Constraints::read(&mut r, §ion_header)?);
58 } else {
59 return Err(Error::new(
60 ErrorKind::InvalidData,
61 "Duplicated constraints section found",
62 ));
63 }
64 }
65 SectionType::Wire2LabelIdMap => {
66 if let None = map {
67 map = Some(WireMap::read(&mut r, §ion_header)?);
68 } else {
69 return Err(Error::new(
70 ErrorKind::InvalidData,
71 "Duplicated wire map section found",
72 ));
73 }
74 }
75 SectionType::Unknown => {
76 return Err(Error::new(ErrorKind::InvalidData, "Unknown section"))
77 }
78 }
79 }
80
81 match (header, constraints, map) {
82 (Some(header), Some(constraints), Some(map)) => Ok(R1csFile {
83 header,
84 constraints,
85 map,
86 }),
87 (None, _, _) => Err(Error::new(ErrorKind::InvalidData, "Missing header section")),
88 (_, None, _) => Err(Error::new(
89 ErrorKind::InvalidData,
90 "Missing constraints section",
91 )),
92 (_, _, None) => Err(Error::new(
93 ErrorKind::InvalidData,
94 "Missing wire map section",
95 )),
96 }
97 }
98
99 pub fn write<W: Write>(&self, mut w: W) -> Result<()> {
100 w.write_all(MAGIC)?;
101 w.write_u32::<LittleEndian>(VERSION)?;
102 w.write_u32::<LittleEndian>(3)?; self.header.write(&mut w)?;
105 self.constraints.write(&mut w)?;
106 self.map.write(&mut w)?;
107
108 Ok(())
109 }
110}
111
112#[derive(Debug, PartialEq, Eq)]
113pub struct Header<const FS: usize> {
114 pub prime: FieldElement<FS>,
115 pub n_wires: u32,
116 pub n_pub_out: u32,
117 pub n_pub_in: u32,
118 pub n_prvt_in: u32,
119 pub n_labels: u64,
120 pub n_constraints: u32,
121}
122
123impl<const FS: usize> Header<FS> {
124 fn read<R: Read>(mut r: R) -> Result<Self> {
125 let field_size = r.read_u32::<LittleEndian>()?;
126 if field_size != FS as u32 {
127 return Err(Error::new(ErrorKind::InvalidData, "Wrong field size"));
128 }
129
130 let prime = FieldElement::read(&mut r)?;
131 let n_wires = r.read_u32::<LittleEndian>()?;
132 let n_pub_out = r.read_u32::<LittleEndian>()?;
133 let n_pub_in = r.read_u32::<LittleEndian>()?;
134 let n_prvt_in = r.read_u32::<LittleEndian>()?;
135 let n_labels = r.read_u64::<LittleEndian>()?;
136 let n_constraints = r.read_u32::<LittleEndian>()?;
137
138 Ok(Header {
139 prime,
140 n_wires,
141 n_pub_out,
142 n_pub_in,
143 n_prvt_in,
144 n_labels,
145 n_constraints,
146 })
147 }
148
149 fn write<W: Write>(&self, mut w: W) -> Result<()> {
150 let header = SectionHeader {
151 ty: SectionType::Header,
152 size: 6 * 4 + 8 + FS as u64,
153 };
154
155 header.write(&mut w)?;
156
157 w.write_u32::<LittleEndian>(FS as u32)?;
158 self.prime.write(&mut w)?;
159 w.write_u32::<LittleEndian>(self.n_wires)?;
160 w.write_u32::<LittleEndian>(self.n_pub_out)?;
161 w.write_u32::<LittleEndian>(self.n_pub_in)?;
162 w.write_u32::<LittleEndian>(self.n_prvt_in)?;
163 w.write_u64::<LittleEndian>(self.n_labels)?;
164 w.write_u32::<LittleEndian>(self.n_constraints)?;
165
166 Ok(())
167 }
168}
169
170#[derive(Debug, Default, PartialEq, Eq)]
171pub struct Constraints<const FS: usize>(pub Vec<Constraint<FS>>);
172
173impl<const FS: usize> Constraints<FS> {
174 fn read<R: Read>(r: R, section_header: &SectionHeader) -> Result<Self> {
175 let mut section_data = r.take(section_header.size);
176
177 let mut constraints = Vec::new();
178 while section_data.limit() > 0 {
179 let c = Constraint::read(&mut section_data)?;
180 constraints.push(c);
181 }
182
183 Ok(Constraints(constraints))
184 }
185
186 fn write<W: Write>(&self, mut w: W) -> Result<()> {
187 let header = SectionHeader {
188 ty: SectionType::Constraint,
189 size: self.0.iter().map(|c| c.size()).sum::<usize>() as u64,
190 };
191
192 header.write(&mut w)?;
193
194 for c in &self.0 {
195 c.write(&mut w)?;
196 }
197
198 Ok(())
199 }
200}
201
202#[derive(Debug, Default, PartialEq, Eq)]
203pub struct Constraint<const FS: usize>(
204 pub Vec<(FieldElement<FS>, u32)>,
205 pub Vec<(FieldElement<FS>, u32)>,
206 pub Vec<(FieldElement<FS>, u32)>,
207);
208
209impl<const FS: usize> Constraint<FS> {
210 fn read<R: Read>(mut r: R) -> Result<Self> {
211 let a = Self::read_combination(&mut r)?;
212 let b = Self::read_combination(&mut r)?;
213 let c = Self::read_combination(&mut r)?;
214
215 Ok(Constraint(a, b, c))
216 }
217
218 fn read_combination<R: Read>(mut r: R) -> Result<Vec<(FieldElement<FS>, u32)>> {
219 let n = r.read_u32::<LittleEndian>()?;
220 let mut factors = Vec::new();
221
222 for _ in 0..n {
223 let index = r.read_u32::<LittleEndian>()?;
224 let factor = FieldElement::read(&mut r)?;
225 factors.push((factor, index));
226 }
227
228 Ok(factors)
229 }
230
231 fn write<W: Write>(&self, mut w: W) -> Result<()> {
232 let mut write = |comb: &Vec<(FieldElement<FS>, u32)>| -> Result<()> {
233 w.write_u32::<LittleEndian>(comb.len() as u32)?;
234
235 for (factor, index) in comb {
236 w.write_u32::<LittleEndian>(*index)?;
237 factor.write(&mut w)?;
238 }
239
240 Ok(())
241 };
242
243 write(&self.0)?;
244 write(&self.1)?;
245 write(&self.2)?;
246
247 Ok(())
248 }
249
250 fn size(&self) -> usize {
251 let a = self.0.iter().map(|(f, _)| f.len()).sum::<usize>() + self.0.len() * 4;
252 let b = self.1.iter().map(|(f, _)| f.len()).sum::<usize>() + self.1.len() * 4;
253 let c = self.2.iter().map(|(f, _)| f.len()).sum::<usize>() + self.2.len() * 4;
254
255 a + b + c + 3 * 4
256 }
257}
258
259#[derive(Debug, Default, PartialEq, Eq)]
260pub struct WireMap(pub Vec<u64>);
261
262impl WireMap {
263 fn read<R: Read>(mut r: R, section_header: &SectionHeader) -> Result<Self> {
264 let num_labels = section_header.size / 8;
265 let mut label_ids = Vec::with_capacity(num_labels as usize);
266
267 for _ in 0..num_labels {
268 label_ids.push(r.read_u64::<LittleEndian>()?);
269 }
270
271 Ok(WireMap(label_ids))
272 }
273
274 fn write<W: Write>(&self, mut w: W) -> Result<()> {
275 let header = SectionHeader {
276 ty: SectionType::Wire2LabelIdMap,
277 size: self.0.len() as u64 * 8,
278 };
279
280 header.write(&mut w)?;
281
282 for label_id in &self.0 {
283 w.write_u64::<LittleEndian>(*label_id)?;
284 }
285
286 Ok(())
287 }
288}
289
290struct SectionHeader {
291 ty: SectionType,
292 size: u64,
293}
294
295impl SectionHeader {
296 fn read<R: Read>(mut r: R) -> Result<Self> {
297 let ty = SectionType::read(&mut r)?;
298 let size = r.read_u64::<LittleEndian>()?;
299
300 Ok(SectionHeader { ty, size })
301 }
302
303 fn write<W: Write>(&self, mut w: W) -> Result<()> {
304 w.write_u32::<LittleEndian>(self.ty as u32)?;
305 w.write_u64::<LittleEndian>(self.size)?;
306
307 Ok(())
308 }
309}
310
311#[derive(Debug, PartialEq, Eq, Clone, Copy)]
312#[repr(u32)]
313enum SectionType {
314 Header = 1,
315 Constraint = 2,
316 Wire2LabelIdMap = 3,
317 Unknown = u32::MAX,
318}
319
320impl SectionType {
321 fn read<R: Read>(mut r: R) -> Result<Self> {
322 let num = r.read_u32::<LittleEndian>()?;
323
324 let ty = match num {
325 1 => SectionType::Header,
326 2 => SectionType::Constraint,
327 3 => SectionType::Wire2LabelIdMap,
328 _ => SectionType::Unknown,
329 };
330
331 Ok(ty)
332 }
333}
334
335#[derive(Debug, PartialEq, Eq)]
336pub struct FieldElement<const FS: usize>([u8; FS]);
337
338impl<const FS: usize> FieldElement<FS> {
339 pub fn as_bytes(&self) -> &[u8] {
340 &self.0[..]
341 }
342
343 fn read<R: Read>(mut r: R) -> Result<Self> {
344 let mut buf = [0; FS];
345 r.read_exact(&mut buf)?;
346
347 Ok(FieldElement(buf))
348 }
349
350 fn write<W: Write>(&self, mut w: W) -> Result<()> {
351 w.write_all(&self.0[..])
352 }
353}
354
355impl<const FS: usize> From<[u8; FS]> for FieldElement<FS> {
356 fn from(array: [u8; FS]) -> Self {
357 FieldElement(array)
358 }
359}
360
361impl<const FS: usize> std::ops::Deref for FieldElement<FS> {
362 type Target = [u8; FS];
363
364 fn deref(&self) -> &Self::Target {
365 &self.0
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use hex_literal::hex;
373
374 #[test]
375 fn test_parse() {
376 let data = std::fs::read("tests/simple_circuit.r1cs").unwrap();
377 let file = R1csFile::<32>::read(data.as_slice()).unwrap();
378
379 assert_eq!(
381 file.header.prime,
382 FieldElement::from(hex!(
383 "010000f093f5e1439170b97948e833285d588181b64550b829a031e1724e6430"
384 ))
385 );
386 assert_eq!(file.header.n_wires, 7);
387 assert_eq!(file.header.n_pub_out, 1);
388 assert_eq!(file.header.n_pub_in, 2);
389 assert_eq!(file.header.n_prvt_in, 3);
390 assert_eq!(file.header.n_labels, 0x03e8);
391 assert_eq!(file.header.n_constraints, 3);
392
393 assert_eq!(file.constraints.0.len(), 3);
394 assert_eq!(file.constraints.0[0].0.len(), 2);
395 assert_eq!(file.constraints.0[0].0[0].1, 5);
396 assert_eq!(
397 file.constraints.0[0].0[0].0,
398 FieldElement::from(hex!(
399 "0300000000000000000000000000000000000000000000000000000000000000"
400 )),
401 );
402 assert_eq!(file.constraints.0[2].1[0].1, 0);
403 assert_eq!(
404 file.constraints.0[2].1[0].0,
405 FieldElement::from(hex!(
406 "0600000000000000000000000000000000000000000000000000000000000000"
407 )),
408 );
409 assert_eq!(file.constraints.0[1].2.len(), 0);
410
411 assert_eq!(file.map.0.len(), 7);
412 assert_eq!(file.map.0[1], 3);
413 }
414
415 #[test]
416 fn test_serialize() {
417 let data = std::fs::read("tests/test_circuit.r1cs").unwrap();
418 let parsed_file = R1csFile::<32>::read(data.as_slice()).unwrap();
419 let mut serialized_file = Vec::new();
420 parsed_file.write(&mut serialized_file).unwrap();
421
422 assert_eq!(data.len(), serialized_file.len());
425 assert_eq!(data, serialized_file);
426 }
427}