1#![allow(non_snake_case)]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![allow(clippy::assertions_on_result_states)]
5
6extern crate byteorder;
7extern crate core;
8extern crate curve25519_dalek;
9extern crate digest;
10extern crate merlin;
11extern crate rand;
12extern crate sha3;
13
14#[cfg(feature = "multicore")]
15extern crate rayon;
16
17mod commitments;
18mod dense_mlpoly;
19mod errors;
20mod group;
21mod math;
22mod nizk;
23mod product_tree;
24mod r1csinstance;
25mod r1csproof;
26mod random;
27mod scalar;
28mod sparse_mlpoly;
29mod sumcheck;
30mod timer;
31mod transcript;
32mod unipoly;
33
34use core::cmp::max;
35use errors::{ProofVerifyError, R1CSError};
36use merlin::Transcript;
37use r1csinstance::{
38 R1CSCommitment, R1CSCommitmentGens, R1CSDecommitment, R1CSEvalProof, R1CSInstance,
39};
40use r1csproof::{R1CSGens, R1CSProof};
41use random::RandomTape;
42use scalar::Scalar;
43use serde::{Deserialize, Serialize};
44use timer::Timer;
45use transcript::{AppendToTranscript, ProofTranscript};
46
47#[derive(Serialize, Deserialize)]
49pub struct ComputationCommitment {
50 comm: R1CSCommitment,
51}
52
53#[derive(Serialize, Deserialize)]
55pub struct ComputationDecommitment {
56 decomm: R1CSDecommitment,
57}
58
59#[derive(Clone, Serialize, Deserialize)]
61pub struct Assignment {
62 assignment: Vec<Scalar>,
63}
64
65impl Assignment {
66 pub fn new(assignment: &[[u8; 32]]) -> Result<Assignment, R1CSError> {
68 let bytes_to_scalar = |vec: &[[u8; 32]]| -> Result<Vec<Scalar>, R1CSError> {
69 let mut vec_scalar: Vec<Scalar> = Vec::new();
70 for v in vec {
71 let val = Scalar::from_bytes(v);
72 if val.is_some().unwrap_u8() == 1 {
73 vec_scalar.push(val.unwrap());
74 } else {
75 return Err(R1CSError::InvalidScalar);
76 }
77 }
78 Ok(vec_scalar)
79 };
80
81 let assignment_scalar = bytes_to_scalar(assignment);
82
83 if assignment_scalar.is_err() {
85 return Err(R1CSError::InvalidScalar);
86 }
87
88 Ok(Assignment {
89 assignment: assignment_scalar.unwrap(),
90 })
91 }
92
93 fn pad(&self, len: usize) -> VarsAssignment {
95 assert!(len > self.assignment.len());
97
98 let padded_assignment = {
99 let mut padded_assignment = self.assignment.clone();
100 padded_assignment.extend(vec![Scalar::zero(); len - self.assignment.len()]);
101 padded_assignment
102 };
103
104 VarsAssignment {
105 assignment: padded_assignment,
106 }
107 }
108}
109
110pub type VarsAssignment = Assignment;
112
113pub type InputsAssignment = Assignment;
115
116pub struct Instance {
118 inst: R1CSInstance,
119 digest: Vec<u8>,
120}
121
122impl Instance {
123 pub fn new(
125 num_cons: usize,
126 num_vars: usize,
127 num_inputs: usize,
128 A: &[(usize, usize, [u8; 32])],
129 B: &[(usize, usize, [u8; 32])],
130 C: &[(usize, usize, [u8; 32])],
131 ) -> Result<Instance, R1CSError> {
132 let (num_vars_padded, num_cons_padded) = {
133 let num_vars_padded = {
134 let mut num_vars_padded = num_vars;
135
136 num_vars_padded = max(num_vars_padded, num_inputs + 1);
138
139 if num_vars_padded.next_power_of_two() != num_vars_padded {
141 num_vars_padded = num_vars_padded.next_power_of_two();
142 }
143 num_vars_padded
144 };
145
146 let num_cons_padded = {
147 let mut num_cons_padded = num_cons;
148
149 if num_cons_padded == 0 || num_cons_padded == 1 {
151 num_cons_padded = 2;
152 }
153
154 if num_cons.next_power_of_two() != num_cons {
156 num_cons_padded = num_cons.next_power_of_two();
157 }
158 num_cons_padded
159 };
160
161 (num_vars_padded, num_cons_padded)
162 };
163
164 let bytes_to_scalar =
165 |tups: &[(usize, usize, [u8; 32])]| -> Result<Vec<(usize, usize, Scalar)>, R1CSError> {
166 let mut mat: Vec<(usize, usize, Scalar)> = Vec::new();
167 for &(row, col, val_bytes) in tups {
168 if row >= num_cons {
170 return Err(R1CSError::InvalidIndex);
171 }
172
173 if col >= num_vars + 1 + num_inputs {
175 return Err(R1CSError::InvalidIndex);
176 }
177
178 let val = Scalar::from_bytes(&val_bytes);
179 if val.is_some().unwrap_u8() == 1 {
180 if col >= num_vars {
183 mat.push((row, col + num_vars_padded - num_vars, val.unwrap()));
184 } else {
185 mat.push((row, col, val.unwrap()));
186 }
187 } else {
188 return Err(R1CSError::InvalidScalar);
189 }
190 }
191
192 if num_cons == 0 || num_cons == 1 {
195 for i in tups.len()..num_cons_padded {
196 mat.push((i, num_vars, Scalar::zero()));
197 }
198 }
199
200 Ok(mat)
201 };
202
203 let A_scalar = bytes_to_scalar(A);
204 if A_scalar.is_err() {
205 return Err(A_scalar.err().unwrap());
206 }
207
208 let B_scalar = bytes_to_scalar(B);
209 if B_scalar.is_err() {
210 return Err(B_scalar.err().unwrap());
211 }
212
213 let C_scalar = bytes_to_scalar(C);
214 if C_scalar.is_err() {
215 return Err(C_scalar.err().unwrap());
216 }
217
218 let inst = R1CSInstance::new(
219 num_cons_padded,
220 num_vars_padded,
221 num_inputs,
222 &A_scalar.unwrap(),
223 &B_scalar.unwrap(),
224 &C_scalar.unwrap(),
225 );
226
227 let digest = inst.get_digest();
228
229 Ok(Instance { inst, digest })
230 }
231
232 pub fn is_sat(
234 &self,
235 vars: &VarsAssignment,
236 inputs: &InputsAssignment,
237 ) -> Result<bool, R1CSError> {
238 if vars.assignment.len() > self.inst.get_num_vars() {
239 return Err(R1CSError::InvalidNumberOfInputs);
240 }
241
242 if inputs.assignment.len() != self.inst.get_num_inputs() {
243 return Err(R1CSError::InvalidNumberOfInputs);
244 }
245
246 let padded_vars = {
248 let num_padded_vars = self.inst.get_num_vars();
249 let num_vars = vars.assignment.len();
250 if num_padded_vars > num_vars {
251 vars.pad(num_padded_vars)
252 } else {
253 vars.clone()
254 }
255 };
256
257 Ok(
258 self
259 .inst
260 .is_sat(&padded_vars.assignment, &inputs.assignment),
261 )
262 }
263
264 pub fn produce_synthetic_r1cs(
266 num_cons: usize,
267 num_vars: usize,
268 num_inputs: usize,
269 ) -> (Instance, VarsAssignment, InputsAssignment) {
270 let (inst, vars, inputs) = R1CSInstance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs);
271 let digest = inst.get_digest();
272 (
273 Instance { inst, digest },
274 VarsAssignment { assignment: vars },
275 InputsAssignment { assignment: inputs },
276 )
277 }
278}
279
280#[derive(Serialize, Deserialize)]
282pub struct SNARKGens {
283 gens_r1cs_sat: R1CSGens,
284 gens_r1cs_eval: R1CSCommitmentGens,
285}
286
287impl SNARKGens {
288 pub fn new(num_cons: usize, num_vars: usize, num_inputs: usize, num_nz_entries: usize) -> Self {
291 let num_vars_padded = {
292 let mut num_vars_padded = max(num_vars, num_inputs + 1);
293 if num_vars_padded != num_vars_padded.next_power_of_two() {
294 num_vars_padded = num_vars_padded.next_power_of_two();
295 }
296 num_vars_padded
297 };
298
299 let gens_r1cs_sat = R1CSGens::new(b"gens_r1cs_sat", num_cons, num_vars_padded);
300 let gens_r1cs_eval = R1CSCommitmentGens::new(
301 b"gens_r1cs_eval",
302 num_cons,
303 num_vars_padded,
304 num_inputs,
305 num_nz_entries,
306 );
307 SNARKGens {
308 gens_r1cs_sat,
309 gens_r1cs_eval,
310 }
311 }
312}
313
314#[derive(Serialize, Deserialize, Debug)]
316pub struct SNARK {
317 r1cs_sat_proof: R1CSProof,
318 inst_evals: (Scalar, Scalar, Scalar),
319 r1cs_eval_proof: R1CSEvalProof,
320}
321
322impl SNARK {
323 fn protocol_name() -> &'static [u8] {
324 b"Spartan SNARK proof"
325 }
326
327 pub fn encode(
329 inst: &Instance,
330 gens: &SNARKGens,
331 ) -> (ComputationCommitment, ComputationDecommitment) {
332 let timer_encode = Timer::new("SNARK::encode");
333 let (comm, decomm) = inst.inst.commit(&gens.gens_r1cs_eval);
334 timer_encode.stop();
335 (
336 ComputationCommitment { comm },
337 ComputationDecommitment { decomm },
338 )
339 }
340
341 pub fn prove(
343 inst: &Instance,
344 comm: &ComputationCommitment,
345 decomm: &ComputationDecommitment,
346 vars: VarsAssignment,
347 inputs: &InputsAssignment,
348 gens: &SNARKGens,
349 transcript: &mut Transcript,
350 ) -> Self {
351 let timer_prove = Timer::new("SNARK::prove");
352
353 let mut random_tape = RandomTape::new(b"proof");
356
357 transcript.append_protocol_name(SNARK::protocol_name());
358 comm.comm.append_to_transcript(b"comm", transcript);
359
360 let (r1cs_sat_proof, rx, ry) = {
361 let (proof, rx, ry) = {
362 let padded_vars = {
364 let num_padded_vars = inst.inst.get_num_vars();
365 let num_vars = vars.assignment.len();
366 if num_padded_vars > num_vars {
367 vars.pad(num_padded_vars)
368 } else {
369 vars
370 }
371 };
372
373 R1CSProof::prove(
374 &inst.inst,
375 padded_vars.assignment,
376 &inputs.assignment,
377 &gens.gens_r1cs_sat,
378 transcript,
379 &mut random_tape,
380 )
381 };
382
383 let proof_encoded: Vec<u8> = bincode::serialize(&proof).unwrap();
384 Timer::print(&format!("len_r1cs_sat_proof {:?}", proof_encoded.len()));
385
386 (proof, rx, ry)
387 };
388
389 let timer_eval = Timer::new("eval_sparse_polys");
392 let inst_evals = {
393 let (Ar, Br, Cr) = inst.inst.evaluate(&rx, &ry);
394 Ar.append_to_transcript(b"Ar_claim", transcript);
395 Br.append_to_transcript(b"Br_claim", transcript);
396 Cr.append_to_transcript(b"Cr_claim", transcript);
397 (Ar, Br, Cr)
398 };
399 timer_eval.stop();
400
401 let r1cs_eval_proof = {
402 let proof = R1CSEvalProof::prove(
403 &decomm.decomm,
404 &rx,
405 &ry,
406 &inst_evals,
407 &gens.gens_r1cs_eval,
408 transcript,
409 &mut random_tape,
410 );
411
412 let proof_encoded: Vec<u8> = bincode::serialize(&proof).unwrap();
413 Timer::print(&format!("len_r1cs_eval_proof {:?}", proof_encoded.len()));
414 proof
415 };
416
417 timer_prove.stop();
418 SNARK {
419 r1cs_sat_proof,
420 inst_evals,
421 r1cs_eval_proof,
422 }
423 }
424
425 pub fn verify(
427 &self,
428 comm: &ComputationCommitment,
429 input: &InputsAssignment,
430 transcript: &mut Transcript,
431 gens: &SNARKGens,
432 ) -> Result<(), ProofVerifyError> {
433 let timer_verify = Timer::new("SNARK::verify");
434 transcript.append_protocol_name(SNARK::protocol_name());
435
436 comm.comm.append_to_transcript(b"comm", transcript);
438
439 let timer_sat_proof = Timer::new("verify_sat_proof");
440 assert_eq!(input.assignment.len(), comm.comm.get_num_inputs());
441 let (rx, ry) = self.r1cs_sat_proof.verify(
442 comm.comm.get_num_vars(),
443 comm.comm.get_num_cons(),
444 &input.assignment,
445 &self.inst_evals,
446 transcript,
447 &gens.gens_r1cs_sat,
448 )?;
449 timer_sat_proof.stop();
450
451 let timer_eval_proof = Timer::new("verify_eval_proof");
452 let (Ar, Br, Cr) = &self.inst_evals;
453 Ar.append_to_transcript(b"Ar_claim", transcript);
454 Br.append_to_transcript(b"Br_claim", transcript);
455 Cr.append_to_transcript(b"Cr_claim", transcript);
456 self.r1cs_eval_proof.verify(
457 &comm.comm,
458 &rx,
459 &ry,
460 &self.inst_evals,
461 &gens.gens_r1cs_eval,
462 transcript,
463 )?;
464 timer_eval_proof.stop();
465 timer_verify.stop();
466 Ok(())
467 }
468}
469
470pub struct NIZKGens {
472 gens_r1cs_sat: R1CSGens,
473}
474
475impl NIZKGens {
476 pub fn new(num_cons: usize, num_vars: usize, num_inputs: usize) -> Self {
478 let num_vars_padded = {
479 let mut num_vars_padded = max(num_vars, num_inputs + 1);
480 if num_vars_padded != num_vars_padded.next_power_of_two() {
481 num_vars_padded = num_vars_padded.next_power_of_two();
482 }
483 num_vars_padded
484 };
485
486 let gens_r1cs_sat = R1CSGens::new(b"gens_r1cs_sat", num_cons, num_vars_padded);
487 NIZKGens { gens_r1cs_sat }
488 }
489}
490
491#[derive(Serialize, Deserialize, Debug)]
493pub struct NIZK {
494 r1cs_sat_proof: R1CSProof,
495 r: (Vec<Scalar>, Vec<Scalar>),
496}
497
498impl NIZK {
499 fn protocol_name() -> &'static [u8] {
500 b"Spartan NIZK proof"
501 }
502
503 pub fn prove(
505 inst: &Instance,
506 vars: VarsAssignment,
507 input: &InputsAssignment,
508 gens: &NIZKGens,
509 transcript: &mut Transcript,
510 ) -> Self {
511 let timer_prove = Timer::new("NIZK::prove");
512 let mut random_tape = RandomTape::new(b"proof");
515
516 transcript.append_protocol_name(NIZK::protocol_name());
517 transcript.append_message(b"R1CSInstanceDigest", &inst.digest);
518
519 let (r1cs_sat_proof, rx, ry) = {
520 let padded_vars = {
522 let num_padded_vars = inst.inst.get_num_vars();
523 let num_vars = vars.assignment.len();
524 if num_padded_vars > num_vars {
525 vars.pad(num_padded_vars)
526 } else {
527 vars
528 }
529 };
530
531 let (proof, rx, ry) = R1CSProof::prove(
532 &inst.inst,
533 padded_vars.assignment,
534 &input.assignment,
535 &gens.gens_r1cs_sat,
536 transcript,
537 &mut random_tape,
538 );
539 let proof_encoded: Vec<u8> = bincode::serialize(&proof).unwrap();
540 Timer::print(&format!("len_r1cs_sat_proof {:?}", proof_encoded.len()));
541 (proof, rx, ry)
542 };
543
544 timer_prove.stop();
545 NIZK {
546 r1cs_sat_proof,
547 r: (rx, ry),
548 }
549 }
550
551 pub fn verify(
553 &self,
554 inst: &Instance,
555 input: &InputsAssignment,
556 transcript: &mut Transcript,
557 gens: &NIZKGens,
558 ) -> Result<(), ProofVerifyError> {
559 let timer_verify = Timer::new("NIZK::verify");
560
561 transcript.append_protocol_name(NIZK::protocol_name());
562 transcript.append_message(b"R1CSInstanceDigest", &inst.digest);
563
564 let timer_eval = Timer::new("eval_sparse_polys");
567 let (claimed_rx, claimed_ry) = &self.r;
568 let inst_evals = inst.inst.evaluate(claimed_rx, claimed_ry);
569 timer_eval.stop();
570
571 let timer_sat_proof = Timer::new("verify_sat_proof");
572 assert_eq!(input.assignment.len(), inst.inst.get_num_inputs());
573 let (rx, ry) = self.r1cs_sat_proof.verify(
574 inst.inst.get_num_vars(),
575 inst.inst.get_num_cons(),
576 &input.assignment,
577 &inst_evals,
578 transcript,
579 &gens.gens_r1cs_sat,
580 )?;
581
582 assert_eq!(rx, *claimed_rx);
584 assert_eq!(ry, *claimed_ry);
585 timer_sat_proof.stop();
586 timer_verify.stop();
587
588 Ok(())
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 pub fn check_snark() {
598 let num_vars = 256;
599 let num_cons = num_vars;
600 let num_inputs = 10;
601
602 let gens = SNARKGens::new(num_cons, num_vars, num_inputs, num_cons);
604
605 let (inst, vars, inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs);
607
608 let (comm, decomm) = SNARK::encode(&inst, &gens);
610
611 let mut prover_transcript = Transcript::new(b"example");
613 let proof = SNARK::prove(
614 &inst,
615 &comm,
616 &decomm,
617 vars,
618 &inputs,
619 &gens,
620 &mut prover_transcript,
621 );
622
623 let mut verifier_transcript = Transcript::new(b"example");
625 assert!(proof
626 .verify(&comm, &inputs, &mut verifier_transcript, &gens)
627 .is_ok());
628 }
629
630 #[test]
631 pub fn check_r1cs_invalid_index() {
632 let num_cons = 4;
633 let num_vars = 8;
634 let num_inputs = 1;
635
636 let zero: [u8; 32] = [
637 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
638 0,
639 ];
640
641 let A = vec![(0, 0, zero)];
642 let B = vec![(100, 1, zero)];
643 let C = vec![(1, 1, zero)];
644
645 let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C);
646 assert!(inst.is_err());
647 assert_eq!(inst.err(), Some(R1CSError::InvalidIndex));
648 }
649
650 #[test]
651 pub fn check_r1cs_invalid_scalar() {
652 let num_cons = 4;
653 let num_vars = 8;
654 let num_inputs = 1;
655
656 let zero: [u8; 32] = [
657 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
658 0,
659 ];
660
661 let larger_than_mod = [
662 3, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8, 216,
663 57, 51, 72, 125, 157, 41, 83, 167, 237, 115,
664 ];
665
666 let A = vec![(0, 0, zero)];
667 let B = vec![(1, 1, larger_than_mod)];
668 let C = vec![(1, 1, zero)];
669
670 let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C);
671 assert!(inst.is_err());
672 assert_eq!(inst.err(), Some(R1CSError::InvalidScalar));
673 }
674
675 #[test]
676 fn test_padded_constraints() {
677 let num_cons = 1;
679 let num_vars = 0;
680 let num_inputs = 3;
681 let num_non_zero_entries = 3;
682
683 let mut A: Vec<(usize, usize, [u8; 32])> = Vec::new();
686 let mut B: Vec<(usize, usize, [u8; 32])> = Vec::new();
687 let mut C: Vec<(usize, usize, [u8; 32])> = Vec::new();
688
689 A.push((0, num_vars + 2, Scalar::one().to_bytes())); B.push((0, num_vars + 2, Scalar::one().to_bytes())); C.push((0, num_vars + 1, Scalar::one().to_bytes())); C.push((0, num_vars, (-Scalar::from(13u64)).to_bytes())); C.push((0, num_vars + 3, (-Scalar::one()).to_bytes())); let vars = vec![Scalar::zero().to_bytes(); num_vars];
698
699 let mut inputs = vec![Scalar::zero().to_bytes(); num_inputs];
701 inputs[0] = Scalar::from(16u64).to_bytes();
702 inputs[1] = Scalar::from(1u64).to_bytes();
703 inputs[2] = Scalar::from(2u64).to_bytes();
704
705 let assignment_inputs = InputsAssignment::new(&inputs).unwrap();
706 let assignment_vars = VarsAssignment::new(&vars).unwrap();
707
708 let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C).unwrap();
710 let res = inst.is_sat(&assignment_vars, &assignment_inputs);
711 assert!(res.unwrap(), "should be satisfied");
712
713 let gens = SNARKGens::new(num_cons, num_vars, num_inputs, num_non_zero_entries);
715
716 let (comm, decomm) = SNARK::encode(&inst, &gens);
718
719 let mut prover_transcript = Transcript::new(b"snark_example");
721 let proof = SNARK::prove(
722 &inst,
723 &comm,
724 &decomm,
725 assignment_vars.clone(),
726 &assignment_inputs,
727 &gens,
728 &mut prover_transcript,
729 );
730
731 let mut verifier_transcript = Transcript::new(b"snark_example");
733 assert!(proof
734 .verify(&comm, &assignment_inputs, &mut verifier_transcript, &gens)
735 .is_ok());
736
737 let gens = NIZKGens::new(num_cons, num_vars, num_inputs);
739
740 let mut prover_transcript = Transcript::new(b"nizk_example");
742 let proof = NIZK::prove(
743 &inst,
744 assignment_vars,
745 &assignment_inputs,
746 &gens,
747 &mut prover_transcript,
748 );
749
750 let mut verifier_transcript = Transcript::new(b"nizk_example");
752 assert!(proof
753 .verify(&inst, &assignment_inputs, &mut verifier_transcript, &gens)
754 .is_ok());
755 }
756}