1#![allow(non_snake_case)]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![allow(clippy::assertions_on_result_states)]
5
6extern crate core;
7extern crate digest;
8extern crate merlin;
9extern crate rand;
10extern crate sha3;
11
12#[cfg(feature = "multicore")]
13extern crate rayon;
14
15mod commitments;
16mod dense_mlpoly;
17mod errors;
18mod math;
19mod nizk;
20mod product_tree;
21mod r1csinstance;
22mod r1csproof;
23mod random;
24mod sparse_mlpoly;
25mod sumcheck;
26mod timer;
27mod transcript;
28mod unipoly;
29
30use ark_ec::CurveGroup;
31use ark_ff::PrimeField;
32use ark_serialize::*;
33use core::cmp::max;
34use errors::{ProofVerifyError, R1CSError};
35use merlin::Transcript;
36use r1csinstance::{
37 R1CSCommitment, R1CSCommitmentGens, R1CSDecommitment, R1CSEvalProof, R1CSInstance,
38};
39use r1csproof::{R1CSGens, R1CSProof};
40use random::RandomTape;
41use timer::Timer;
42use transcript::{AppendToTranscript, ProofTranscript};
43
44pub struct ComputationCommitment<G: CurveGroup> {
46 comm: R1CSCommitment<G>,
47}
48
49pub struct ComputationDecommitment<F> {
51 decomm: R1CSDecommitment<F>,
52}
53
54#[derive(Clone)]
56pub struct Assignment<F> {
57 assignment: Vec<F>,
58}
59
60impl<F: PrimeField> Assignment<F> {
61 pub fn new(assignment: &[F]) -> Result<Self, R1CSError> {
63 let bytes_to_scalar = |vec: &[F]| -> Result<Vec<F>, R1CSError> {
64 let mut vec_scalar: Vec<F> = Vec::new();
65 for v in vec {
66 vec_scalar.push(*v);
67 }
68 Ok(vec_scalar)
69 };
70
71 let assignment_scalar = bytes_to_scalar(assignment)?;
72
73 Ok(Assignment {
74 assignment: assignment_scalar,
75 })
76 }
77
78 fn pad(&self, len: usize) -> VarsAssignment<F> {
80 assert!(len > self.assignment.len());
82
83 let padded_assignment = {
84 let mut padded_assignment = self.assignment.clone();
85 padded_assignment.extend(vec![F::zero(); len - self.assignment.len()]);
86 padded_assignment
87 };
88
89 VarsAssignment {
90 assignment: padded_assignment,
91 }
92 }
93}
94
95pub type VarsAssignment<F> = Assignment<F>;
97
98pub type InputsAssignment<F> = Assignment<F>;
100
101pub struct Instance<F: PrimeField> {
103 inst: R1CSInstance<F>,
104}
105
106impl<F: PrimeField> Instance<F> {
107 pub fn new(
109 num_cons: usize,
110 num_vars: usize,
111 num_inputs: usize,
112 A: &[(usize, usize, F)],
113 B: &[(usize, usize, F)],
114 C: &[(usize, usize, F)],
115 ) -> Result<Self, R1CSError> {
116 let (num_vars_padded, num_cons_padded) = {
117 let num_vars_padded = {
118 let mut num_vars_padded = num_vars;
119
120 num_vars_padded = max(num_vars_padded, num_inputs + 1);
122
123 if num_vars_padded.next_power_of_two() != num_vars_padded {
125 num_vars_padded = num_vars_padded.next_power_of_two();
126 }
127 num_vars_padded
128 };
129
130 let num_cons_padded = {
131 let mut num_cons_padded = num_cons;
132
133 if num_cons_padded == 0 || num_cons_padded == 1 {
135 num_cons_padded = 2;
136 }
137
138 if num_cons.next_power_of_two() != num_cons {
140 num_cons_padded = num_cons.next_power_of_two();
141 }
142 num_cons_padded
143 };
144
145 (num_vars_padded, num_cons_padded)
146 };
147
148 let bytes_to_scalar =
149 |tups: &[(usize, usize, F)]| -> Result<Vec<(usize, usize, F)>, R1CSError> {
150 let mut mat: Vec<(usize, usize, F)> = Vec::new();
151 for &(row, col, val) in tups {
152 if row >= num_cons {
154 return Err(R1CSError::InvalidIndex);
155 }
156
157 if col >= num_vars + 1 + num_inputs {
159 return Err(R1CSError::InvalidIndex);
160 }
161
162 if col >= num_vars {
163 mat.push((row, col + num_vars_padded - num_vars, val));
164 } else {
165 mat.push((row, col, val));
166 }
167 }
168
169 if num_cons == 0 || num_cons == 1 {
172 for i in tups.len()..num_cons_padded {
173 mat.push((i, num_vars, F::zero()));
174 }
175 }
176
177 Ok(mat)
178 };
179
180 let A_scalar = bytes_to_scalar(A);
181 if A_scalar.is_err() {
182 return Err(A_scalar.err().unwrap());
183 }
184
185 let B_scalar = bytes_to_scalar(B);
186 if B_scalar.is_err() {
187 return Err(B_scalar.err().unwrap());
188 }
189
190 let C_scalar = bytes_to_scalar(C);
191 if C_scalar.is_err() {
192 return Err(C_scalar.err().unwrap());
193 }
194
195 let inst = R1CSInstance::<F>::new(
196 num_cons_padded,
197 num_vars_padded,
198 num_inputs,
199 &A_scalar.unwrap(),
200 &B_scalar.unwrap(),
201 &C_scalar.unwrap(),
202 );
203
204 Ok(Instance { inst })
205 }
206
207 pub fn is_sat(
209 &self,
210 vars: &VarsAssignment<F>,
211 inputs: &InputsAssignment<F>,
212 ) -> Result<bool, R1CSError> {
213 if vars.assignment.len() > self.inst.get_num_vars() {
214 return Err(R1CSError::InvalidNumberOfInputs);
215 }
216
217 if inputs.assignment.len() != self.inst.get_num_inputs() {
218 return Err(R1CSError::InvalidNumberOfInputs);
219 }
220
221 let padded_vars = {
223 let num_padded_vars = self.inst.get_num_vars();
224 let num_vars = vars.assignment.len();
225 if num_padded_vars > num_vars {
226 vars.pad(num_padded_vars)
227 } else {
228 vars.clone()
229 }
230 };
231
232 Ok(
233 self
234 .inst
235 .is_sat(&padded_vars.assignment, &inputs.assignment),
236 )
237 }
238
239 pub fn produce_synthetic_r1cs(
241 num_cons: usize,
242 num_vars: usize,
243 num_inputs: usize,
244 ) -> (Instance<F>, VarsAssignment<F>, InputsAssignment<F>) {
245 let (inst, vars, inputs) = R1CSInstance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs);
246 (
247 Instance { inst },
248 VarsAssignment { assignment: vars },
249 InputsAssignment { assignment: inputs },
250 )
251 }
252}
253
254pub struct SNARKGens<G> {
256 gens_r1cs_sat: R1CSGens<G>,
257 gens_r1cs_eval: R1CSCommitmentGens<G>,
258}
259
260impl<G: CurveGroup> SNARKGens<G> {
261 pub fn new(num_cons: usize, num_vars: usize, num_inputs: usize, num_nz_entries: usize) -> Self {
264 let num_vars_padded = {
265 let mut num_vars_padded = max(num_vars, num_inputs + 1);
266 if num_vars_padded != num_vars_padded.next_power_of_two() {
267 num_vars_padded = num_vars_padded.next_power_of_two();
268 }
269 num_vars_padded
270 };
271
272 let gens_r1cs_sat = R1CSGens::<G>::new(b"gens_r1cs_sat", num_cons, num_vars_padded);
273 let gens_r1cs_eval = R1CSCommitmentGens::new(
274 b"gens_r1cs_eval",
275 num_cons,
276 num_vars_padded,
277 num_inputs,
278 num_nz_entries,
279 );
280 SNARKGens {
281 gens_r1cs_sat,
282 gens_r1cs_eval,
283 }
284 }
285}
286
287#[derive(CanonicalSerialize, CanonicalDeserialize, Debug)]
289pub struct SNARK<G: CurveGroup> {
290 r1cs_sat_proof: R1CSProof<G>,
291 inst_evals: (G::ScalarField, G::ScalarField, G::ScalarField),
292 r1cs_eval_proof: R1CSEvalProof<G>,
293}
294
295impl<G: CurveGroup> SNARK<G> {
296 fn protocol_name() -> &'static [u8] {
297 b"Spartan SNARK proof"
298 }
299
300 pub fn encode(
302 inst: &Instance<G::ScalarField>,
303 gens: &SNARKGens<G>,
304 ) -> (
305 ComputationCommitment<G>,
306 ComputationDecommitment<G::ScalarField>,
307 ) {
308 let timer_encode = Timer::new("SNARK::encode");
309 let (comm, decomm) = inst.inst.commit(&gens.gens_r1cs_eval);
310 timer_encode.stop();
311 (
312 ComputationCommitment { comm },
313 ComputationDecommitment { decomm },
314 )
315 }
316
317 pub fn prove(
319 inst: &Instance<G::ScalarField>,
320 comm: &ComputationCommitment<G>,
321 decomm: &ComputationDecommitment<G::ScalarField>,
322 vars: VarsAssignment<G::ScalarField>,
323 inputs: &InputsAssignment<G::ScalarField>,
324 gens: &SNARKGens<G>,
325 transcript: &mut Transcript,
326 ) -> Self {
327 let timer_prove = Timer::new("SNARK::prove");
328
329 let mut random_tape = RandomTape::<G>::new(b"proof");
332 <Transcript as ProofTranscript<G>>::append_protocol_name(
333 transcript,
334 SNARK::<G>::protocol_name(),
335 );
336 comm.comm.append_to_transcript(b"comm", transcript);
337
338 let (r1cs_sat_proof, rx, ry) = {
339 let (proof, rx, ry) = {
340 let padded_vars = {
342 let num_padded_vars = inst.inst.get_num_vars();
343 let num_vars = vars.assignment.len();
344 if num_padded_vars > num_vars {
345 vars.pad(num_padded_vars)
346 } else {
347 vars
348 }
349 };
350
351 R1CSProof::prove(
352 &inst.inst,
353 padded_vars.assignment,
354 &inputs.assignment,
355 &gens.gens_r1cs_sat,
356 transcript,
357 &mut random_tape,
358 )
359 };
360
361 let mut proof_encoded = vec![];
362 proof.serialize_compressed(&mut proof_encoded).unwrap();
363
364 Timer::print(&format!("len_r1cs_sat_proof {:?}", proof_encoded.len()));
365
366 (proof, rx, ry)
367 };
368
369 let timer_eval = Timer::new("eval_sparse_polys");
372 let inst_evals = {
373 let (Ar, Br, Cr) = inst.inst.evaluate(&rx, &ry);
374 <Transcript as ProofTranscript<G>>::append_scalar(transcript, b"Ar_claim", &Ar);
375 <Transcript as ProofTranscript<G>>::append_scalar(transcript, b"Ar_claim", &Br);
376 <Transcript as ProofTranscript<G>>::append_scalar(transcript, b"Ar_claim", &Cr);
377 (Ar, Br, Cr)
378 };
379 timer_eval.stop();
380
381 let r1cs_eval_proof = {
382 let proof = R1CSEvalProof::prove(
383 &decomm.decomm,
384 &rx,
385 &ry,
386 &inst_evals,
387 &gens.gens_r1cs_eval,
388 transcript,
389 &mut random_tape,
390 );
391
392 let mut proof_encoded = vec![];
393 proof.serialize_compressed(&mut proof_encoded).unwrap();
394
395 Timer::print(&format!("len_r1cs_eval_proof {:?}", proof_encoded.len()));
396 proof
397 };
398
399 timer_prove.stop();
400 SNARK {
401 r1cs_sat_proof,
402 inst_evals,
403 r1cs_eval_proof,
404 }
405 }
406
407 pub fn verify(
409 &self,
410 comm: &ComputationCommitment<G>,
411 input: &InputsAssignment<G::ScalarField>,
412 transcript: &mut Transcript,
413 gens: &SNARKGens<G>,
414 ) -> Result<(), ProofVerifyError> {
415 let timer_verify = Timer::new("SNARK::verify");
416 <Transcript as ProofTranscript<G>>::append_protocol_name(
417 transcript,
418 SNARK::<G>::protocol_name(),
419 );
420
421 comm.comm.append_to_transcript(b"comm", transcript);
423
424 let timer_sat_proof = Timer::new("verify_sat_proof");
425 assert_eq!(input.assignment.len(), comm.comm.get_num_inputs());
426 let (rx, ry) = self.r1cs_sat_proof.verify(
427 comm.comm.get_num_vars(),
428 comm.comm.get_num_cons(),
429 &input.assignment,
430 &self.inst_evals,
431 transcript,
432 &gens.gens_r1cs_sat,
433 )?;
434 timer_sat_proof.stop();
435
436 let timer_eval_proof = Timer::new("verify_eval_proof");
437 let (Ar, Br, Cr) = &self.inst_evals;
438 <Transcript as ProofTranscript<G>>::append_scalar(transcript, b"Ar_claim", Ar);
439 <Transcript as ProofTranscript<G>>::append_scalar(transcript, b"Ar_claim", Br);
440 <Transcript as ProofTranscript<G>>::append_scalar(transcript, b"Ar_claim", Cr);
441 self.r1cs_eval_proof.verify(
442 &comm.comm,
443 &rx,
444 &ry,
445 &self.inst_evals,
446 &gens.gens_r1cs_eval,
447 transcript,
448 )?;
449 timer_eval_proof.stop();
450 timer_verify.stop();
451 Ok(())
452 }
453}
454
455pub struct NIZKGens<G> {
457 gens_r1cs_sat: R1CSGens<G>,
458}
459
460impl<G: CurveGroup> NIZKGens<G> {
461 pub fn new(num_cons: usize, num_vars: usize, num_inputs: usize) -> Self {
463 let num_vars_padded = {
464 let mut num_vars_padded = max(num_vars, num_inputs + 1);
465 if num_vars_padded != num_vars_padded.next_power_of_two() {
466 num_vars_padded = num_vars_padded.next_power_of_two();
467 }
468 num_vars_padded
469 };
470
471 let gens_r1cs_sat = R1CSGens::<G>::new(b"gens_r1cs_sat", num_cons, num_vars_padded);
472 NIZKGens { gens_r1cs_sat }
473 }
474}
475
476#[derive(CanonicalSerialize, CanonicalDeserialize, Debug)]
478pub struct NIZK<G: CurveGroup> {
479 r1cs_sat_proof: R1CSProof<G>,
480 r: (Vec<G::ScalarField>, Vec<G::ScalarField>),
481}
482
483impl<G: CurveGroup> NIZK<G> {
484 fn protocol_name() -> &'static [u8] {
485 b"Spartan NIZK proof"
486 }
487
488 pub fn prove(
490 inst: &Instance<G::ScalarField>,
491 vars: VarsAssignment<G::ScalarField>,
492 input: &InputsAssignment<G::ScalarField>,
493 gens: &NIZKGens<G>,
494 transcript: &mut Transcript,
495 ) -> Self {
496 let timer_prove = Timer::new("NIZK::prove");
497 let mut random_tape = RandomTape::new(b"proof");
500
501 <Transcript as ProofTranscript<G>>::append_protocol_name(
502 transcript,
503 NIZK::<G>::protocol_name(),
504 );
505 <R1CSInstance<G::ScalarField> as AppendToTranscript<G>>::append_to_transcript(
506 &inst.inst, b"inst", transcript,
507 );
508
509 let (r1cs_sat_proof, rx, ry) = {
510 let padded_vars = {
512 let num_padded_vars = inst.inst.get_num_vars();
513 let num_vars = vars.assignment.len();
514 if num_padded_vars > num_vars {
515 vars.pad(num_padded_vars)
516 } else {
517 vars
518 }
519 };
520
521 let (proof, rx, ry) = R1CSProof::prove(
522 &inst.inst,
523 padded_vars.assignment,
524 &input.assignment,
525 &gens.gens_r1cs_sat,
526 transcript,
527 &mut random_tape,
528 );
529
530 let mut proof_encoded = vec![];
531 proof.serialize_compressed(&mut proof_encoded).unwrap();
532
533 Timer::print(&format!("len_r1cs_sat_proof {:?}", proof_encoded.len()));
534 (proof, rx, ry)
535 };
536
537 timer_prove.stop();
538 NIZK {
539 r1cs_sat_proof,
540 r: (rx, ry),
541 }
542 }
543
544 pub fn verify(
546 &self,
547 inst: &Instance<G::ScalarField>,
548 input: &InputsAssignment<G::ScalarField>,
549 transcript: &mut Transcript,
550 gens: &NIZKGens<G>,
551 ) -> Result<(), ProofVerifyError> {
552 let timer_verify = Timer::new("NIZK::verify");
553
554 <Transcript as ProofTranscript<G>>::append_protocol_name(
555 transcript,
556 NIZK::<G>::protocol_name(),
557 );
558 <R1CSInstance<G::ScalarField> as AppendToTranscript<G>>::append_to_transcript(
559 &inst.inst, b"inst", transcript,
560 );
561
562 let timer_eval = Timer::new("eval_sparse_polys");
565 let (claimed_rx, claimed_ry) = &self.r;
566 let inst_evals = inst.inst.evaluate(claimed_rx, claimed_ry);
567 timer_eval.stop();
568
569 let timer_sat_proof = Timer::new("verify_sat_proof");
570 assert_eq!(input.assignment.len(), inst.inst.get_num_inputs());
571 let (rx, ry) = self.r1cs_sat_proof.verify(
572 inst.inst.get_num_vars(),
573 inst.inst.get_num_cons(),
574 &input.assignment,
575 &inst_evals,
576 transcript,
577 &gens.gens_r1cs_sat,
578 )?;
579
580 assert_eq!(rx, *claimed_rx);
582 assert_eq!(ry, *claimed_ry);
583 timer_sat_proof.stop();
584 timer_verify.stop();
585
586 Ok(())
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593 use ark_bls12_381::{Fr, G1Projective};
594 use ark_std::One;
595 use ark_std::Zero;
596
597 #[test]
598 pub fn check_snark() {
599 check_snark_helper::<G1Projective>()
600 }
601 pub fn check_snark_helper<G: CurveGroup>() {
602 let num_vars = 256;
603 let num_cons = num_vars;
604 let num_inputs = 10;
605
606 let gens = SNARKGens::<G>::new(num_cons, num_vars, num_inputs, num_cons);
608
609 let (inst, vars, inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs);
611
612 let (comm, decomm) = SNARK::encode(&inst, &gens);
614
615 let mut prover_transcript = Transcript::new(b"example");
617 let proof = SNARK::prove(
618 &inst,
619 &comm,
620 &decomm,
621 vars,
622 &inputs,
623 &gens,
624 &mut prover_transcript,
625 );
626
627 let mut verifier_transcript = Transcript::new(b"example");
629 assert!(proof
630 .verify(&comm, &inputs, &mut verifier_transcript, &gens)
631 .is_ok());
632 }
633
634 #[test]
635 pub fn check_r1cs_invalid_index() {
636 check_r1cs_invalid_index_helper::<Fr>();
637 }
638
639 pub fn check_r1cs_invalid_index_helper<F: PrimeField>() {
640 let num_cons = 4;
641 let num_vars = 8;
642 let num_inputs = 1;
643
644 let zero = F::zero();
645
646 let A = vec![(0, 0, zero)];
647 let B = vec![(100, 1, zero)];
648 let C = vec![(1, 1, zero)];
649
650 let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C);
651 assert!(inst.is_err());
652 }
654
655 #[test]
678 fn test_padded_constraints() {
679 test_padded_constraints_helper::<G1Projective>()
680 }
681
682 fn test_padded_constraints_helper<G: CurveGroup>() {
683 let num_cons = 1;
685 let num_vars = 0;
686 let num_inputs = 3;
687 let num_non_zero_entries = 3;
688
689 let mut A: Vec<(usize, usize, G::ScalarField)> = Vec::new();
692 let mut B: Vec<(usize, usize, G::ScalarField)> = Vec::new();
693 let mut C: Vec<(usize, usize, G::ScalarField)> = Vec::new();
694
695 let zero = G::ScalarField::zero();
696 let one = G::ScalarField::one();
697
698 A.push((0, num_vars + 2, one)); B.push((0, num_vars + 2, one)); C.push((0, num_vars + 1, one)); C.push((0, num_vars, -G::ScalarField::from(13u64))); C.push((0, num_vars + 3, -G::ScalarField::one())); let vars = vec![zero; num_vars];
707
708 let mut inputs = vec![zero; num_inputs];
710 inputs[0] = G::ScalarField::from(16u64);
711 inputs[1] = G::ScalarField::from(1u64);
712 inputs[2] = G::ScalarField::from(2u64);
713
714 let assignment_inputs = InputsAssignment::new(&inputs).unwrap();
715 let assignment_vars = VarsAssignment::new(&vars).unwrap();
716
717 let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C).unwrap();
719 let res = inst.is_sat(&assignment_vars, &assignment_inputs);
720 assert!(res.unwrap(), "should be satisfied");
721
722 let gens = SNARKGens::<G>::new(num_cons, num_vars, num_inputs, num_non_zero_entries);
724
725 let (comm, decomm) = SNARK::encode(&inst, &gens);
727
728 let mut prover_transcript = Transcript::new(b"snark_example");
730 let proof = SNARK::prove(
731 &inst,
732 &comm,
733 &decomm,
734 assignment_vars.clone(),
735 &assignment_inputs,
736 &gens,
737 &mut prover_transcript,
738 );
739
740 let mut verifier_transcript = Transcript::new(b"snark_example");
742 assert!(proof
743 .verify(&comm, &assignment_inputs, &mut verifier_transcript, &gens)
744 .is_ok());
745
746 let gens = NIZKGens::<G>::new(num_cons, num_vars, num_inputs);
748
749 let mut prover_transcript = Transcript::new(b"nizk_example");
751 let proof = NIZK::prove(
752 &inst,
753 assignment_vars,
754 &assignment_inputs,
755 &gens,
756 &mut prover_transcript,
757 );
758
759 let mut verifier_transcript = Transcript::new(b"nizk_example");
761 assert!(proof
762 .verify(&inst, &assignment_inputs, &mut verifier_transcript, &gens)
763 .is_ok());
764 }
765}