ark_r1cs_std/fields/fp/
mod.rs

1use ark_ff::{BigInteger, PrimeField};
2use ark_relations::r1cs::{
3    ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable,
4};
5
6use core::borrow::Borrow;
7
8use crate::{
9    boolean::AllocatedBool,
10    convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget},
11    fields::{FieldOpsBounds, FieldVar},
12    prelude::*,
13    Assignment, Vec,
14};
15use ark_std::iter::Sum;
16
17mod cmp;
18
19/// Represents a variable in the constraint system whose
20/// value can be an arbitrary field element.
21#[derive(Debug, Clone)]
22#[must_use]
23pub struct AllocatedFp<F: PrimeField> {
24    pub(crate) value: Option<F>,
25    /// The allocated variable corresponding to `self` in `self.cs`.
26    pub variable: Variable,
27    /// The constraint system that `self` was allocated in.
28    pub cs: ConstraintSystemRef<F>,
29}
30
31impl<F: PrimeField> AllocatedFp<F> {
32    /// Constructs a new `AllocatedFp` from a (optional) value, a low-level
33    /// Variable, and a `ConstraintSystemRef`.
34    pub fn new(value: Option<F>, variable: Variable, cs: ConstraintSystemRef<F>) -> Self {
35        Self {
36            value,
37            variable,
38            cs,
39        }
40    }
41}
42
43/// Represent variables corresponding to a field element in `F`.
44#[derive(Clone, Debug)]
45#[must_use]
46pub enum FpVar<F: PrimeField> {
47    /// Represents a constant in the constraint system, which means that
48    /// it does not have a corresponding variable.
49    Constant(F),
50    /// Represents an allocated variable constant in the constraint system.
51    Var(AllocatedFp<F>),
52}
53
54impl<F: PrimeField> FpVar<F> {
55    /// Decomposes `self` into a vector of `bits` and a remainder `rest` such that
56    /// * `bits.len() == size`, and
57    /// * `rest == 0`.
58    pub fn to_bits_le_with_top_bits_zero(
59        &self,
60        size: usize,
61    ) -> Result<(Vec<Boolean<F>>, Self), SynthesisError> {
62        assert!(size <= F::MODULUS_BIT_SIZE as usize - 1);
63        let cs = self.cs();
64        let mode = if self.is_constant() {
65            AllocationMode::Constant
66        } else {
67            AllocationMode::Witness
68        };
69
70        let value = self.value().map(|f| f.into_bigint());
71        let lower_bits = (0..size)
72            .map(|i| {
73                Boolean::new_variable(cs.clone(), || value.map(|v| v.get_bit(i as usize)), mode)
74            })
75            .collect::<Result<Vec<_>, _>>()?;
76        let lower_bits_fp = Boolean::le_bits_to_fp(&lower_bits)?;
77        let rest = self - &lower_bits_fp;
78        rest.enforce_equal(&Self::zero())?;
79        Ok((lower_bits, rest))
80    }
81}
82
83impl<F: PrimeField> R1CSVar<F> for FpVar<F> {
84    type Value = F;
85
86    fn cs(&self) -> ConstraintSystemRef<F> {
87        match self {
88            Self::Constant(_) => ConstraintSystemRef::None,
89            Self::Var(a) => a.cs.clone(),
90        }
91    }
92
93    fn value(&self) -> Result<Self::Value, SynthesisError> {
94        match self {
95            Self::Constant(v) => Ok(*v),
96            Self::Var(v) => v.value(),
97        }
98    }
99}
100
101impl<F: PrimeField> From<Boolean<F>> for FpVar<F> {
102    fn from(other: Boolean<F>) -> Self {
103        if let Boolean::Constant(b) = other {
104            Self::Constant(F::from(b as u8))
105        } else {
106            // `other` is a variable
107            let cs = other.cs();
108            let variable = cs.new_lc(other.lc()).unwrap();
109            Self::Var(AllocatedFp::new(
110                other.value().ok().map(|b| F::from(b as u8)),
111                variable,
112                cs,
113            ))
114        }
115    }
116}
117
118impl<F: PrimeField> From<AllocatedFp<F>> for FpVar<F> {
119    fn from(other: AllocatedFp<F>) -> Self {
120        Self::Var(other)
121    }
122}
123
124impl<'a, F: PrimeField> FieldOpsBounds<'a, F, Self> for FpVar<F> {}
125impl<'a, F: PrimeField> FieldOpsBounds<'a, F, FpVar<F>> for &'a FpVar<F> {}
126
127impl<F: PrimeField> AllocatedFp<F> {
128    /// Constructs `Self` from a `Boolean`: if `other` is false, this outputs
129    /// `zero`, else it outputs `one`.
130    pub fn from(other: Boolean<F>) -> Self {
131        let cs = other.cs();
132        let variable = cs.new_lc(other.lc()).unwrap();
133        Self::new(other.value().ok().map(|b| F::from(b as u8)), variable, cs)
134    }
135
136    /// Returns the value assigned to `self` in the underlying constraint system
137    /// (if a value was assigned).
138    pub fn value(&self) -> Result<F, SynthesisError> {
139        self.cs.assigned_value(self.variable).get()
140    }
141
142    /// Outputs `self + other`.
143    ///
144    /// This does not create any constraints.
145    #[tracing::instrument(target = "r1cs")]
146    pub fn add(&self, other: &Self) -> Self {
147        let value = match (self.value, other.value) {
148            (Some(val1), Some(val2)) => Some(val1 + &val2),
149            (..) => None,
150        };
151
152        let variable = self
153            .cs
154            .new_lc(lc!() + self.variable + other.variable)
155            .unwrap();
156        AllocatedFp::new(value, variable, self.cs.clone())
157    }
158
159    /// Add many allocated Fp elements together.
160    ///
161    /// This does not create any constraints and only creates one linear
162    /// combination.
163    pub fn add_many<B: Borrow<Self>, I: Iterator<Item = B>>(iter: I) -> Self {
164        let mut cs = ConstraintSystemRef::None;
165        let mut has_value = true;
166        let mut value = F::zero();
167        let mut new_lc = lc!();
168
169        let mut num_iters = 0;
170        for variable in iter {
171            let variable = variable.borrow();
172            if !variable.cs.is_none() {
173                cs = cs.or(variable.cs.clone());
174            }
175            if variable.value.is_none() {
176                has_value = false;
177            } else {
178                value += variable.value.unwrap();
179            }
180            new_lc = new_lc + variable.variable;
181            num_iters += 1;
182        }
183        assert_ne!(num_iters, 0);
184
185        let variable = cs.new_lc(new_lc).unwrap();
186
187        if has_value {
188            AllocatedFp::new(Some(value), variable, cs)
189        } else {
190            AllocatedFp::new(None, variable, cs)
191        }
192    }
193
194    /// Outputs `self - other`.
195    ///
196    /// This does not create any constraints.
197    #[tracing::instrument(target = "r1cs")]
198    pub fn sub(&self, other: &Self) -> Self {
199        let value = match (self.value, other.value) {
200            (Some(val1), Some(val2)) => Some(val1 - &val2),
201            (..) => None,
202        };
203
204        let variable = self
205            .cs
206            .new_lc(lc!() + self.variable - other.variable)
207            .unwrap();
208        AllocatedFp::new(value, variable, self.cs.clone())
209    }
210
211    /// Outputs `self * other`.
212    ///
213    /// This requires *one* constraint.
214    #[tracing::instrument(target = "r1cs")]
215    pub fn mul(&self, other: &Self) -> Self {
216        let product = AllocatedFp::new_witness(self.cs.clone(), || {
217            Ok(self.value.get()? * &other.value.get()?)
218        })
219        .unwrap();
220        self.cs
221            .enforce_constraint(
222                lc!() + self.variable,
223                lc!() + other.variable,
224                lc!() + product.variable,
225            )
226            .unwrap();
227        product
228    }
229
230    /// Output `self + other`
231    ///
232    /// This does not create any constraints.
233    #[tracing::instrument(target = "r1cs")]
234    pub fn add_constant(&self, other: F) -> Self {
235        if other.is_zero() {
236            self.clone()
237        } else {
238            let value = self.value.map(|val| val + other);
239            let variable = self
240                .cs
241                .new_lc(lc!() + self.variable + (other, Variable::One))
242                .unwrap();
243            AllocatedFp::new(value, variable, self.cs.clone())
244        }
245    }
246
247    /// Output `self - other`
248    ///
249    /// This does not create any constraints.
250    #[tracing::instrument(target = "r1cs")]
251    pub fn sub_constant(&self, other: F) -> Self {
252        self.add_constant(-other)
253    }
254
255    /// Output `self * other`
256    ///
257    /// This does not create any constraints.
258    #[tracing::instrument(target = "r1cs")]
259    pub fn mul_constant(&self, other: F) -> Self {
260        if other.is_one() {
261            self.clone()
262        } else {
263            let value = self.value.map(|val| val * other);
264            let variable = self.cs.new_lc(lc!() + (other, self.variable)).unwrap();
265            AllocatedFp::new(value, variable, self.cs.clone())
266        }
267    }
268
269    /// Output `self + self`
270    ///
271    /// This does not create any constraints.
272    #[tracing::instrument(target = "r1cs")]
273    pub fn double(&self) -> Result<Self, SynthesisError> {
274        let value = self.value.map(|val| val.double());
275        let variable = self.cs.new_lc(lc!() + self.variable + self.variable)?;
276        Ok(Self::new(value, variable, self.cs.clone()))
277    }
278
279    /// Output `-self`
280    ///
281    /// This does not create any constraints.
282    #[tracing::instrument(target = "r1cs")]
283    pub fn negate(&self) -> Self {
284        let mut result = self.clone();
285        result.negate_in_place();
286        result
287    }
288
289    /// Sets `self = -self`
290    ///
291    /// This does not create any constraints.
292    #[tracing::instrument(target = "r1cs")]
293    pub fn negate_in_place(&mut self) -> &mut Self {
294        if let Some(val) = self.value.as_mut() {
295            *val = -(*val);
296        }
297        self.variable = self.cs.new_lc(lc!() - self.variable).unwrap();
298        self
299    }
300
301    /// Outputs `self * self`
302    ///
303    /// This requires *one* constraint.
304    #[tracing::instrument(target = "r1cs")]
305    pub fn square(&self) -> Result<Self, SynthesisError> {
306        Ok(self.mul(self))
307    }
308
309    /// Outputs `result` such that `result * self = 1`.
310    ///
311    /// This requires *one* constraint.
312    #[tracing::instrument(target = "r1cs")]
313    pub fn inverse(&self) -> Result<Self, SynthesisError> {
314        let inverse = Self::new_witness(self.cs.clone(), || {
315            Ok(self.value.get()?.inverse().unwrap_or_else(F::zero))
316        })?;
317
318        self.cs.enforce_constraint(
319            lc!() + self.variable,
320            lc!() + inverse.variable,
321            lc!() + Variable::One,
322        )?;
323        Ok(inverse)
324    }
325
326    /// This is a no-op for prime fields.
327    #[tracing::instrument(target = "r1cs")]
328    pub fn frobenius_map(&self, _: usize) -> Result<Self, SynthesisError> {
329        Ok(self.clone())
330    }
331
332    /// Enforces that `self * other = result`.
333    ///
334    /// This requires *one* constraint.
335    #[tracing::instrument(target = "r1cs")]
336    pub fn mul_equals(&self, other: &Self, result: &Self) -> Result<(), SynthesisError> {
337        self.cs.enforce_constraint(
338            lc!() + self.variable,
339            lc!() + other.variable,
340            lc!() + result.variable,
341        )
342    }
343
344    /// Enforces that `self * self = result`.
345    ///
346    /// This requires *one* constraint.
347    #[tracing::instrument(target = "r1cs")]
348    pub fn square_equals(&self, result: &Self) -> Result<(), SynthesisError> {
349        self.cs.enforce_constraint(
350            lc!() + self.variable,
351            lc!() + self.variable,
352            lc!() + result.variable,
353        )
354    }
355
356    /// Outputs the bit `self == other`.
357    ///
358    /// This requires two constraints.
359    #[tracing::instrument(target = "r1cs")]
360    pub fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
361        Ok(!self.is_neq(other)?)
362    }
363
364    /// Outputs the bit `self != other`.
365    ///
366    /// This requires two constraints.
367    #[tracing::instrument(target = "r1cs")]
368    pub fn is_neq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
369        // We don't need to enforce `is_not_equal` to be boolean here;
370        // see the comments above the constraints below for why.
371        let is_not_equal = Boolean::from(AllocatedBool::new_witness_without_booleanity_check(
372            self.cs.clone(),
373            || Ok(self.value.get()? != other.value.get()?),
374        )?);
375        let multiplier = self.cs.new_witness_variable(|| {
376            if is_not_equal.value()? {
377                (self.value.get()? - other.value.get()?).inverse().get()
378            } else {
379                Ok(F::one())
380            }
381        })?;
382
383        // Completeness:
384        // Case 1: self != other:
385        // ----------------------
386        //   constraint 1:
387        //   (self - other) * multiplier = is_not_equal
388        //   => (non_zero) * multiplier = 1 (satisfied, because multiplier = 1/(self -
389        // other)
390        //
391        //   constraint 2:
392        //   (self - other) * not(is_not_equal) = 0
393        //   => (non_zero) * not(1) = 0
394        //   => (non_zero) * 0 = 0
395        //
396        // Case 2: self == other:
397        // ----------------------
398        //   constraint 1:
399        //   (self - other) * multiplier = is_not_equal
400        //   => 0 * multiplier = 0 (satisfied, because multiplier = 1
401        //
402        //   constraint 2:
403        //   (self - other) * not(is_not_equal) = 0
404        //   => 0 * not(0) = 0
405        //   => 0 * 1 = 0
406        //
407        // --------------------------------------------------------------------
408        //
409        // Soundness:
410        // Case 1: self != other, but is_not_equal != 1.
411        // --------------------------------------------
412        //   constraint 2:
413        //   (self - other) * not(is_not_equal) = 0
414        //   => (non_zero) * (1 - is_not_equal) = 0
415        //   => non_zero = 0 (contradiction) || 1 - is_not_equal = 0 (contradiction)
416        //
417        // Case 2: self == other, but is_not_equal != 0.
418        // --------------------------------------------
419        //   constraint 1:
420        //   (self - other) * multiplier = is_not_equal
421        //   0 * multiplier = is_not_equal != 0 (unsatisfiable)
422        //
423        // That is, constraint 1 enforces that if self == other, then `is_not_equal = 0`
424        // and constraint 2 enforces that if self != other, then `is_not_equal = 1`.
425        // Since these are the only possible two cases, `is_not_equal` is always
426        // constrained to 0 or 1.
427        self.cs.enforce_constraint(
428            lc!() + self.variable - other.variable,
429            lc!() + multiplier,
430            is_not_equal.lc(),
431        )?;
432        self.cs.enforce_constraint(
433            lc!() + self.variable - other.variable,
434            (!&is_not_equal).lc(),
435            lc!(),
436        )?;
437        Ok(is_not_equal)
438    }
439
440    /// Enforces that self == other if `should_enforce.is_eq(&Boolean::TRUE)`.
441    ///
442    /// This requires one constraint.
443    #[tracing::instrument(target = "r1cs")]
444    pub fn conditional_enforce_equal(
445        &self,
446        other: &Self,
447        should_enforce: &Boolean<F>,
448    ) -> Result<(), SynthesisError> {
449        self.cs.enforce_constraint(
450            lc!() + self.variable - other.variable,
451            lc!() + should_enforce.lc(),
452            lc!(),
453        )
454    }
455
456    /// Enforces that self != other if `should_enforce.is_eq(&Boolean::TRUE)`.
457    ///
458    /// This requires one constraint.
459    #[tracing::instrument(target = "r1cs")]
460    pub fn conditional_enforce_not_equal(
461        &self,
462        other: &Self,
463        should_enforce: &Boolean<F>,
464    ) -> Result<(), SynthesisError> {
465        // The high level logic is as follows:
466        // We want to check that self - other != 0. We do this by checking that
467        // (self - other).inverse() exists. In more detail, we check the following:
468        // If `should_enforce == true`, then we set `multiplier = (self - other).inverse()`,
469        // and check that (self - other) * multiplier == 1. (i.e., that the inverse exists)
470        //
471        // If `should_enforce == false`, then we set `multiplier == 0`, and check that
472        // (self - other) * 0 == 0, which is always satisfied.
473        let multiplier = Self::new_witness(self.cs.clone(), || {
474            if should_enforce.value()? {
475                (self.value.get()? - other.value.get()?).inverse().get()
476            } else {
477                Ok(F::zero())
478            }
479        })?;
480
481        self.cs.enforce_constraint(
482            lc!() + self.variable - other.variable,
483            lc!() + multiplier.variable,
484            should_enforce.lc(),
485        )?;
486        Ok(())
487    }
488}
489
490/// *************************************************************************
491/// *************************************************************************
492
493impl<F: PrimeField> ToBitsGadget<F> for AllocatedFp<F> {
494    /// Outputs the unique bit-wise decomposition of `self` in *little-endian*
495    /// form.
496    ///
497    /// This method enforces that the output is in the field, i.e.
498    /// it invokes `Boolean::enforce_in_field_le` on the bit decomposition.
499    #[tracing::instrument(target = "r1cs")]
500    fn to_bits_le(&self) -> Result<Vec<Boolean<F>>, SynthesisError> {
501        let bits = self.to_non_unique_bits_le()?;
502        Boolean::enforce_in_field_le(&bits)?;
503        Ok(bits)
504    }
505
506    #[tracing::instrument(target = "r1cs")]
507    fn to_non_unique_bits_le(&self) -> Result<Vec<Boolean<F>>, SynthesisError> {
508        let cs = self.cs.clone();
509        use ark_ff::BitIteratorBE;
510        let mut bits = if let Some(value) = self.value {
511            let field_char = BitIteratorBE::new(F::characteristic());
512            let bits: Vec<_> = BitIteratorBE::new(value.into_bigint())
513                .zip(field_char)
514                .skip_while(|(_, c)| !c)
515                .map(|(b, _)| Some(b))
516                .collect();
517            assert_eq!(bits.len(), F::MODULUS_BIT_SIZE as usize);
518            bits
519        } else {
520            vec![None; F::MODULUS_BIT_SIZE as usize]
521        };
522
523        // Convert to little-endian
524        bits.reverse();
525
526        let bits: Vec<_> = bits
527            .into_iter()
528            .map(|b| Boolean::new_witness(cs.clone(), || b.get()))
529            .collect::<Result<_, _>>()?;
530
531        let mut lc = LinearCombination::zero();
532        let mut coeff = F::one();
533
534        for bit in bits.iter() {
535            lc = &lc + bit.lc() * coeff;
536
537            coeff.double_in_place();
538        }
539
540        lc = lc - &self.variable;
541
542        cs.enforce_constraint(lc!(), lc!(), lc)?;
543
544        Ok(bits)
545    }
546}
547
548impl<F: PrimeField> ToBytesGadget<F> for AllocatedFp<F> {
549    /// Outputs the unique byte decomposition of `self` in *little-endian*
550    /// form.
551    ///
552    /// This method enforces that the decomposition represents
553    /// an integer that is less than `F::MODULUS`.
554    #[tracing::instrument(target = "r1cs")]
555    fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
556        let num_bits = F::BigInt::NUM_LIMBS * 64;
557        let mut bits = self.to_bits_le()?;
558        let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len());
559        bits.extend(remainder);
560        let bytes = bits
561            .chunks(8)
562            .map(|chunk| UInt8::from_bits_le(chunk))
563            .collect();
564        Ok(bytes)
565    }
566
567    #[tracing::instrument(target = "r1cs")]
568    fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
569        let num_bits = F::BigInt::NUM_LIMBS * 64;
570        let mut bits = self.to_non_unique_bits_le()?;
571        let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len());
572        bits.extend(remainder);
573        let bytes = bits
574            .chunks(8)
575            .map(|chunk| UInt8::from_bits_le(chunk))
576            .collect();
577        Ok(bytes)
578    }
579}
580
581impl<F: PrimeField> ToConstraintFieldGadget<F> for AllocatedFp<F> {
582    #[tracing::instrument(target = "r1cs")]
583    fn to_constraint_field(&self) -> Result<Vec<FpVar<F>>, SynthesisError> {
584        Ok(vec![self.clone().into()])
585    }
586}
587
588impl<F: PrimeField> CondSelectGadget<F> for AllocatedFp<F> {
589    #[inline]
590    #[tracing::instrument(target = "r1cs")]
591    fn conditionally_select(
592        cond: &Boolean<F>,
593        true_val: &Self,
594        false_val: &Self,
595    ) -> Result<Self, SynthesisError> {
596        match cond {
597            &Boolean::Constant(true) => Ok(true_val.clone()),
598            &Boolean::Constant(false) => Ok(false_val.clone()),
599            _ => {
600                let cs = cond.cs();
601                let result = Self::new_witness(cs.clone(), || {
602                    cond.value()
603                        .and_then(|c| if c { true_val } else { false_val }.value.get())
604                })?;
605                // a = self; b = other; c = cond;
606                //
607                // r = c * a + (1  - c) * b
608                // r = b + c * (a - b)
609                // c * (a - b) = r - b
610                cs.enforce_constraint(
611                    cond.lc(),
612                    lc!() + true_val.variable - false_val.variable,
613                    lc!() + result.variable - false_val.variable,
614                )?;
615
616                Ok(result)
617            },
618        }
619    }
620}
621
622/// Uses two bits to perform a lookup into a table
623/// `b` is little-endian: `b[0]` is LSB.
624impl<F: PrimeField> TwoBitLookupGadget<F> for AllocatedFp<F> {
625    type TableConstant = F;
626    #[tracing::instrument(target = "r1cs")]
627    fn two_bit_lookup(b: &[Boolean<F>], c: &[Self::TableConstant]) -> Result<Self, SynthesisError> {
628        debug_assert_eq!(b.len(), 2);
629        debug_assert_eq!(c.len(), 4);
630        let result = Self::new_witness(b.cs(), || {
631            let lsb = usize::from(b[0].value()?);
632            let msb = usize::from(b[1].value()?);
633            let index = lsb + (msb << 1);
634            Ok(c[index])
635        })?;
636        let one = Variable::One;
637        b.cs().enforce_constraint(
638            lc!() + b[1].lc() * (c[3] - &c[2] - &c[1] + &c[0]) + (c[1] - &c[0], one),
639            lc!() + b[0].lc(),
640            lc!() + result.variable - (c[0], one) + b[1].lc() * (c[0] - &c[2]),
641        )?;
642
643        Ok(result)
644    }
645}
646
647impl<F: PrimeField> ThreeBitCondNegLookupGadget<F> for AllocatedFp<F> {
648    type TableConstant = F;
649
650    #[tracing::instrument(target = "r1cs")]
651    fn three_bit_cond_neg_lookup(
652        b: &[Boolean<F>],
653        b0b1: &Boolean<F>,
654        c: &[Self::TableConstant],
655    ) -> Result<Self, SynthesisError> {
656        debug_assert_eq!(b.len(), 3);
657        debug_assert_eq!(c.len(), 4);
658        let result = Self::new_witness(b.cs(), || {
659            let lsb = usize::from(b[0].value()?);
660            let msb = usize::from(b[1].value()?);
661            let index = lsb + (msb << 1);
662            let intermediate = c[index];
663
664            let is_negative = b[2].value()?;
665            let y = if is_negative {
666                -intermediate
667            } else {
668                intermediate
669            };
670            Ok(y)
671        })?;
672
673        let y_lc = b0b1.lc() * (c[3] - &c[2] - &c[1] + &c[0])
674            + b[0].lc() * (c[1] - &c[0])
675            + b[1].lc() * (c[2] - &c[0])
676            + (c[0], Variable::One);
677        // enforce y * (1 - 2 * b_2) == res
678        b.cs().enforce_constraint(
679            y_lc.clone(),
680            b[2].lc() * F::from(2u64).neg() + (F::one(), Variable::One),
681            lc!() + result.variable,
682        )?;
683
684        Ok(result)
685    }
686}
687
688impl<F: PrimeField> AllocVar<F, F> for AllocatedFp<F> {
689    fn new_variable<T: Borrow<F>>(
690        cs: impl Into<Namespace<F>>,
691        f: impl FnOnce() -> Result<T, SynthesisError>,
692        mode: AllocationMode,
693    ) -> Result<Self, SynthesisError> {
694        let ns = cs.into();
695        let cs = ns.cs();
696        if mode == AllocationMode::Constant {
697            let v = *f()?.borrow();
698            let lc = cs.new_lc(lc!() + (v, Variable::One))?;
699            Ok(Self::new(Some(v), lc, cs))
700        } else {
701            let mut value = None;
702            let value_generator = || {
703                value = Some(*f()?.borrow());
704                value.ok_or(SynthesisError::AssignmentMissing)
705            };
706            let variable = if mode == AllocationMode::Input {
707                cs.new_input_variable(value_generator)?
708            } else {
709                cs.new_witness_variable(value_generator)?
710            };
711            Ok(Self::new(value, variable, cs))
712        }
713    }
714}
715
716impl<F: PrimeField> FieldVar<F, F> for FpVar<F> {
717    fn constant(f: F) -> Self {
718        Self::Constant(f)
719    }
720
721    fn zero() -> Self {
722        Self::Constant(F::zero())
723    }
724
725    fn one() -> Self {
726        Self::Constant(F::one())
727    }
728
729    #[tracing::instrument(target = "r1cs")]
730    fn double(&self) -> Result<Self, SynthesisError> {
731        match self {
732            Self::Constant(c) => Ok(Self::Constant(c.double())),
733            Self::Var(v) => Ok(Self::Var(v.double()?)),
734        }
735    }
736
737    #[tracing::instrument(target = "r1cs")]
738    fn negate(&self) -> Result<Self, SynthesisError> {
739        match self {
740            Self::Constant(c) => Ok(Self::Constant(-*c)),
741            Self::Var(v) => Ok(Self::Var(v.negate())),
742        }
743    }
744
745    #[tracing::instrument(target = "r1cs")]
746    fn square(&self) -> Result<Self, SynthesisError> {
747        match self {
748            Self::Constant(c) => Ok(Self::Constant(c.square())),
749            Self::Var(v) => Ok(Self::Var(v.square()?)),
750        }
751    }
752
753    /// Enforce that `self * other == result`.
754    #[tracing::instrument(target = "r1cs")]
755    fn mul_equals(&self, other: &Self, result: &Self) -> Result<(), SynthesisError> {
756        use FpVar::*;
757        match (self, other, result) {
758            (Constant(_), Constant(_), Constant(_)) => Ok(()),
759            (Constant(_), Constant(_), _) | (Constant(_), Var(_), _) | (Var(_), Constant(_), _) => {
760                result.enforce_equal(&(self * other))
761            }, // this multiplication should be free
762            (Var(v1), Var(v2), Var(v3)) => v1.mul_equals(v2, v3),
763            (Var(v1), Var(v2), Constant(f)) => {
764                let cs = v1.cs.clone();
765                let v3 = AllocatedFp::new_constant(cs, f).unwrap();
766                v1.mul_equals(v2, &v3)
767            },
768        }
769    }
770
771    /// Enforce that `self * self == result`.
772    #[tracing::instrument(target = "r1cs")]
773    fn square_equals(&self, result: &Self) -> Result<(), SynthesisError> {
774        use FpVar::*;
775        match (self, result) {
776            (Constant(_), Constant(_)) => Ok(()),
777            (Constant(f), Var(r)) => {
778                let cs = r.cs.clone();
779                let v = AllocatedFp::new_witness(cs, || Ok(f))?;
780                v.square_equals(&r)
781            },
782            (Var(v), Constant(f)) => {
783                let cs = v.cs.clone();
784                let r = AllocatedFp::new_witness(cs, || Ok(f))?;
785                v.square_equals(&r)
786            },
787            (Var(v1), Var(v2)) => v1.square_equals(v2),
788        }
789    }
790
791    #[tracing::instrument(target = "r1cs")]
792    fn inverse(&self) -> Result<Self, SynthesisError> {
793        match self {
794            FpVar::Var(v) => v.inverse().map(FpVar::Var),
795            FpVar::Constant(f) => f.inverse().get().map(FpVar::Constant),
796        }
797    }
798
799    #[tracing::instrument(target = "r1cs")]
800    fn frobenius_map(&self, power: usize) -> Result<Self, SynthesisError> {
801        match self {
802            FpVar::Var(v) => v.frobenius_map(power).map(FpVar::Var),
803            FpVar::Constant(f) => {
804                let mut f = *f;
805                f.frobenius_map_in_place(power);
806                Ok(FpVar::Constant(f))
807            },
808        }
809    }
810
811    #[tracing::instrument(target = "r1cs")]
812    fn frobenius_map_in_place(&mut self, power: usize) -> Result<&mut Self, SynthesisError> {
813        *self = self.frobenius_map(power)?;
814        Ok(self)
815    }
816}
817
818impl_ops!(
819    FpVar<F>,
820    F,
821    Add,
822    add,
823    AddAssign,
824    add_assign,
825    |this: &'a FpVar<F>, other: &'a FpVar<F>| {
826        use FpVar::*;
827        match (this, other) {
828            (Constant(c1), Constant(c2)) => Constant(*c1 + *c2),
829            (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.add_constant(*c)),
830            (Var(v1), Var(v2)) => Var(v1.add(v2)),
831        }
832    },
833    |this: &'a FpVar<F>, other: F| { this + &FpVar::Constant(other) },
834    F: PrimeField,
835);
836
837impl_ops!(
838    FpVar<F>,
839    F,
840    Sub,
841    sub,
842    SubAssign,
843    sub_assign,
844    |this: &'a FpVar<F>, other: &'a FpVar<F>| {
845        use FpVar::*;
846        match (this, other) {
847            (Constant(c1), Constant(c2)) => Constant(*c1 - *c2),
848            (Var(v), Constant(c)) => Var(v.sub_constant(*c)),
849            (Constant(c), Var(v)) => Var(v.sub_constant(*c).negate()),
850            (Var(v1), Var(v2)) => Var(v1.sub(v2)),
851        }
852    },
853    |this: &'a FpVar<F>, other: F| { this - &FpVar::Constant(other) },
854    F: PrimeField
855);
856
857impl_ops!(
858    FpVar<F>,
859    F,
860    Mul,
861    mul,
862    MulAssign,
863    mul_assign,
864    |this: &'a FpVar<F>, other: &'a FpVar<F>| {
865        use FpVar::*;
866        match (this, other) {
867            (Constant(c1), Constant(c2)) => Constant(*c1 * *c2),
868            (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.mul_constant(*c)),
869            (Var(v1), Var(v2)) => Var(v1.mul(v2)),
870        }
871    },
872    |this: &'a FpVar<F>, other: F| {
873        if other.is_zero() {
874            FpVar::zero()
875        } else {
876            this * &FpVar::Constant(other)
877        }
878    },
879    F: PrimeField
880);
881
882/// *************************************************************************
883/// *************************************************************************
884
885impl<F: PrimeField> EqGadget<F> for FpVar<F> {
886    #[tracing::instrument(target = "r1cs")]
887    fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
888        match (self, other) {
889            (Self::Constant(c1), Self::Constant(c2)) => Ok(Boolean::Constant(c1 == c2)),
890            (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => {
891                let cs = v.cs.clone();
892                let c = AllocatedFp::new_constant(cs, c)?;
893                c.is_eq(v)
894            },
895            (Self::Var(v1), Self::Var(v2)) => v1.is_eq(v2),
896        }
897    }
898
899    #[tracing::instrument(target = "r1cs")]
900    fn conditional_enforce_equal(
901        &self,
902        other: &Self,
903        should_enforce: &Boolean<F>,
904    ) -> Result<(), SynthesisError> {
905        match (self, other) {
906            (Self::Constant(_), Self::Constant(_)) => Ok(()),
907            (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => {
908                let cs = v.cs.clone();
909                let c = AllocatedFp::new_constant(cs, c)?;
910                c.conditional_enforce_equal(v, should_enforce)
911            },
912            (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_equal(v2, should_enforce),
913        }
914    }
915
916    #[tracing::instrument(target = "r1cs")]
917    fn conditional_enforce_not_equal(
918        &self,
919        other: &Self,
920        should_enforce: &Boolean<F>,
921    ) -> Result<(), SynthesisError> {
922        match (self, other) {
923            (Self::Constant(_), Self::Constant(_)) => Ok(()),
924            (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => {
925                let cs = v.cs.clone();
926                let c = AllocatedFp::new_constant(cs, c)?;
927                c.conditional_enforce_not_equal(v, should_enforce)
928            },
929            (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_not_equal(v2, should_enforce),
930        }
931    }
932}
933
934impl<F: PrimeField> ToBitsGadget<F> for FpVar<F> {
935    #[tracing::instrument(target = "r1cs")]
936    fn to_bits_le(&self) -> Result<Vec<Boolean<F>>, SynthesisError> {
937        match self {
938            Self::Constant(_) => self.to_non_unique_bits_le(),
939            Self::Var(v) => v.to_bits_le(),
940        }
941    }
942
943    #[tracing::instrument(target = "r1cs")]
944    fn to_non_unique_bits_le(&self) -> Result<Vec<Boolean<F>>, SynthesisError> {
945        use ark_ff::BitIteratorLE;
946        match self {
947            Self::Constant(c) => Ok(BitIteratorLE::new(&c.into_bigint())
948                .take((F::MODULUS_BIT_SIZE) as usize)
949                .map(Boolean::constant)
950                .collect::<Vec<_>>()),
951            Self::Var(v) => v.to_non_unique_bits_le(),
952        }
953    }
954}
955
956impl<F: PrimeField> ToBytesGadget<F> for FpVar<F> {
957    /// Outputs the unique byte decomposition of `self` in *little-endian*
958    /// form.
959    #[tracing::instrument(target = "r1cs")]
960    fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
961        match self {
962            Self::Constant(c) => Ok(UInt8::constant_vec(
963                c.into_bigint().to_bytes_le().as_slice(),
964            )),
965            Self::Var(v) => v.to_bytes_le(),
966        }
967    }
968
969    #[tracing::instrument(target = "r1cs")]
970    fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
971        match self {
972            Self::Constant(c) => Ok(UInt8::constant_vec(
973                c.into_bigint().to_bytes_le().as_slice(),
974            )),
975            Self::Var(v) => v.to_non_unique_bytes_le(),
976        }
977    }
978}
979
980impl<F: PrimeField> ToConstraintFieldGadget<F> for FpVar<F> {
981    #[tracing::instrument(target = "r1cs")]
982    fn to_constraint_field(&self) -> Result<Vec<FpVar<F>>, SynthesisError> {
983        Ok(vec![self.clone()])
984    }
985}
986
987impl<F: PrimeField> CondSelectGadget<F> for FpVar<F> {
988    #[tracing::instrument(target = "r1cs")]
989    fn conditionally_select(
990        cond: &Boolean<F>,
991        true_value: &Self,
992        false_value: &Self,
993    ) -> Result<Self, SynthesisError> {
994        match cond {
995            &Boolean::Constant(true) => Ok(true_value.clone()),
996            &Boolean::Constant(false) => Ok(false_value.clone()),
997            _ => {
998                match (true_value, false_value) {
999                    (Self::Constant(t), Self::Constant(f)) => {
1000                        let is = AllocatedFp::from(cond.clone());
1001                        let not = AllocatedFp::from(!cond);
1002                        // cond * t + (1 - cond) * f
1003                        Ok(is.mul_constant(*t).add(&not.mul_constant(*f)).into())
1004                    },
1005                    (..) => {
1006                        let cs = cond.cs();
1007                        let true_value = match true_value {
1008                            Self::Constant(f) => AllocatedFp::new_constant(cs.clone(), f)?,
1009                            Self::Var(v) => v.clone(),
1010                        };
1011                        let false_value = match false_value {
1012                            Self::Constant(f) => AllocatedFp::new_constant(cs, f)?,
1013                            Self::Var(v) => v.clone(),
1014                        };
1015                        cond.select(&true_value, &false_value).map(Self::Var)
1016                    },
1017                }
1018            },
1019        }
1020    }
1021}
1022
1023/// Uses two bits to perform a lookup into a table
1024/// `b` is little-endian: `b[0]` is LSB.
1025impl<F: PrimeField> TwoBitLookupGadget<F> for FpVar<F> {
1026    type TableConstant = F;
1027
1028    #[tracing::instrument(target = "r1cs")]
1029    fn two_bit_lookup(b: &[Boolean<F>], c: &[Self::TableConstant]) -> Result<Self, SynthesisError> {
1030        debug_assert_eq!(b.len(), 2);
1031        debug_assert_eq!(c.len(), 4);
1032        if b.is_constant() {
1033            let lsb = usize::from(b[0].value()?);
1034            let msb = usize::from(b[1].value()?);
1035            let index = lsb + (msb << 1);
1036            Ok(Self::Constant(c[index]))
1037        } else {
1038            AllocatedFp::two_bit_lookup(b, c).map(Self::Var)
1039        }
1040    }
1041}
1042
1043impl<F: PrimeField> ThreeBitCondNegLookupGadget<F> for FpVar<F> {
1044    type TableConstant = F;
1045
1046    #[tracing::instrument(target = "r1cs")]
1047    fn three_bit_cond_neg_lookup(
1048        b: &[Boolean<F>],
1049        b0b1: &Boolean<F>,
1050        c: &[Self::TableConstant],
1051    ) -> Result<Self, SynthesisError> {
1052        debug_assert_eq!(b.len(), 3);
1053        debug_assert_eq!(c.len(), 4);
1054
1055        if b.cs().or(b0b1.cs()).is_none() {
1056            // We only have constants
1057
1058            let lsb = usize::from(b[0].value()?);
1059            let msb = usize::from(b[1].value()?);
1060            let index = lsb + (msb << 1);
1061            let intermediate = c[index];
1062
1063            let is_negative = b[2].value()?;
1064            let y = if is_negative {
1065                -intermediate
1066            } else {
1067                intermediate
1068            };
1069            Ok(Self::Constant(y))
1070        } else {
1071            AllocatedFp::three_bit_cond_neg_lookup(b, b0b1, c).map(Self::Var)
1072        }
1073    }
1074}
1075
1076impl<F: PrimeField> AllocVar<F, F> for FpVar<F> {
1077    fn new_variable<T: Borrow<F>>(
1078        cs: impl Into<Namespace<F>>,
1079        f: impl FnOnce() -> Result<T, SynthesisError>,
1080        mode: AllocationMode,
1081    ) -> Result<Self, SynthesisError> {
1082        if mode == AllocationMode::Constant {
1083            Ok(Self::Constant(*f()?.borrow()))
1084        } else {
1085            AllocatedFp::new_variable(cs, f, mode).map(Self::Var)
1086        }
1087    }
1088}
1089
1090impl<'a, F: PrimeField> Sum<&'a FpVar<F>> for FpVar<F> {
1091    fn sum<I: Iterator<Item = &'a FpVar<F>>>(iter: I) -> FpVar<F> {
1092        let mut sum_constants = F::zero();
1093        let sum_variables = FpVar::Var(AllocatedFp::<F>::add_many(iter.filter_map(|x| match x {
1094            FpVar::Constant(c) => {
1095                sum_constants += c;
1096                None
1097            },
1098            FpVar::Var(v) => Some(v),
1099        })));
1100
1101        let sum = sum_variables + sum_constants;
1102        sum
1103    }
1104}
1105
1106impl<'a, F: PrimeField> Sum<FpVar<F>> for FpVar<F> {
1107    fn sum<I: Iterator<Item = FpVar<F>>>(iter: I) -> FpVar<F> {
1108        let mut sum_constants = F::zero();
1109        let sum_variables = FpVar::Var(AllocatedFp::<F>::add_many(iter.filter_map(|x| match x {
1110            FpVar::Constant(c) => {
1111                sum_constants += c;
1112                None
1113            },
1114            FpVar::Var(v) => Some(v),
1115        })));
1116
1117        let sum = sum_variables + sum_constants;
1118        sum
1119    }
1120}
1121
1122#[cfg(test)]
1123mod test {
1124    use crate::{
1125        alloc::{AllocVar, AllocationMode},
1126        eq::EqGadget,
1127        fields::fp::FpVar,
1128        R1CSVar,
1129    };
1130    use ark_relations::r1cs::ConstraintSystem;
1131    use ark_std::{UniformRand, Zero};
1132    use ark_test_curves::bls12_381::Fr;
1133
1134    #[test]
1135    fn test_sum_fpvar() {
1136        let mut rng = ark_std::test_rng();
1137        let cs = ConstraintSystem::new_ref();
1138
1139        let mut sum_expected = Fr::zero();
1140
1141        let mut v = Vec::new();
1142        for _ in 0..10 {
1143            let a = Fr::rand(&mut rng);
1144            sum_expected += &a;
1145            v.push(
1146                FpVar::<Fr>::new_variable(cs.clone(), || Ok(a), AllocationMode::Constant).unwrap(),
1147            );
1148        }
1149        for _ in 0..10 {
1150            let a = Fr::rand(&mut rng);
1151            sum_expected += &a;
1152            v.push(
1153                FpVar::<Fr>::new_variable(cs.clone(), || Ok(a), AllocationMode::Witness).unwrap(),
1154            );
1155        }
1156
1157        let sum: FpVar<Fr> = v.iter().sum();
1158
1159        sum.enforce_equal(&FpVar::Constant(sum_expected)).unwrap();
1160
1161        assert!(cs.is_satisfied().unwrap());
1162        assert_eq!(sum.value().unwrap(), sum_expected);
1163    }
1164}