Skip to main content

ark_r1cs_std/fields/emulated_fp/
field_var.rs

1use core::iter::Sum;
2
3use super::{params::OptimizationType, AllocatedEmulatedFpVar, MulResultVar};
4use crate::{
5    boolean::Boolean,
6    convert::{ToBitsGadget, ToBytesGadget, ToConstraintFieldGadget},
7    fields::{fp::FpVar, FieldVar},
8    prelude::*,
9    GR1CSVar,
10};
11use ark_ff::{BigInteger, PrimeField};
12use ark_relations::gr1cs::{ConstraintSystemRef, Namespace, Result as R1CSResult, SynthesisError};
13use ark_std::{
14    borrow::Borrow,
15    hash::{Hash, Hasher},
16    vec::Vec,
17};
18
19/// A gadget for representing non-native (`TargetF`) field elements over the
20/// constraint field (`BaseF`).
21#[derive(Clone, Debug)]
22#[must_use]
23pub enum EmulatedFpVar<TargetF: PrimeField, BaseF: PrimeField> {
24    /// Constant
25    Constant(TargetF),
26    /// Allocated gadget
27    Var(AllocatedEmulatedFpVar<TargetF, BaseF>),
28}
29
30impl<TargetF: PrimeField, BaseF: PrimeField> PartialEq for EmulatedFpVar<TargetF, BaseF> {
31    fn eq(&self, other: &Self) -> bool {
32        self.value()
33            .unwrap_or_default()
34            .eq(&other.value().unwrap_or_default())
35    }
36}
37
38impl<TargetF: PrimeField, BaseF: PrimeField> Eq for EmulatedFpVar<TargetF, BaseF> {}
39
40impl<TargetF: PrimeField, BaseF: PrimeField> Hash for EmulatedFpVar<TargetF, BaseF> {
41    fn hash<H: Hasher>(&self, state: &mut H) {
42        self.value().unwrap_or_default().hash(state);
43    }
44}
45
46impl<TargetF: PrimeField, BaseF: PrimeField> GR1CSVar<BaseF> for EmulatedFpVar<TargetF, BaseF> {
47    type Value = TargetF;
48
49    fn cs(&self) -> ConstraintSystemRef<BaseF> {
50        match self {
51            Self::Constant(_) => ConstraintSystemRef::None,
52            Self::Var(a) => a.cs(),
53        }
54    }
55
56    fn value(&self) -> R1CSResult<Self::Value> {
57        match self {
58            Self::Constant(v) => Ok(*v),
59            Self::Var(v) => v.value(),
60        }
61    }
62}
63
64impl<TargetF: PrimeField, BaseF: PrimeField> From<Boolean<BaseF>>
65    for EmulatedFpVar<TargetF, BaseF>
66{
67    fn from(other: Boolean<BaseF>) -> Self {
68        if let Boolean::Constant(b) = other {
69            Self::Constant(<TargetF as From<u128>>::from(b as u128))
70        } else {
71            // `other` is a variable
72            let one = Self::Constant(TargetF::one());
73            let zero = Self::Constant(TargetF::zero());
74            Self::conditionally_select(&other, &one, &zero).unwrap()
75        }
76    }
77}
78
79impl<TargetF: PrimeField, BaseF: PrimeField> From<AllocatedEmulatedFpVar<TargetF, BaseF>>
80    for EmulatedFpVar<TargetF, BaseF>
81{
82    fn from(other: AllocatedEmulatedFpVar<TargetF, BaseF>) -> Self {
83        Self::Var(other)
84    }
85}
86
87impl<'a, TargetF: PrimeField, BaseF: PrimeField> FieldOpsBounds<'a, TargetF, Self>
88    for EmulatedFpVar<TargetF, BaseF>
89{
90}
91
92impl<'a, TargetF: PrimeField, BaseF: PrimeField>
93    FieldOpsBounds<'a, TargetF, EmulatedFpVar<TargetF, BaseF>>
94    for &'a EmulatedFpVar<TargetF, BaseF>
95{
96}
97
98impl<TargetF: PrimeField, BaseF: PrimeField> FieldVar<TargetF, BaseF>
99    for EmulatedFpVar<TargetF, BaseF>
100{
101    fn zero() -> Self {
102        Self::Constant(TargetF::zero())
103    }
104
105    fn one() -> Self {
106        Self::Constant(TargetF::one())
107    }
108
109    fn constant(v: TargetF) -> Self {
110        Self::Constant(v)
111    }
112
113    #[tracing::instrument(target = "gr1cs")]
114    fn negate(&self) -> R1CSResult<Self> {
115        match self {
116            Self::Constant(c) => Ok(Self::Constant(-*c)),
117            Self::Var(v) => Ok(Self::Var(v.negate()?)),
118        }
119    }
120
121    #[tracing::instrument(target = "gr1cs")]
122    fn inverse(&self) -> R1CSResult<Self> {
123        match self {
124            Self::Constant(c) => Ok(Self::Constant(c.inverse().unwrap_or_default())),
125            Self::Var(v) => Ok(Self::Var(v.inverse()?)),
126        }
127    }
128
129    #[tracing::instrument(target = "gr1cs")]
130    fn frobenius_map(&self, power: usize) -> R1CSResult<Self> {
131        match self {
132            Self::Constant(c) => Ok(Self::Constant({
133                let mut tmp = *c;
134                tmp.frobenius_map_in_place(power);
135                tmp
136            })),
137            Self::Var(v) => Ok(Self::Var(v.frobenius_map(power)?)),
138        }
139    }
140}
141
142impl_bounded_ops!(
143    EmulatedFpVar<TargetF, BaseF>,
144    TargetF,
145    Add,
146    add,
147    AddAssign,
148    add_assign,
149    |this: &'a EmulatedFpVar<TargetF, BaseF>, other: &'a EmulatedFpVar<TargetF, BaseF>| {
150        use EmulatedFpVar::*;
151        match (this, other) {
152            (Constant(c1), Constant(c2)) => Constant(*c1 + c2),
153            (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.add_constant(c).unwrap()),
154            (Var(v1), Var(v2)) => Var(v1.add(v2).unwrap()),
155        }
156    },
157    |this: &'a EmulatedFpVar<TargetF, BaseF>, other: TargetF| { this + &EmulatedFpVar::Constant(other) },
158    (TargetF: PrimeField, BaseF: PrimeField),
159);
160
161impl_bounded_ops!(
162    EmulatedFpVar<TargetF, BaseF>,
163    TargetF,
164    Sub,
165    sub,
166    SubAssign,
167    sub_assign,
168    |this: &'a EmulatedFpVar<TargetF, BaseF>, other: &'a EmulatedFpVar<TargetF, BaseF>| {
169        use EmulatedFpVar::*;
170        match (this, other) {
171            (Constant(c1), Constant(c2)) => Constant(*c1 - c2),
172            (Var(v), Constant(c)) => Var(v.sub_constant(c).unwrap()),
173            (Constant(c), Var(v)) => Var(v.sub_constant(c).unwrap().negate().unwrap()),
174            (Var(v1), Var(v2)) => Var(v1.sub(v2).unwrap()),
175        }
176    },
177    |this: &'a EmulatedFpVar<TargetF, BaseF>, other: TargetF| {
178        this - &EmulatedFpVar::Constant(other)
179    },
180    (TargetF: PrimeField, BaseF: PrimeField),
181);
182
183impl_bounded_ops!(
184    EmulatedFpVar<TargetF, BaseF>,
185    TargetF,
186    Mul,
187    mul,
188    MulAssign,
189    mul_assign,
190    |this: &'a EmulatedFpVar<TargetF, BaseF>, other: &'a EmulatedFpVar<TargetF, BaseF>| {
191        use EmulatedFpVar::*;
192        match (this, other) {
193            (Constant(c1), Constant(c2)) => Constant(*c1 * c2),
194            (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.mul_constant(c).unwrap()),
195            (Var(v1), Var(v2)) => Var(v1.mul(v2).unwrap()),
196        }
197    },
198    |this: &'a EmulatedFpVar<TargetF, BaseF>, other: TargetF| {
199        if other.is_zero() {
200            EmulatedFpVar::zero()
201        } else {
202            this * &EmulatedFpVar::Constant(other)
203        }
204    },
205    (TargetF: PrimeField, BaseF: PrimeField),
206);
207
208/// *************************************************************************
209/// *************************************************************************
210
211impl<TargetF: PrimeField, BaseF: PrimeField> EqGadget<BaseF> for EmulatedFpVar<TargetF, BaseF> {
212    #[tracing::instrument(target = "gr1cs")]
213    fn is_eq(&self, other: &Self) -> R1CSResult<Boolean<BaseF>> {
214        let cs = self.cs().or(other.cs());
215
216        if cs == ConstraintSystemRef::None {
217            Ok(Boolean::Constant(self.value()? == other.value()?))
218        } else {
219            let should_enforce_equal =
220                Boolean::new_witness(cs, || Ok(self.value()? == other.value()?))?;
221
222            self.conditional_enforce_equal(other, &should_enforce_equal)?;
223            self.conditional_enforce_not_equal(other, &!&should_enforce_equal)?;
224
225            Ok(should_enforce_equal)
226        }
227    }
228
229    #[tracing::instrument(target = "gr1cs")]
230    fn conditional_enforce_equal(
231        &self,
232        other: &Self,
233        should_enforce: &Boolean<BaseF>,
234    ) -> R1CSResult<()> {
235        match (self, other) {
236            (Self::Constant(c1), Self::Constant(c2)) => {
237                if c1 != c2 {
238                    should_enforce.enforce_equal(&Boolean::FALSE)?;
239                }
240                Ok(())
241            },
242            (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => {
243                let cs = v.cs();
244                let c = AllocatedEmulatedFpVar::new_constant(cs, c)?;
245                c.conditional_enforce_equal(v, should_enforce)
246            },
247            (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_equal(v2, should_enforce),
248        }
249    }
250
251    #[tracing::instrument(target = "gr1cs")]
252    fn conditional_enforce_not_equal(
253        &self,
254        other: &Self,
255        should_enforce: &Boolean<BaseF>,
256    ) -> R1CSResult<()> {
257        match (self, other) {
258            (Self::Constant(c1), Self::Constant(c2)) => {
259                if c1 == c2 {
260                    should_enforce.enforce_equal(&Boolean::FALSE)?;
261                }
262                Ok(())
263            },
264            (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => {
265                let cs = v.cs();
266                let c = AllocatedEmulatedFpVar::new_constant(cs, c)?;
267                c.conditional_enforce_not_equal(v, should_enforce)
268            },
269            (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_not_equal(v2, should_enforce),
270        }
271    }
272}
273
274impl<TargetF: PrimeField, BaseF: PrimeField> ToBitsGadget<BaseF> for EmulatedFpVar<TargetF, BaseF> {
275    #[tracing::instrument(target = "gr1cs")]
276    fn to_bits_le(&self) -> R1CSResult<Vec<Boolean<BaseF>>> {
277        match self {
278            Self::Constant(_) => self.to_non_unique_bits_le(),
279            Self::Var(v) => v.to_bits_le(),
280        }
281    }
282
283    #[tracing::instrument(target = "gr1cs")]
284    fn to_non_unique_bits_le(&self) -> R1CSResult<Vec<Boolean<BaseF>>> {
285        use ark_ff::BitIteratorLE;
286        match self {
287            Self::Constant(c) => Ok(BitIteratorLE::new(&c.into_bigint())
288                .take((TargetF::MODULUS_BIT_SIZE) as usize)
289                .map(Boolean::constant)
290                .collect::<Vec<_>>()),
291            Self::Var(v) => v.to_non_unique_bits_le(),
292        }
293    }
294}
295
296impl<TargetF: PrimeField, BaseF: PrimeField> ToBytesGadget<BaseF>
297    for EmulatedFpVar<TargetF, BaseF>
298{
299    /// Outputs the unique byte decomposition of `self` in *little-endian*
300    /// form.
301    #[tracing::instrument(target = "gr1cs")]
302    fn to_bytes_le(&self) -> R1CSResult<Vec<UInt8<BaseF>>> {
303        match self {
304            Self::Constant(c) => Ok(UInt8::constant_vec(
305                c.into_bigint().to_bytes_le().as_slice(),
306            )),
307
308            Self::Var(v) => v.to_bytes_le(),
309        }
310    }
311
312    #[tracing::instrument(target = "gr1cs")]
313    fn to_non_unique_bytes_le(&self) -> R1CSResult<Vec<UInt8<BaseF>>> {
314        match self {
315            Self::Constant(c) => Ok(UInt8::constant_vec(
316                c.into_bigint().to_bytes_le().as_slice(),
317            )),
318            Self::Var(v) => v.to_non_unique_bytes_le(),
319        }
320    }
321}
322
323impl<TargetF: PrimeField, BaseF: PrimeField> CondSelectGadget<BaseF>
324    for EmulatedFpVar<TargetF, BaseF>
325{
326    #[tracing::instrument(target = "gr1cs")]
327    fn conditionally_select(
328        cond: &Boolean<BaseF>,
329        true_value: &Self,
330        false_value: &Self,
331    ) -> R1CSResult<Self> {
332        match cond {
333            &Boolean::Constant(true) => Ok(true_value.clone()),
334            &Boolean::Constant(false) => Ok(false_value.clone()),
335            _ => {
336                let cs = cond.cs();
337                let true_value = match true_value {
338                    Self::Constant(f) => AllocatedEmulatedFpVar::new_constant(cs.clone(), f)?,
339                    Self::Var(v) => v.clone(),
340                };
341                let false_value = match false_value {
342                    Self::Constant(f) => AllocatedEmulatedFpVar::new_constant(cs, f)?,
343                    Self::Var(v) => v.clone(),
344                };
345                cond.select(&true_value, &false_value).map(Self::Var)
346            },
347        }
348    }
349}
350
351/// Uses two bits to perform a lookup into a table
352/// `b` is little-endian: `b[0]` is LSB.
353impl<TargetF: PrimeField, BaseF: PrimeField> TwoBitLookupGadget<BaseF>
354    for EmulatedFpVar<TargetF, BaseF>
355{
356    type TableConstant = TargetF;
357
358    #[tracing::instrument(target = "gr1cs")]
359    fn two_bit_lookup(b: &[Boolean<BaseF>], c: &[Self::TableConstant]) -> R1CSResult<Self> {
360        debug_assert_eq!(b.len(), 2);
361        debug_assert_eq!(c.len(), 4);
362        if b.cs().is_none() {
363            // We're in the constant case
364
365            let lsb = b[0].value()? as usize;
366            let msb = b[1].value()? as usize;
367            let index = lsb + (msb << 1);
368            Ok(Self::Constant(c[index]))
369        } else {
370            AllocatedEmulatedFpVar::two_bit_lookup(b, c).map(Self::Var)
371        }
372    }
373}
374
375impl<TargetF: PrimeField, BaseF: PrimeField> ThreeBitCondNegLookupGadget<BaseF>
376    for EmulatedFpVar<TargetF, BaseF>
377{
378    type TableConstant = TargetF;
379
380    #[tracing::instrument(target = "gr1cs")]
381    fn three_bit_cond_neg_lookup(
382        b: &[Boolean<BaseF>],
383        b0b1: &Boolean<BaseF>,
384        c: &[Self::TableConstant],
385    ) -> R1CSResult<Self> {
386        debug_assert_eq!(b.len(), 3);
387        debug_assert_eq!(c.len(), 4);
388
389        if b.cs().or(b0b1.cs()).is_none() {
390            // We're in the constant case
391
392            let lsb = b[0].value()? as usize;
393            let msb = b[1].value()? as usize;
394            let index = lsb + (msb << 1);
395            let intermediate = c[index];
396
397            let is_negative = b[2].value()?;
398            let y = if is_negative {
399                -intermediate
400            } else {
401                intermediate
402            };
403            Ok(Self::Constant(y))
404        } else {
405            AllocatedEmulatedFpVar::three_bit_cond_neg_lookup(b, b0b1, c).map(Self::Var)
406        }
407    }
408}
409
410impl<TargetF: PrimeField, BaseF: PrimeField> AllocVar<TargetF, BaseF>
411    for EmulatedFpVar<TargetF, BaseF>
412{
413    fn new_variable<T: Borrow<TargetF>>(
414        cs: impl Into<Namespace<BaseF>>,
415        f: impl FnOnce() -> Result<T, SynthesisError>,
416        mode: AllocationMode,
417    ) -> R1CSResult<Self> {
418        let ns = cs.into();
419        let cs = ns.cs();
420
421        if cs == ConstraintSystemRef::None || mode == AllocationMode::Constant {
422            Ok(Self::Constant(*f()?.borrow()))
423        } else {
424            AllocatedEmulatedFpVar::new_variable(cs, f, mode).map(Self::Var)
425        }
426    }
427}
428
429impl<TargetF: PrimeField, BaseF: PrimeField> ToConstraintFieldGadget<BaseF>
430    for EmulatedFpVar<TargetF, BaseF>
431{
432    #[tracing::instrument(target = "gr1cs")]
433    fn to_constraint_field(&self) -> R1CSResult<Vec<FpVar<BaseF>>> {
434        // Use one group element to represent the optimization type.
435        //
436        // By default, the constant is converted in the weight-optimized type, because
437        // it results in fewer elements.
438        match self {
439            Self::Constant(c) => Ok(AllocatedEmulatedFpVar::get_limbs_representations(
440                c,
441                OptimizationType::Weight,
442            )?
443            .into_iter()
444            .map(FpVar::constant)
445            .collect()),
446            Self::Var(v) => v.to_constraint_field(),
447        }
448    }
449}
450
451impl<TargetF: PrimeField, BaseF: PrimeField> EmulatedFpVar<TargetF, BaseF> {
452    /// The `mul_without_reduce` for `EmulatedFpVar`
453    #[tracing::instrument(target = "gr1cs")]
454    pub fn mul_without_reduce(&self, other: &Self) -> R1CSResult<MulResultVar<TargetF, BaseF>> {
455        match self {
456            Self::Constant(c) => match other {
457                Self::Constant(other_c) => Ok(MulResultVar::Constant(*c * other_c)),
458                Self::Var(other_v) => {
459                    let self_v =
460                        AllocatedEmulatedFpVar::<TargetF, BaseF>::new_constant(self.cs(), c)?;
461                    Ok(MulResultVar::Var(other_v.mul_without_reduce(&self_v)?))
462                },
463            },
464            Self::Var(v) => {
465                let other_v = match other {
466                    Self::Constant(other_c) => {
467                        AllocatedEmulatedFpVar::<TargetF, BaseF>::new_constant(self.cs(), other_c)?
468                    },
469                    Self::Var(other_v) => other_v.clone(),
470                };
471                Ok(MulResultVar::Var(v.mul_without_reduce(&other_v)?))
472            },
473        }
474    }
475}
476
477impl<TargetF: PrimeField, BaseF: PrimeField> Sum<Self> for EmulatedFpVar<TargetF, BaseF> {
478    #[tracing::instrument(target = "gr1cs", skip(iter))]
479    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
480        iter.fold(Self::zero(), |acc, x| acc + x)
481    }
482}
483
484impl<'a, TargetF: PrimeField, BaseF: PrimeField> Sum<&'a Self> for EmulatedFpVar<TargetF, BaseF> {
485    #[tracing::instrument(target = "gr1cs", skip(iter))]
486    fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
487        iter.fold(Self::zero(), |acc, x| acc + x)
488    }
489}