nova_snark/frontend/gadgets/
multieq.rs1use ff::PrimeField;
4
5use crate::frontend::{ConstraintSystem, LinearCombination, SynthesisError, Variable};
6
7#[derive(Debug)]
9pub struct MultiEq<Scalar: PrimeField, CS: ConstraintSystem<Scalar>> {
10 cs: CS,
11 ops: usize,
12 bits_used: usize,
13 lhs: LinearCombination<Scalar>,
14 rhs: LinearCombination<Scalar>,
15}
16
17impl<Scalar: PrimeField, CS: ConstraintSystem<Scalar>> MultiEq<Scalar, CS> {
18 pub fn new(cs: CS) -> Self {
20 MultiEq {
21 cs,
22 ops: 0,
23 bits_used: 0,
24 lhs: LinearCombination::zero(),
25 rhs: LinearCombination::zero(),
26 }
27 }
28
29 fn accumulate(&mut self) {
30 let ops = self.ops;
31 let lhs = self.lhs.clone();
32 let rhs = self.rhs.clone();
33 self.cs.enforce(
34 || format!("multieq {ops}"),
35 |_| lhs,
36 |lc| lc + CS::one(),
37 |_| rhs,
38 );
39 self.lhs = LinearCombination::zero();
40 self.rhs = LinearCombination::zero();
41 self.bits_used = 0;
42 self.ops += 1;
43 }
44
45 pub fn enforce_equal(
48 &mut self,
49 num_bits: usize,
50 lhs: &LinearCombination<Scalar>,
51 rhs: &LinearCombination<Scalar>,
52 ) {
53 if (Scalar::CAPACITY as usize) <= (self.bits_used + num_bits) {
55 self.accumulate();
56 }
57
58 assert!((Scalar::CAPACITY as usize) > (self.bits_used + num_bits));
59
60 let coeff = Scalar::from(2u64).pow_vartime([self.bits_used as u64]);
61 self.lhs = self.lhs.clone() + (coeff, lhs);
62 self.rhs = self.rhs.clone() + (coeff, rhs);
63 self.bits_used += num_bits;
64 }
65}
66
67impl<Scalar: PrimeField, CS: ConstraintSystem<Scalar>> Drop for MultiEq<Scalar, CS> {
68 fn drop(&mut self) {
69 if self.bits_used > 0 {
70 self.accumulate();
71 }
72 }
73}
74
75impl<Scalar: PrimeField, CS: ConstraintSystem<Scalar>> ConstraintSystem<Scalar>
76 for MultiEq<Scalar, CS>
77{
78 type Root = Self;
79
80 fn one() -> Variable {
81 CS::one()
82 }
83
84 fn alloc<F, A, AR>(&mut self, annotation: A, f: F) -> Result<Variable, SynthesisError>
85 where
86 F: FnOnce() -> Result<Scalar, SynthesisError>,
87 A: FnOnce() -> AR,
88 AR: Into<String>,
89 {
90 self.cs.alloc(annotation, f)
91 }
92
93 fn alloc_input<F, A, AR>(&mut self, annotation: A, f: F) -> Result<Variable, SynthesisError>
94 where
95 F: FnOnce() -> Result<Scalar, SynthesisError>,
96 A: FnOnce() -> AR,
97 AR: Into<String>,
98 {
99 self.cs.alloc_input(annotation, f)
100 }
101
102 fn enforce<A, AR, LA, LB, LC>(&mut self, annotation: A, a: LA, b: LB, c: LC)
103 where
104 A: FnOnce() -> AR,
105 AR: Into<String>,
106 LA: FnOnce(LinearCombination<Scalar>) -> LinearCombination<Scalar>,
107 LB: FnOnce(LinearCombination<Scalar>) -> LinearCombination<Scalar>,
108 LC: FnOnce(LinearCombination<Scalar>) -> LinearCombination<Scalar>,
109 {
110 self.cs.enforce(annotation, a, b, c)
111 }
112
113 fn push_namespace<NR, N>(&mut self, name_fn: N)
114 where
115 NR: Into<String>,
116 N: FnOnce() -> NR,
117 {
118 self.cs.get_root().push_namespace(name_fn)
119 }
120
121 fn pop_namespace(&mut self) {
122 self.cs.get_root().pop_namespace()
123 }
124
125 fn get_root(&mut self) -> &mut Self::Root {
126 self
127 }
128}