1use ff::{PrimeField, PrimeFieldBits};
4use serde::{Deserialize, Serialize};
5
6use crate::frontend::{ConstraintSystem, LinearCombination, SynthesisError, Variable};
7
8use crate::frontend::gadgets::boolean::{self, AllocatedBit, Boolean};
9
10#[derive(Debug, Serialize, Deserialize)]
12pub struct AllocatedNum<Scalar: PrimeField> {
13 value: Option<Scalar>,
14 variable: Variable,
15}
16
17impl<Scalar: PrimeField> Clone for AllocatedNum<Scalar> {
18 fn clone(&self) -> Self {
19 AllocatedNum {
20 value: self.value,
21 variable: self.variable,
22 }
23 }
24}
25
26impl<Scalar: PrimeField> AllocatedNum<Scalar> {
27 pub fn alloc<CS, F>(mut cs: CS, value: F) -> Result<Self, SynthesisError>
29 where
30 CS: ConstraintSystem<Scalar>,
31 F: FnOnce() -> Result<Scalar, SynthesisError>,
32 {
33 let mut new_value = None;
34 let var = cs.alloc(
35 || "num",
36 || {
37 let tmp = value()?;
38
39 new_value = Some(tmp);
40
41 Ok(tmp)
42 },
43 )?;
44
45 Ok(AllocatedNum {
46 value: new_value,
47 variable: var,
48 })
49 }
50
51 pub fn alloc_infallible<CS, F>(cs: CS, value: F) -> Self
54 where
55 CS: ConstraintSystem<Scalar>,
56 F: FnOnce() -> Scalar,
57 {
58 Self::alloc(cs, || Ok(value())).unwrap()
59 }
60
61 pub fn alloc_input<CS, F>(mut cs: CS, value: F) -> Result<Self, SynthesisError>
63 where
64 CS: ConstraintSystem<Scalar>,
65 F: FnOnce() -> Result<Scalar, SynthesisError>,
66 {
67 let mut new_value = None;
68 let var = cs.alloc_input(
69 || "input num",
70 || {
71 let tmp = value()?;
72
73 new_value = Some(tmp);
74
75 Ok(tmp)
76 },
77 )?;
78
79 Ok(AllocatedNum {
80 value: new_value,
81 variable: var,
82 })
83 }
84
85 pub fn alloc_maybe_input<CS, F>(cs: CS, is_input: bool, value: F) -> Result<Self, SynthesisError>
90 where
91 CS: ConstraintSystem<Scalar>,
92 F: FnOnce() -> Result<Scalar, SynthesisError>,
93 {
94 if is_input {
95 Self::alloc_input(cs, value)
96 } else {
97 Self::alloc(cs, value)
98 }
99 }
100
101 pub fn inputize<CS>(&self, mut cs: CS) -> Result<(), SynthesisError>
103 where
104 CS: ConstraintSystem<Scalar>,
105 {
106 let input = cs.alloc_input(
107 || "input variable",
108 || self.value.ok_or(SynthesisError::AssignmentMissing),
109 )?;
110
111 cs.enforce(
112 || "enforce input is correct",
113 |lc| lc + input,
114 |lc| lc + CS::one(),
115 |lc| lc + self.variable,
116 );
117
118 Ok(())
119 }
120
121 pub fn to_bits_le_strict<CS>(&self, mut cs: CS) -> Result<Vec<Boolean>, SynthesisError>
127 where
128 CS: ConstraintSystem<Scalar>,
129 Scalar: PrimeFieldBits,
130 {
131 pub fn kary_and<Scalar, CS>(
132 mut cs: CS,
133 v: &[AllocatedBit],
134 ) -> Result<AllocatedBit, SynthesisError>
135 where
136 Scalar: PrimeField,
137 CS: ConstraintSystem<Scalar>,
138 {
139 assert!(!v.is_empty());
140
141 let mut cur = None;
144
145 for (i, v) in v.iter().enumerate() {
146 if cur.is_none() {
147 cur = Some(v.clone());
148 } else {
149 cur = Some(AllocatedBit::and(
150 cs.namespace(|| format!("and {}", i)),
151 cur.as_ref().unwrap(),
152 v,
153 )?);
154 }
155 }
156
157 Ok(cur.expect("v.len() > 0"))
158 }
159
160 let a = self.value.map(|e| e.to_le_bits());
163 let b = (-Scalar::ONE).to_le_bits();
164
165 let mut a = a.as_ref().map(|e| e.into_iter().rev());
167
168 let mut result = vec![];
169
170 let mut last_run = None;
172 let mut current_run = vec![];
173
174 let mut found_one = false;
175 let mut i = 0;
176 for b in b.into_iter().rev() {
177 let a_bit: Option<bool> = a.as_mut().map(|e| *e.next().unwrap());
178
179 found_one |= b;
181 if !found_one {
182 if let Some(a_bit) = a_bit {
184 assert!(!a_bit);
185 }
186 continue;
187 }
188
189 if b {
190 let a_bit = AllocatedBit::alloc(cs.namespace(|| format!("bit {}", i)), a_bit)?;
193 current_run.push(a_bit.clone());
195 result.push(a_bit);
196 } else {
197 if !current_run.is_empty() {
198 if last_run.is_some() {
202 current_run.push(last_run.clone().unwrap());
203 }
204 last_run = Some(kary_and(
205 cs.namespace(|| format!("run ending at {}", i)),
206 ¤t_run,
207 )?);
208 current_run.truncate(0);
209 }
210
211 let a_bit = AllocatedBit::alloc_conditionally(
217 cs.namespace(|| format!("bit {}", i)),
218 a_bit,
219 last_run.as_ref().expect("char always starts with a one"),
220 )?;
221 result.push(a_bit);
222 }
223
224 i += 1;
225 }
226
227 assert_eq!(current_run.len(), 0);
230
231 let mut lc = LinearCombination::zero();
235 let mut coeff = Scalar::ONE;
236
237 for bit in result.iter().rev() {
238 lc = lc + (coeff, bit.get_variable());
239
240 coeff = coeff.double();
241 }
242
243 lc = lc - self.variable;
244
245 cs.enforce(|| "unpacking constraint", |lc| lc, |lc| lc, |_| lc);
246
247 Ok(result.into_iter().map(Boolean::from).rev().collect())
249 }
250
251 pub fn to_bits_le<CS>(&self, mut cs: CS) -> Result<Vec<Boolean>, SynthesisError>
255 where
256 CS: ConstraintSystem<Scalar>,
257 Scalar: PrimeFieldBits,
258 {
259 let bits = boolean::field_into_allocated_bits_le(&mut cs, self.value)?;
260
261 let mut lc = LinearCombination::zero();
262 let mut coeff = Scalar::ONE;
263
264 for bit in bits.iter() {
265 lc = lc + (coeff, bit.get_variable());
266
267 coeff = coeff.double();
268 }
269
270 lc = lc - self.variable;
271
272 cs.enforce(|| "unpacking constraint", |lc| lc, |lc| lc, |_| lc);
273
274 Ok(bits.into_iter().map(Boolean::from).collect())
275 }
276
277 pub fn add<CS>(&self, mut cs: CS, other: &Self) -> Result<Self, SynthesisError>
279 where
280 CS: ConstraintSystem<Scalar>,
281 {
282 let mut value = None;
283
284 let var = cs.alloc(
285 || "sum num",
286 || {
287 let mut tmp = self.value.ok_or(SynthesisError::AssignmentMissing)?;
288 tmp.add_assign(other.value.ok_or(SynthesisError::AssignmentMissing)?);
289
290 value = Some(tmp);
291
292 Ok(tmp)
293 },
294 )?;
295
296 cs.enforce(
298 || "addition constraint",
299 |lc| lc + self.variable + other.variable,
300 |lc| lc + CS::one(),
301 |lc| lc + var,
302 );
303
304 Ok(AllocatedNum {
305 value,
306 variable: var,
307 })
308 }
309
310 pub fn mul<CS>(&self, mut cs: CS, other: &Self) -> Result<Self, SynthesisError>
312 where
313 CS: ConstraintSystem<Scalar>,
314 {
315 let mut value = None;
316
317 let var = cs.alloc(
318 || "product num",
319 || {
320 let mut tmp = self.value.ok_or(SynthesisError::AssignmentMissing)?;
321 tmp.mul_assign(other.value.ok_or(SynthesisError::AssignmentMissing)?);
322
323 value = Some(tmp);
324
325 Ok(tmp)
326 },
327 )?;
328
329 cs.enforce(
331 || "multiplication constraint",
332 |lc| lc + self.variable,
333 |lc| lc + other.variable,
334 |lc| lc + var,
335 );
336
337 Ok(AllocatedNum {
338 value,
339 variable: var,
340 })
341 }
342
343 pub fn square<CS>(&self, mut cs: CS) -> Result<Self, SynthesisError>
345 where
346 CS: ConstraintSystem<Scalar>,
347 {
348 let mut value = None;
349
350 let var = cs.alloc(
351 || "squared num",
352 || {
353 let mut tmp = self.value.ok_or(SynthesisError::AssignmentMissing)?;
354 tmp = tmp.square();
355
356 value = Some(tmp);
357
358 Ok(tmp)
359 },
360 )?;
361
362 cs.enforce(
364 || "squaring constraint",
365 |lc| lc + self.variable,
366 |lc| lc + self.variable,
367 |lc| lc + var,
368 );
369
370 Ok(AllocatedNum {
371 value,
372 variable: var,
373 })
374 }
375
376 pub fn assert_nonzero<CS>(&self, mut cs: CS) -> Result<(), SynthesisError>
378 where
379 CS: ConstraintSystem<Scalar>,
380 {
381 let inv = cs.alloc(
382 || "ephemeral inverse",
383 || {
384 let tmp = self.value.ok_or(SynthesisError::AssignmentMissing)?;
385
386 if tmp.is_zero().into() {
387 Err(SynthesisError::DivisionByZero)
388 } else {
389 Ok(tmp.invert().unwrap())
390 }
391 },
392 )?;
393
394 cs.enforce(
398 || "nonzero assertion constraint",
399 |lc| lc + self.variable,
400 |lc| lc + inv,
401 |lc| lc + CS::one(),
402 );
403
404 Ok(())
405 }
406
407 pub fn conditionally_reverse<CS>(
411 mut cs: CS,
412 a: &Self,
413 b: &Self,
414 condition: &Boolean,
415 ) -> Result<(Self, Self), SynthesisError>
416 where
417 CS: ConstraintSystem<Scalar>,
418 {
419 let c = Self::alloc(cs.namespace(|| "conditional reversal result 1"), || {
420 if condition
421 .get_value()
422 .ok_or(SynthesisError::AssignmentMissing)?
423 {
424 Ok(b.value.ok_or(SynthesisError::AssignmentMissing)?)
425 } else {
426 Ok(a.value.ok_or(SynthesisError::AssignmentMissing)?)
427 }
428 })?;
429
430 cs.enforce(
431 || "first conditional reversal",
432 |lc| lc + a.variable - b.variable,
433 |_| condition.lc(CS::one(), Scalar::ONE),
434 |lc| lc + a.variable - c.variable,
435 );
436
437 let d = Self::alloc(cs.namespace(|| "conditional reversal result 2"), || {
438 if condition
439 .get_value()
440 .ok_or(SynthesisError::AssignmentMissing)?
441 {
442 Ok(a.value.ok_or(SynthesisError::AssignmentMissing)?)
443 } else {
444 Ok(b.value.ok_or(SynthesisError::AssignmentMissing)?)
445 }
446 })?;
447
448 cs.enforce(
449 || "second conditional reversal",
450 |lc| lc + b.variable - a.variable,
451 |_| condition.lc(CS::one(), Scalar::ONE),
452 |lc| lc + b.variable - d.variable,
453 );
454
455 Ok((c, d))
456 }
457
458 pub fn get_value(&self) -> Option<Scalar> {
460 self.value
461 }
462
463 pub fn get_variable(&self) -> Variable {
465 self.variable
466 }
467}
468
469#[derive(Debug, Clone)]
471pub struct Num<Scalar: PrimeField> {
472 value: Option<Scalar>,
473 lc: LinearCombination<Scalar>,
474}
475
476impl<Scalar: PrimeField> From<AllocatedNum<Scalar>> for Num<Scalar> {
477 fn from(num: AllocatedNum<Scalar>) -> Num<Scalar> {
478 Num {
479 value: num.value,
480 lc: LinearCombination::<Scalar>::from_variable(num.variable),
481 }
482 }
483}
484
485impl<Scalar: PrimeField> Num<Scalar> {
486 pub fn zero() -> Self {
488 Num {
489 value: Some(Scalar::ZERO),
490 lc: LinearCombination::zero(),
491 }
492 }
493
494 pub fn get_value(&self) -> Option<Scalar> {
496 self.value
497 }
498
499 pub fn lc(&self, coeff: Scalar) -> LinearCombination<Scalar> {
501 LinearCombination::zero() + (coeff, &self.lc)
502 }
503
504 pub fn add_bool_with_coeff(self, one: Variable, bit: &Boolean, coeff: Scalar) -> Self {
506 let newval = match (self.value, bit.get_value()) {
507 (Some(mut curval), Some(bval)) => {
508 if bval {
509 curval.add_assign(&coeff);
510 }
511
512 Some(curval)
513 }
514 _ => None,
515 };
516
517 Num {
518 value: newval,
519 lc: self.lc + &bit.lc(one, coeff),
520 }
521 }
522
523 #[allow(clippy::should_implement_trait)]
525 pub fn add(self, other: &Self) -> Self {
526 let lc = self.lc + &other.lc;
527 let value = match (self.value, other.value) {
528 (Some(v1), Some(v2)) => {
529 let mut tmp = v1;
530 tmp.add_assign(&v2);
531 Some(tmp)
532 }
533 (Some(v), None) | (None, Some(v)) => Some(v),
534 (None, None) => None,
535 };
536
537 Num { value, lc }
538 }
539
540 pub fn scale(mut self, scalar: Scalar) -> Self {
542 for (_variable, fr) in self.lc.iter_mut() {
543 fr.mul_assign(&scalar);
544 }
545
546 if let Some(ref mut v) = self.value {
547 v.mul_assign(&scalar);
548 }
549
550 self
551 }
552}