Skip to main content

ark_r1cs_std/fields/emulated_fp/
allocated_field_var.rs

1use super::{
2    params::{get_params, OptimizationType},
3    reduce::{bigint_to_basefield, limbs_to_bigint, Reducer},
4    AllocatedMulResultVar,
5};
6use crate::{convert::ToConstraintFieldGadget, fields::fp::FpVar, prelude::*};
7use ark_ff::{BigInteger, PrimeField};
8use ark_relations::{
9    gr1cs::{
10        ConstraintSystemRef, Namespace, OptimizationGoal, Result as R1CSResult, SynthesisError,
11    },
12    ns,
13};
14use ark_std::{
15    borrow::Borrow,
16    cmp::{max, min},
17    marker::PhantomData,
18    vec,
19    vec::Vec,
20};
21
22/// The allocated version of `EmulatedFpVar` (introduced below)
23#[derive(Debug)]
24#[must_use]
25pub struct AllocatedEmulatedFpVar<TargetF: PrimeField, BaseF: PrimeField> {
26    /// Constraint system reference
27    pub cs: ConstraintSystemRef<BaseF>,
28    /// The limbs, each of which is a BaseF gadget.
29    pub limbs: Vec<FpVar<BaseF>>,
30    /// Number of additions done over this gadget, using which the gadget
31    /// decides when to reduce.
32    pub num_of_additions_over_normal_form: BaseF,
33    /// Whether the limb representation is the normal form (using only the bits
34    /// specified in the parameters, and the representation is strictly within
35    /// the range of TargetF).
36    pub is_in_the_normal_form: bool,
37    #[doc(hidden)]
38    pub target_phantom: PhantomData<TargetF>,
39}
40
41impl<TargetF: PrimeField, BaseF: PrimeField> AllocatedEmulatedFpVar<TargetF, BaseF> {
42    /// Return cs
43    pub fn cs(&self) -> ConstraintSystemRef<BaseF> {
44        self.cs.clone()
45    }
46
47    /// Obtain the value of limbs
48    pub fn limbs_to_value(limbs: Vec<BaseF>, optimization_type: OptimizationType) -> TargetF {
49        let params = get_params(
50            TargetF::MODULUS_BIT_SIZE as usize,
51            BaseF::MODULUS_BIT_SIZE as usize,
52            optimization_type,
53        );
54
55        // Convert 2^{(params.bits_per_limb - 1)} into the TargetF and then double
56        // the base This is because 2^{(params.bits_per_limb)} might indeed be
57        // larger than the target field's prime.
58        let base_repr = TargetF::ONE.into_bigint() << (params.bits_per_limb - 1) as u32;
59
60        let mut base = TargetF::from_bigint(base_repr).unwrap();
61        base.double_in_place();
62
63        let mut result = TargetF::zero();
64        let mut power = TargetF::one();
65
66        for limb in limbs.iter().rev() {
67            let mut val = TargetF::zero();
68            let mut cur = TargetF::one();
69
70            for bit in limb.into_bigint().to_bits_be().iter().rev() {
71                if *bit {
72                    val += &cur;
73                }
74                cur.double_in_place();
75            }
76
77            result += &(val * power);
78            power *= &base;
79        }
80
81        result
82    }
83
84    /// Obtain the value of a emulated field element
85    pub fn value(&self) -> R1CSResult<TargetF> {
86        let mut limbs = Vec::new();
87        for limb in self.limbs.iter() {
88            limbs.push(limb.value()?);
89        }
90
91        Ok(Self::limbs_to_value(limbs, self.get_optimization_type()))
92    }
93
94    /// Obtain the emulated field element of a constant value
95    pub fn constant(cs: ConstraintSystemRef<BaseF>, value: TargetF) -> R1CSResult<Self> {
96        let optimization_type = match cs.optimization_goal() {
97            OptimizationGoal::None => OptimizationType::Constraints,
98            OptimizationGoal::Constraints => OptimizationType::Constraints,
99            OptimizationGoal::Weight => OptimizationType::Weight,
100        };
101
102        let limbs_value = Self::get_limbs_representations(&value, optimization_type)?;
103
104        let mut limbs = Vec::new();
105
106        for limb_value in limbs_value.iter() {
107            limbs.push(FpVar::<BaseF>::new_constant(ns!(cs, "limb"), limb_value)?);
108        }
109
110        Ok(Self {
111            cs,
112            limbs,
113            num_of_additions_over_normal_form: BaseF::zero(),
114            is_in_the_normal_form: true,
115            target_phantom: PhantomData,
116        })
117    }
118
119    /// Obtain the emulated field element of one
120    pub fn one(cs: ConstraintSystemRef<BaseF>) -> R1CSResult<Self> {
121        Self::constant(cs, TargetF::one())
122    }
123
124    /// Obtain the emulated field element of zero
125    pub fn zero(cs: ConstraintSystemRef<BaseF>) -> R1CSResult<Self> {
126        Self::constant(cs, TargetF::zero())
127    }
128
129    /// Add a emulated field element
130    #[tracing::instrument(target = "gr1cs")]
131    pub fn add(&self, other: &Self) -> R1CSResult<Self> {
132        assert_eq!(self.get_optimization_type(), other.get_optimization_type());
133
134        let mut limbs = Vec::new();
135        for (this_limb, other_limb) in self.limbs.iter().zip(other.limbs.iter()) {
136            limbs.push(this_limb + other_limb);
137        }
138
139        let mut res = Self {
140            cs: self.cs(),
141            limbs,
142            num_of_additions_over_normal_form: self
143                .num_of_additions_over_normal_form
144                .add(&other.num_of_additions_over_normal_form)
145                .add(&BaseF::one()),
146            is_in_the_normal_form: false,
147            target_phantom: PhantomData,
148        };
149
150        Reducer::<TargetF, BaseF>::post_add_reduce(&mut res)?;
151        Ok(res)
152    }
153
154    /// Add a constant
155    #[tracing::instrument(target = "gr1cs")]
156    pub fn add_constant(&self, other: &TargetF) -> R1CSResult<Self> {
157        let other_limbs = Self::get_limbs_representations(other, self.get_optimization_type())?;
158
159        let mut limbs = Vec::new();
160        for (this_limb, other_limb) in self.limbs.iter().zip(other_limbs.iter()) {
161            limbs.push(this_limb + *other_limb);
162        }
163
164        let mut res = Self {
165            cs: self.cs(),
166            limbs,
167            num_of_additions_over_normal_form: self
168                .num_of_additions_over_normal_form
169                .add(&BaseF::one()),
170            is_in_the_normal_form: false,
171            target_phantom: PhantomData,
172        };
173
174        Reducer::<TargetF, BaseF>::post_add_reduce(&mut res)?;
175
176        Ok(res)
177    }
178
179    /// Subtract a emulated field element, without the final reduction step
180    #[tracing::instrument(target = "gr1cs")]
181    pub fn sub_without_reduce(&self, other: &Self) -> R1CSResult<Self> {
182        assert_eq!(self.get_optimization_type(), other.get_optimization_type());
183
184        let params = get_params(
185            TargetF::MODULUS_BIT_SIZE as usize,
186            BaseF::MODULUS_BIT_SIZE as usize,
187            self.get_optimization_type(),
188        );
189
190        // Step 1: reduce the `other` if needed
191        let mut surfeit = overhead!(other.num_of_additions_over_normal_form + BaseF::one()) + 1;
192        let mut other = other.clone();
193        if (surfeit + params.bits_per_limb > BaseF::MODULUS_BIT_SIZE as usize - 1)
194            || (surfeit
195                + (TargetF::MODULUS_BIT_SIZE as usize
196                    - params.bits_per_limb * (params.num_limbs - 1))
197                > BaseF::MODULUS_BIT_SIZE as usize - 1)
198        {
199            Reducer::reduce(&mut other)?;
200            surfeit = overhead!(other.num_of_additions_over_normal_form + BaseF::ONE) + 1;
201        }
202
203        // Step 2: construct the padding
204        let mut pad_non_top_limb = BaseF::ONE.into_bigint();
205        let mut pad_top_limb = pad_non_top_limb;
206
207        pad_non_top_limb <<= (surfeit + params.bits_per_limb) as u32;
208        let pad_non_top_limb = BaseF::from_bigint(pad_non_top_limb).unwrap();
209
210        pad_top_limb <<= (surfeit + TargetF::MODULUS_BIT_SIZE as usize
211            - params.bits_per_limb * (params.num_limbs - 1)) as u32;
212        let pad_top_limb = BaseF::from_bigint(pad_top_limb).unwrap();
213
214        let mut pad_limbs = Vec::with_capacity(self.limbs.len());
215        pad_limbs.push(pad_top_limb);
216        for _ in 0..self.limbs.len() - 1 {
217            pad_limbs.push(pad_non_top_limb);
218        }
219
220        // Step 3: prepare to pad the padding to k * p for some k
221        let pad_to_kp_gap = Self::limbs_to_value(pad_limbs, self.get_optimization_type()).neg();
222        let pad_to_kp_limbs =
223            Self::get_limbs_representations(&pad_to_kp_gap, self.get_optimization_type())?;
224
225        // Step 4: the result is self + pad + pad_to_kp - other
226        let mut limbs = Vec::with_capacity(self.limbs.len());
227        for (i, ((this_limb, other_limb), pad_to_kp_limb)) in self
228            .limbs
229            .iter()
230            .zip(&other.limbs)
231            .zip(&pad_to_kp_limbs)
232            .enumerate()
233        {
234            if i != 0 {
235                limbs.push(this_limb + pad_non_top_limb + *pad_to_kp_limb - other_limb);
236            } else {
237                limbs.push(this_limb + pad_top_limb + *pad_to_kp_limb - other_limb);
238            }
239        }
240
241        let padding_bit_len = {
242            let mut one = BaseF::ONE.into_bigint();
243            one <<= surfeit as u32;
244            BaseF::from(one)
245        };
246        let result = AllocatedEmulatedFpVar::<TargetF, BaseF> {
247            cs: self.cs(),
248            limbs,
249            num_of_additions_over_normal_form: self.num_of_additions_over_normal_form // this_limb
250                + padding_bit_len // pad_non_top_limb / pad_top_limb
251                + BaseF::one(), // pad_to_kp_limb
252            is_in_the_normal_form: false,
253            target_phantom: PhantomData,
254        };
255
256        Ok(result)
257    }
258
259    /// Subtract a emulated field element
260    #[tracing::instrument(target = "gr1cs")]
261    pub fn sub(&self, other: &Self) -> R1CSResult<Self> {
262        assert_eq!(self.get_optimization_type(), other.get_optimization_type());
263
264        let mut result = self.sub_without_reduce(other)?;
265        Reducer::<TargetF, BaseF>::post_add_reduce(&mut result)?;
266        Ok(result)
267    }
268
269    /// Subtract a constant
270    #[tracing::instrument(target = "gr1cs")]
271    pub fn sub_constant(&self, other: &TargetF) -> R1CSResult<Self> {
272        self.sub(&Self::constant(self.cs(), *other)?)
273    }
274
275    /// Multiply a emulated field element
276    #[tracing::instrument(target = "gr1cs")]
277    pub fn mul(&self, other: &Self) -> R1CSResult<Self> {
278        assert_eq!(self.get_optimization_type(), other.get_optimization_type());
279
280        self.mul_without_reduce(&other)?.reduce()
281    }
282
283    /// Multiply a constant
284    pub fn mul_constant(&self, other: &TargetF) -> R1CSResult<Self> {
285        self.mul(&Self::constant(self.cs(), *other)?)
286    }
287
288    /// Compute the negate of a emulated field element
289    #[tracing::instrument(target = "gr1cs")]
290    pub fn negate(&self) -> R1CSResult<Self> {
291        Self::zero(self.cs())?.sub(self)
292    }
293
294    /// Compute the inverse of a emulated field element
295    #[tracing::instrument(target = "gr1cs")]
296    pub fn inverse(&self) -> R1CSResult<Self> {
297        let inverse = Self::new_witness(self.cs(), || {
298            Ok(self.value()?.inverse().unwrap_or(TargetF::ZERO))
299        })?;
300
301        let actual_result = self.clone().mul(&inverse)?;
302        actual_result.conditional_enforce_equal(&Self::one(self.cs())?, &Boolean::TRUE)?;
303        Ok(inverse)
304    }
305
306    /// Convert a `TargetF` element into limbs (not constraints).
307    /// This is a utility function intended
308    /// to be reused by a number of other functions.
309    pub fn get_limbs_representations(
310        elem: &TargetF,
311        optimization_type: OptimizationType,
312    ) -> R1CSResult<Vec<BaseF>> {
313        Self::get_limbs_representations_from_big_integer(&elem.into_bigint(), optimization_type)
314    }
315
316    /// Obtain the limbs directly from a big int
317    pub fn get_limbs_representations_from_big_integer(
318        elem: &<TargetF as PrimeField>::BigInt,
319        optimization_type: OptimizationType,
320    ) -> R1CSResult<Vec<BaseF>> {
321        let params = get_params(
322            TargetF::MODULUS_BIT_SIZE as usize,
323            BaseF::MODULUS_BIT_SIZE as usize,
324            optimization_type,
325        );
326
327        // push the lower limbs first
328        let mut limbs: Vec<BaseF> = Vec::new();
329        let mut cur = *elem;
330        for _ in 0..params.num_limbs {
331            let cur_bits = cur.to_bits_be(); // `to_bits` is big endian
332            let cur_mod_r = <BaseF as PrimeField>::BigInt::from_bits_be(
333                &cur_bits[cur_bits.len() - params.bits_per_limb..],
334            ); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
335            limbs.push(BaseF::from_bigint(cur_mod_r).unwrap());
336            cur >>= params.bits_per_limb as u32;
337        }
338
339        // then we reserve, so that the limbs are ``big limb first''
340        limbs.reverse();
341
342        Ok(limbs)
343    }
344
345    /// for advanced use, multiply and output the intermediate representations
346    /// (without reduction) This intermediate representations can be added
347    /// with each other, and they can later be reduced back to the
348    /// `EmulatedFpVar`.
349    #[tracing::instrument(target = "gr1cs")]
350    pub fn mul_without_reduce(
351        &self,
352        other: &Self,
353    ) -> R1CSResult<AllocatedMulResultVar<TargetF, BaseF>> {
354        assert_eq!(self.get_optimization_type(), other.get_optimization_type());
355
356        let params = get_params(
357            TargetF::MODULUS_BIT_SIZE as usize,
358            BaseF::MODULUS_BIT_SIZE as usize,
359            self.get_optimization_type(),
360        );
361
362        // Step 1: reduce `self` and `other` if neceessary
363        let mut self_reduced = self.clone();
364        let mut other_reduced = other.clone();
365        Reducer::<TargetF, BaseF>::pre_mul_reduce(&mut self_reduced, &mut other_reduced)?;
366
367        let mut prod_limbs = Vec::new();
368        if self.get_optimization_type() == OptimizationType::Weight {
369            let zero = FpVar::<BaseF>::zero();
370
371            for _ in 0..2 * params.num_limbs - 1 {
372                prod_limbs.push(zero.clone());
373            }
374
375            for i in 0..params.num_limbs {
376                for j in 0..params.num_limbs {
377                    prod_limbs[i + j] =
378                        &prod_limbs[i + j] + (&self_reduced.limbs[i] * &other_reduced.limbs[j]);
379                }
380            }
381        } else {
382            let cs = self.cs().or(other.cs());
383
384            for z_index in 0..2 * params.num_limbs - 1 {
385                prod_limbs.push(FpVar::new_witness(ns!(cs, "limb product"), || {
386                    let mut z_i = BaseF::zero();
387                    for i in 0..=min(params.num_limbs - 1, z_index) {
388                        let j = z_index - i;
389                        if j < params.num_limbs {
390                            z_i += &self_reduced.limbs[i]
391                                .value()?
392                                .mul(&other_reduced.limbs[j].value()?);
393                        }
394                    }
395
396                    Ok(z_i)
397                })?);
398            }
399
400            for c in 0..(2 * params.num_limbs - 1) {
401                let c_pows: Vec<_> = (0..(2 * params.num_limbs - 1))
402                    .map(|i| BaseF::from((c + 1) as u128).pow(&vec![i as u64]))
403                    .collect();
404
405                let x = self_reduced
406                    .limbs
407                    .iter()
408                    .zip(c_pows.iter())
409                    .map(|(var, c_pow)| var * *c_pow)
410                    .fold(FpVar::zero(), |sum, i| sum + i);
411
412                let y = other_reduced
413                    .limbs
414                    .iter()
415                    .zip(c_pows.iter())
416                    .map(|(var, c_pow)| var * *c_pow)
417                    .fold(FpVar::zero(), |sum, i| sum + i);
418
419                let z = prod_limbs
420                    .iter()
421                    .zip(c_pows.iter())
422                    .map(|(var, c_pow)| var * *c_pow)
423                    .fold(FpVar::zero(), |sum, i| sum + i);
424
425                z.enforce_equal(&(x * y))?;
426            }
427        }
428
429        Ok(AllocatedMulResultVar {
430            cs: self.cs(),
431            limbs: prod_limbs,
432            // New number is upper bounded by:
433            //
434            // (a+1)2^{bits_per_limb} * (b+1)2^{bits_per_limb} * m = (ab+a+b+1)*m*2^{2*bits_per_limb}
435            //
436            // where `a = self_reduced.num_of_additions_over_normal_form` and
437            //       `b = other_reduced.num_of_additions_over_normal_form`
438            // - why m pair: at cell m, there are m possible pairs (one limb from each var) that can add to cell m
439            //
440            // In theory, we can let `prod_of_num_of_additions = (m(ab+a+b+1)-1)`. But below, we use an overestimation.
441            prod_of_num_of_additions: (self_reduced.num_of_additions_over_normal_form
442                + BaseF::one())
443                * (other_reduced.num_of_additions_over_normal_form + BaseF::one())
444                * BaseF::from((params.num_limbs) as u32),
445            target_phantom: PhantomData,
446        })
447    }
448
449    pub(crate) fn frobenius_map(&self, _power: usize) -> R1CSResult<Self> {
450        Ok(self.clone())
451    }
452
453    pub(crate) fn conditional_enforce_equal(
454        &self,
455        other: &Self,
456        should_enforce: &Boolean<BaseF>,
457    ) -> R1CSResult<()> {
458        assert_eq!(self.get_optimization_type(), other.get_optimization_type());
459
460        let params = get_params(
461            TargetF::MODULUS_BIT_SIZE as usize,
462            BaseF::MODULUS_BIT_SIZE as usize,
463            self.get_optimization_type(),
464        );
465
466        // Get p
467        let p_representations = Self::get_limbs_representations_from_big_integer(
468            &<TargetF as PrimeField>::MODULUS,
469            self.get_optimization_type(),
470        )?;
471        let p_bigint = limbs_to_bigint(params.bits_per_limb, &p_representations);
472
473        let mut p_gadget_limbs = Vec::new();
474        for limb in p_representations.iter() {
475            p_gadget_limbs.push(FpVar::<BaseF>::Constant(*limb));
476        }
477
478        // Get delta = self - other
479        let cs = self.cs().or(other.cs()).or(should_enforce.cs());
480        let delta = self.sub_without_reduce(other)?;
481        let delta = should_enforce.select(&delta, &Self::zero(cs.clone())?)?;
482
483        // Allocate k = delta / p
484        let k_gadget = FpVar::<BaseF>::new_witness(ns!(cs, "k"), || {
485            let mut delta_limbs_values = Vec::<BaseF>::new();
486            for limb in delta.limbs.iter() {
487                delta_limbs_values.push(limb.value()?);
488            }
489
490            let delta_bigint = limbs_to_bigint(params.bits_per_limb, &delta_limbs_values);
491
492            Ok(bigint_to_basefield::<BaseF>(&(delta_bigint / p_bigint)))
493        })?;
494
495        let surfeit = overhead!(delta.num_of_additions_over_normal_form + BaseF::one()) + 1;
496        Reducer::<TargetF, BaseF>::limb_to_bits(&k_gadget, surfeit)?;
497
498        // Compute k * p
499        let mut kp_gadget_limbs = Vec::new();
500        for limb in p_gadget_limbs.iter() {
501            kp_gadget_limbs.push(limb * &k_gadget);
502        }
503
504        // Enforce delta = kp
505        Reducer::<TargetF, BaseF>::group_and_check_equality(
506            surfeit,
507            params.bits_per_limb,
508            params.bits_per_limb,
509            &delta.limbs,
510            &kp_gadget_limbs,
511        )?;
512
513        Ok(())
514    }
515
516    #[tracing::instrument(target = "gr1cs")]
517    pub(crate) fn conditional_enforce_not_equal(
518        &self,
519        other: &Self,
520        should_enforce: &Boolean<BaseF>,
521    ) -> R1CSResult<()> {
522        assert_eq!(self.get_optimization_type(), other.get_optimization_type());
523
524        let cs = self.cs().or(other.cs()).or(should_enforce.cs());
525
526        let _ = should_enforce
527            .select(&self.sub(other)?, &Self::one(cs)?)?
528            .inverse()?;
529
530        Ok(())
531    }
532
533    pub(crate) fn get_optimization_type(&self) -> OptimizationType {
534        match self.cs().optimization_goal() {
535            OptimizationGoal::None => OptimizationType::Constraints,
536            OptimizationGoal::Constraints => OptimizationType::Constraints,
537            OptimizationGoal::Weight => OptimizationType::Weight,
538        }
539    }
540
541    /// Allocates a new variable, but does not check that the allocation's limbs
542    /// are in-range.
543    fn new_variable_unchecked<T: Borrow<TargetF>>(
544        cs: impl Into<Namespace<BaseF>>,
545        f: impl FnOnce() -> Result<T, SynthesisError>,
546        mode: AllocationMode,
547    ) -> R1CSResult<Self> {
548        let ns = cs.into();
549        let cs = ns.cs();
550
551        let optimization_type = match cs.optimization_goal() {
552            OptimizationGoal::None => OptimizationType::Constraints,
553            OptimizationGoal::Constraints => OptimizationType::Constraints,
554            OptimizationGoal::Weight => OptimizationType::Weight,
555        };
556
557        let zero = TargetF::zero();
558
559        let elem = match f() {
560            Ok(t) => *(t.borrow()),
561            Err(_) => zero,
562        };
563        let elem_representations = Self::get_limbs_representations(&elem, optimization_type)?;
564        let mut limbs = Vec::new();
565
566        for limb in elem_representations.iter() {
567            limbs.push(FpVar::<BaseF>::new_variable(
568                ark_relations::ns!(cs, "alloc"),
569                || Ok(limb),
570                mode,
571            )?);
572        }
573
574        let num_of_additions_over_normal_form = if mode != AllocationMode::Witness {
575            BaseF::zero()
576        } else {
577            BaseF::one()
578        };
579
580        Ok(Self {
581            cs,
582            limbs,
583            num_of_additions_over_normal_form,
584            is_in_the_normal_form: mode != AllocationMode::Witness,
585            target_phantom: PhantomData,
586        })
587    }
588
589    /// Check that this element is in-range; i.e., each limb is in-range, and
590    /// the whole number is less than the modulus.
591    ///
592    /// Returns the bits of the element, in little-endian form
593    fn enforce_in_range(&self, cs: impl Into<Namespace<BaseF>>) -> R1CSResult<Vec<Boolean<BaseF>>> {
594        let ns = cs.into();
595        let cs = ns.cs();
596        let optimization_type = match cs.optimization_goal() {
597            OptimizationGoal::None => OptimizationType::Constraints,
598            OptimizationGoal::Constraints => OptimizationType::Constraints,
599            OptimizationGoal::Weight => OptimizationType::Weight,
600        };
601        let params = get_params(
602            TargetF::MODULUS_BIT_SIZE as usize,
603            BaseF::MODULUS_BIT_SIZE as usize,
604            optimization_type,
605        );
606        let mut bits = Vec::new();
607        for limb in self.limbs.iter().rev().take(params.num_limbs - 1) {
608            bits.extend(
609                Reducer::<TargetF, BaseF>::limb_to_bits(limb, params.bits_per_limb)?
610                    .into_iter()
611                    .rev(),
612            );
613        }
614
615        bits.extend(
616            Reducer::<TargetF, BaseF>::limb_to_bits(
617                &self.limbs[0],
618                TargetF::MODULUS_BIT_SIZE as usize - (params.num_limbs - 1) * params.bits_per_limb,
619            )?
620            .into_iter()
621            .rev(),
622        );
623        Ok(bits)
624    }
625
626    /// Allocates a new non-native field witness with value given by the
627    /// function `f`. Enforces that the field element has value
628    /// in `[0, modulus)`, and returns the bits of its binary representation.
629    /// The bits are in little-endian (i.e., the bit at index 0 is the LSB) and
630    /// the bit-vector is empty in non-witness allocation modes.
631    pub fn new_witness_with_le_bits<T: Borrow<TargetF>>(
632        cs: impl Into<Namespace<BaseF>>,
633        f: impl FnOnce() -> Result<T, SynthesisError>,
634    ) -> R1CSResult<(Self, Vec<Boolean<BaseF>>)> {
635        let ns = cs.into();
636        let cs = ns.cs();
637        let this = Self::new_variable_unchecked(ns!(cs, "alloc"), f, AllocationMode::Witness)?;
638        let bits = this.enforce_in_range(ns!(cs, "bits"))?;
639        Ok((this, bits))
640    }
641}
642
643impl<TargetF: PrimeField, BaseF: PrimeField> ToBitsGadget<BaseF>
644    for AllocatedEmulatedFpVar<TargetF, BaseF>
645{
646    #[tracing::instrument(target = "gr1cs")]
647    fn to_bits_le(&self) -> R1CSResult<Vec<Boolean<BaseF>>> {
648        let params = get_params(
649            TargetF::MODULUS_BIT_SIZE as usize,
650            BaseF::MODULUS_BIT_SIZE as usize,
651            self.get_optimization_type(),
652        );
653
654        // Reduce to the normal form
655        // Though, a malicious prover can make it slightly larger than p
656        let mut self_normal = self.clone();
657        Reducer::<TargetF, BaseF>::pre_eq_reduce(&mut self_normal)?;
658
659        // Therefore, we convert it to bits and enforce that it is in the field
660        let mut bits = Vec::<Boolean<BaseF>>::new();
661        for limb in self_normal.limbs.iter() {
662            bits.extend_from_slice(&Reducer::<TargetF, BaseF>::limb_to_bits(
663                &limb,
664                params.bits_per_limb,
665            )?);
666        }
667        bits.reverse();
668
669        let mut b = TargetF::characteristic().to_vec();
670        assert_eq!(b[0] % 2, 1);
671        b[0] -= 1; // This works, because the LSB is one, so there's no borrows.
672        let run = Boolean::<BaseF>::enforce_smaller_or_equal_than_le(&bits, b)?;
673
674        // We should always end in a "run" of zeros, because
675        // the characteristic is an odd prime. So, this should
676        // be empty.
677        assert!(run.is_empty());
678
679        Ok(bits)
680    }
681}
682
683impl<TargetF: PrimeField, BaseF: PrimeField> ToBytesGadget<BaseF>
684    for AllocatedEmulatedFpVar<TargetF, BaseF>
685{
686    #[tracing::instrument(target = "gr1cs")]
687    fn to_bytes_le(&self) -> R1CSResult<Vec<UInt8<BaseF>>> {
688        let mut bits = self.to_bits_le()?;
689
690        let num_bits = TargetF::BigInt::NUM_LIMBS * 64;
691        assert!(bits.len() <= num_bits);
692        bits.resize_with(num_bits, || Boolean::FALSE);
693
694        let bytes = bits.chunks(8).map(UInt8::from_bits_le).collect();
695        Ok(bytes)
696    }
697}
698
699impl<TargetF: PrimeField, BaseF: PrimeField> CondSelectGadget<BaseF>
700    for AllocatedEmulatedFpVar<TargetF, BaseF>
701{
702    #[tracing::instrument(target = "gr1cs")]
703    fn conditionally_select(
704        cond: &Boolean<BaseF>,
705        true_value: &Self,
706        false_value: &Self,
707    ) -> R1CSResult<Self> {
708        assert_eq!(
709            true_value.get_optimization_type(),
710            false_value.get_optimization_type()
711        );
712
713        let mut limbs_sel = Vec::with_capacity(true_value.limbs.len());
714
715        for (x, y) in true_value.limbs.iter().zip(&false_value.limbs) {
716            limbs_sel.push(FpVar::<BaseF>::conditionally_select(cond, x, y)?);
717        }
718
719        Ok(Self {
720            cs: true_value.cs().or(false_value.cs()),
721            limbs: limbs_sel,
722            num_of_additions_over_normal_form: max(
723                true_value.num_of_additions_over_normal_form,
724                false_value.num_of_additions_over_normal_form,
725            ),
726            is_in_the_normal_form: true_value.is_in_the_normal_form
727                && false_value.is_in_the_normal_form,
728            target_phantom: PhantomData,
729        })
730    }
731}
732
733impl<TargetF: PrimeField, BaseF: PrimeField> TwoBitLookupGadget<BaseF>
734    for AllocatedEmulatedFpVar<TargetF, BaseF>
735{
736    type TableConstant = TargetF;
737
738    #[tracing::instrument(target = "gr1cs")]
739    fn two_bit_lookup(
740        bits: &[Boolean<BaseF>],
741        constants: &[Self::TableConstant],
742    ) -> R1CSResult<Self> {
743        debug_assert!(bits.len() == 2);
744        debug_assert!(constants.len() == 4);
745
746        let cs = bits.cs();
747
748        let optimization_type = match cs.optimization_goal() {
749            OptimizationGoal::None => OptimizationType::Constraints,
750            OptimizationGoal::Constraints => OptimizationType::Constraints,
751            OptimizationGoal::Weight => OptimizationType::Weight,
752        };
753
754        let params = get_params(
755            TargetF::MODULUS_BIT_SIZE as usize,
756            BaseF::MODULUS_BIT_SIZE as usize,
757            optimization_type,
758        );
759        let mut limbs_constants = Vec::new();
760        for _ in 0..params.num_limbs {
761            limbs_constants.push(Vec::new());
762        }
763
764        for constant in constants.iter() {
765            let representations =
766                AllocatedEmulatedFpVar::<TargetF, BaseF>::get_limbs_representations(
767                    constant,
768                    optimization_type,
769                )?;
770
771            for (i, representation) in representations.iter().enumerate() {
772                limbs_constants[i].push(*representation);
773            }
774        }
775
776        let mut limbs = Vec::new();
777        for limbs_constant in limbs_constants.iter() {
778            limbs.push(FpVar::<BaseF>::two_bit_lookup(bits, limbs_constant)?);
779        }
780
781        Ok(AllocatedEmulatedFpVar::<TargetF, BaseF> {
782            cs,
783            limbs,
784            num_of_additions_over_normal_form: BaseF::zero(),
785            is_in_the_normal_form: true,
786            target_phantom: PhantomData,
787        })
788    }
789}
790
791impl<TargetF: PrimeField, BaseF: PrimeField> ThreeBitCondNegLookupGadget<BaseF>
792    for AllocatedEmulatedFpVar<TargetF, BaseF>
793{
794    type TableConstant = TargetF;
795
796    #[tracing::instrument(target = "gr1cs")]
797    fn three_bit_cond_neg_lookup(
798        bits: &[Boolean<BaseF>],
799        b0b1: &Boolean<BaseF>,
800        constants: &[Self::TableConstant],
801    ) -> R1CSResult<Self> {
802        debug_assert!(bits.len() == 3);
803        debug_assert!(constants.len() == 4);
804
805        let cs = bits.cs().or(b0b1.cs());
806
807        let optimization_type = match cs.optimization_goal() {
808            OptimizationGoal::None => OptimizationType::Constraints,
809            OptimizationGoal::Constraints => OptimizationType::Constraints,
810            OptimizationGoal::Weight => OptimizationType::Weight,
811        };
812
813        let params = get_params(
814            TargetF::MODULUS_BIT_SIZE as usize,
815            BaseF::MODULUS_BIT_SIZE as usize,
816            optimization_type,
817        );
818
819        let mut limbs_constants = Vec::new();
820        for _ in 0..params.num_limbs {
821            limbs_constants.push(Vec::new());
822        }
823
824        for constant in constants.iter() {
825            let representations =
826                AllocatedEmulatedFpVar::<TargetF, BaseF>::get_limbs_representations(
827                    constant,
828                    optimization_type,
829                )?;
830
831            for (i, representation) in representations.iter().enumerate() {
832                limbs_constants[i].push(*representation);
833            }
834        }
835
836        let mut limbs = Vec::new();
837        for limbs_constant in limbs_constants.iter() {
838            limbs.push(FpVar::<BaseF>::three_bit_cond_neg_lookup(
839                bits,
840                b0b1,
841                limbs_constant,
842            )?);
843        }
844
845        Ok(AllocatedEmulatedFpVar::<TargetF, BaseF> {
846            cs,
847            limbs,
848            num_of_additions_over_normal_form: BaseF::zero(),
849            is_in_the_normal_form: true,
850            target_phantom: PhantomData,
851        })
852    }
853}
854
855impl<TargetF: PrimeField, BaseF: PrimeField> AllocVar<TargetF, BaseF>
856    for AllocatedEmulatedFpVar<TargetF, BaseF>
857{
858    fn new_variable<T: Borrow<TargetF>>(
859        cs: impl Into<Namespace<BaseF>>,
860        f: impl FnOnce() -> Result<T, SynthesisError>,
861        mode: AllocationMode,
862    ) -> R1CSResult<Self> {
863        let ns = cs.into();
864        let cs = ns.cs();
865        let this = Self::new_variable_unchecked(ns!(cs, "alloc"), f, mode)?;
866        if mode == AllocationMode::Witness {
867            this.enforce_in_range(ns!(cs, "bits"))?;
868        }
869        Ok(this)
870    }
871}
872
873impl<TargetF: PrimeField, BaseF: PrimeField> ToConstraintFieldGadget<BaseF>
874    for AllocatedEmulatedFpVar<TargetF, BaseF>
875{
876    fn to_constraint_field(&self) -> R1CSResult<Vec<FpVar<BaseF>>> {
877        // provide a unique representation of the emulated variable
878        // step 1: convert it into a bit sequence
879        let bits = self.to_bits_le()?;
880
881        // step 2: obtain the parameters for weight-optimized (often, fewer limbs)
882        let params = get_params(
883            TargetF::MODULUS_BIT_SIZE as usize,
884            BaseF::MODULUS_BIT_SIZE as usize,
885            OptimizationType::Weight,
886        );
887
888        // step 3: assemble the limbs
889        let mut limbs = bits
890            .chunks(params.bits_per_limb)
891            .map(|chunk| {
892                let mut limb = FpVar::<BaseF>::zero();
893                let mut w = BaseF::one();
894                for b in chunk.iter() {
895                    limb += FpVar::from(b.clone()) * w;
896                    w.double_in_place();
897                }
898                limb
899            })
900            .collect::<Vec<FpVar<BaseF>>>();
901
902        limbs.reverse();
903
904        // step 4: output the limbs
905        Ok(limbs)
906    }
907}
908
909// Implementation of a few traits
910
911impl<TargetF: PrimeField, BaseF: PrimeField> Clone for AllocatedEmulatedFpVar<TargetF, BaseF> {
912    fn clone(&self) -> Self {
913        AllocatedEmulatedFpVar {
914            cs: self.cs(),
915            limbs: self.limbs.clone(),
916            num_of_additions_over_normal_form: self.num_of_additions_over_normal_form,
917            is_in_the_normal_form: self.is_in_the_normal_form,
918            target_phantom: PhantomData,
919        }
920    }
921}
922
923#[cfg(test)]
924mod test {
925    use ark_ec::{bls12::Bls12Config, pairing::Pairing};
926    use ark_ff::PrimeField;
927    use ark_relations::gr1cs::ConstraintSystem;
928
929    use crate::{
930        alloc::AllocVar,
931        fields::{
932            emulated_fp::{test::check_constraint, AllocatedEmulatedFpVar},
933            fp::FpVar,
934        },
935    };
936
937    #[test]
938    fn pr_157() {
939        type TargetF = <ark_bls12_381::Config as Bls12Config>::Fp;
940        type BaseF = <ark_bls12_377::Bls12_377 as Pairing>::ScalarField;
941
942        let cs = ConstraintSystem::new_ref();
943
944        let l0: AllocatedEmulatedFpVar<TargetF, BaseF> =
945            AllocatedEmulatedFpVar::new_input(cs.clone(), || {
946                Ok(TargetF::from(
947                    TargetF::from(1).into_bigint()
948                        << (<TargetF as PrimeField>::MODULUS_BIT_SIZE - 1),
949                ) + TargetF::from(-1))
950            })
951            .unwrap();
952
953        // Accumulate errors
954        let l1 = l0.sub(&l0).unwrap();
955        let l1 = l1.sub(&l1).unwrap();
956        let l1 = l1.sub(&l1).unwrap();
957        let l1 = l1.sub(&l1).unwrap();
958        let l1 = l1.sub(&l1).unwrap();
959
960        let l2 = l1.add(&l1).unwrap();
961
962        // Increase l1's surfeit
963        // - The goal is to ensure that the number of limbs, as grouped by `group_and_check_equality`,
964        //   falls near the threshold between 17 and 18 limbs.
965        // - This increases the chance that accumulated error in `sub` causes an overflow within
966        //   `group_and_check_equality`.
967        let mut l1 = l1;
968        for _ in 0..(293 - 242) {
969            l1 = l1.add(&l0).unwrap();
970        }
971
972        let _ = l1.mul(&l2).unwrap();
973        assert!(cs.is_satisfied().unwrap());
974    }
975
976    #[test]
977    fn pr_157_sub() {
978        type TargetF = <ark_bls12_381::Config as Bls12Config>::Fp;
979        type BaseF = <ark_bls12_377::Bls12_377 as Pairing>::ScalarField;
980
981        let self_limb_values = [
982            100, 2618, 1428, 2152, 2602, 1242, 2823, 511, 1752, 2058, 3599, 1113, 3207, 3601, 2736,
983            435, 1108, 2965, 2685, 1705, 1016, 1343, 1760, 2039, 1355, 1767, 2355, 1945, 3594,
984            4066, 1913, 2646,
985        ];
986        let self_num_of_additions_over_normal_form = 1;
987        let self_is_in_the_normal_form = false;
988        let other_limb_values = [
989            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
990            0, 0, 4,
991        ];
992        let other_num_of_additions_over_normal_form = 1;
993        let other_is_in_the_normal_form = false;
994
995        let cs = ConstraintSystem::new_ref();
996
997        let left_limb = self_limb_values
998            .iter()
999            .map(|v| FpVar::new_input(cs.clone(), || Ok(BaseF::from(*v))).unwrap())
1000            .collect();
1001        let left: AllocatedEmulatedFpVar<TargetF, BaseF> = AllocatedEmulatedFpVar {
1002            cs: cs.clone(),
1003            limbs: left_limb,
1004            num_of_additions_over_normal_form: BaseF::from(self_num_of_additions_over_normal_form),
1005            is_in_the_normal_form: self_is_in_the_normal_form,
1006            target_phantom: std::marker::PhantomData,
1007        };
1008
1009        let other_limb = other_limb_values
1010            .iter()
1011            .map(|v| FpVar::new_input(cs.clone(), || Ok(BaseF::from(*v))).unwrap())
1012            .collect();
1013        let right: AllocatedEmulatedFpVar<TargetF, BaseF> = AllocatedEmulatedFpVar {
1014            cs: cs.clone(),
1015            limbs: other_limb,
1016            num_of_additions_over_normal_form: BaseF::from(other_num_of_additions_over_normal_form),
1017            is_in_the_normal_form: other_is_in_the_normal_form,
1018            target_phantom: std::marker::PhantomData,
1019        };
1020
1021        let result = left.sub_without_reduce(&right).unwrap();
1022        assert!(check_constraint(&left));
1023        assert!(check_constraint(&right));
1024        assert!(check_constraint(&result));
1025    }
1026}