ark_nonnative_field/
nonnative_field_var.rs

1use crate::params::OptimizationType;
2use crate::{AllocatedNonNativeFieldVar, NonNativeFieldMulResultVar};
3use ark_ff::PrimeField;
4use ark_ff::{to_bytes, FpParameters};
5use ark_r1cs_std::boolean::Boolean;
6use ark_r1cs_std::fields::fp::FpVar;
7use ark_r1cs_std::fields::FieldVar;
8use ark_r1cs_std::prelude::*;
9use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget};
10use ark_relations::r1cs::Result as R1CSResult;
11use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError};
12use ark_std::hash::{Hash, Hasher};
13use ark_std::{borrow::Borrow, vec::Vec};
14
15/// A gadget for representing non-native (`TargetField`) field elements over the constraint field (`BaseField`).
16#[derive(Clone, Debug)]
17#[must_use]
18pub enum NonNativeFieldVar<TargetField: PrimeField, BaseField: PrimeField> {
19    /// Constant
20    Constant(TargetField),
21    /// Allocated gadget
22    Var(AllocatedNonNativeFieldVar<TargetField, BaseField>),
23}
24
25impl<TargetField: PrimeField, BaseField: PrimeField> PartialEq
26    for NonNativeFieldVar<TargetField, BaseField>
27{
28    fn eq(&self, other: &Self) -> bool {
29        self.value()
30            .unwrap_or_default()
31            .eq(&other.value().unwrap_or_default())
32    }
33}
34
35impl<TargetField: PrimeField, BaseField: PrimeField> Eq
36    for NonNativeFieldVar<TargetField, BaseField>
37{
38}
39
40impl<TargetField: PrimeField, BaseField: PrimeField> Hash
41    for NonNativeFieldVar<TargetField, BaseField>
42{
43    fn hash<H: Hasher>(&self, state: &mut H) {
44        self.value().unwrap_or_default().hash(state);
45    }
46}
47
48impl<TargetField: PrimeField, BaseField: PrimeField> R1CSVar<BaseField>
49    for NonNativeFieldVar<TargetField, BaseField>
50{
51    type Value = TargetField;
52
53    fn cs(&self) -> ConstraintSystemRef<BaseField> {
54        match self {
55            Self::Constant(_) => ConstraintSystemRef::None,
56            Self::Var(a) => a.cs(),
57        }
58    }
59
60    fn value(&self) -> R1CSResult<Self::Value> {
61        match self {
62            Self::Constant(v) => Ok(*v),
63            Self::Var(v) => v.value(),
64        }
65    }
66}
67
68impl<TargetField: PrimeField, BaseField: PrimeField> From<Boolean<BaseField>>
69    for NonNativeFieldVar<TargetField, BaseField>
70{
71    fn from(other: Boolean<BaseField>) -> Self {
72        if let Boolean::Constant(b) = other {
73            Self::Constant(<TargetField as From<u128>>::from(b as u128))
74        } else {
75            // `other` is a variable
76            let one = Self::Constant(TargetField::one());
77            let zero = Self::Constant(TargetField::zero());
78            Self::conditionally_select(&other, &one, &zero).unwrap()
79        }
80    }
81}
82
83impl<TargetField: PrimeField, BaseField: PrimeField>
84    From<AllocatedNonNativeFieldVar<TargetField, BaseField>>
85    for NonNativeFieldVar<TargetField, BaseField>
86{
87    fn from(other: AllocatedNonNativeFieldVar<TargetField, BaseField>) -> Self {
88        Self::Var(other)
89    }
90}
91
92impl<'a, TargetField: PrimeField, BaseField: PrimeField> FieldOpsBounds<'a, TargetField, Self>
93    for NonNativeFieldVar<TargetField, BaseField>
94{
95}
96
97impl<'a, TargetField: PrimeField, BaseField: PrimeField>
98    FieldOpsBounds<'a, TargetField, NonNativeFieldVar<TargetField, BaseField>>
99    for &'a NonNativeFieldVar<TargetField, BaseField>
100{
101}
102
103impl<TargetField: PrimeField, BaseField: PrimeField> FieldVar<TargetField, BaseField>
104    for NonNativeFieldVar<TargetField, BaseField>
105{
106    fn zero() -> Self {
107        Self::Constant(TargetField::zero())
108    }
109
110    fn one() -> Self {
111        Self::Constant(TargetField::one())
112    }
113
114    fn constant(v: TargetField) -> Self {
115        Self::Constant(v)
116    }
117
118    #[tracing::instrument(target = "r1cs")]
119    fn negate(&self) -> R1CSResult<Self> {
120        match self {
121            Self::Constant(c) => Ok(Self::Constant(-*c)),
122            Self::Var(v) => Ok(Self::Var(v.negate()?)),
123        }
124    }
125
126    #[tracing::instrument(target = "r1cs")]
127    fn inverse(&self) -> R1CSResult<Self> {
128        match self {
129            Self::Constant(c) => Ok(Self::Constant(c.inverse().unwrap_or_default())),
130            Self::Var(v) => Ok(Self::Var(v.inverse()?)),
131        }
132    }
133
134    #[tracing::instrument(target = "r1cs")]
135    fn frobenius_map(&self, power: usize) -> R1CSResult<Self> {
136        match self {
137            Self::Constant(c) => Ok(Self::Constant({
138                let mut tmp = *c;
139                tmp.frobenius_map(power);
140                tmp
141            })),
142            Self::Var(v) => Ok(Self::Var(v.frobenius_map(power)?)),
143        }
144    }
145}
146
147/****************************************************************************/
148/****************************************************************************/
149
150impl_bounded_ops!(
151    NonNativeFieldVar<TargetField, BaseField>,
152    TargetField,
153    Add,
154    add,
155    AddAssign,
156    add_assign,
157    |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
158        use NonNativeFieldVar::*;
159        match (this, other) {
160            (Constant(c1), Constant(c2)) => Constant(*c1 + c2),
161            (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.add_constant(c).unwrap()),
162            (Var(v1), Var(v2)) => Var(v1.add(v2).unwrap()),
163        }
164    },
165    |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: TargetField| { this + &NonNativeFieldVar::Constant(other) },
166    (TargetField: PrimeField, BaseField: PrimeField),
167);
168
169impl_bounded_ops!(
170    NonNativeFieldVar<TargetField, BaseField>,
171    TargetField,
172    Sub,
173    sub,
174    SubAssign,
175    sub_assign,
176    |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
177        use NonNativeFieldVar::*;
178        match (this, other) {
179            (Constant(c1), Constant(c2)) => Constant(*c1 - c2),
180            (Var(v), Constant(c)) => Var(v.sub_constant(c).unwrap()),
181            (Constant(c), Var(v)) => Var(v.sub_constant(c).unwrap().negate().unwrap()),
182            (Var(v1), Var(v2)) => Var(v1.sub(v2).unwrap()),
183        }
184    },
185    |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: TargetField| {
186        this - &NonNativeFieldVar::Constant(other)
187    },
188    (TargetField: PrimeField, BaseField: PrimeField),
189);
190
191impl_bounded_ops!(
192    NonNativeFieldVar<TargetField, BaseField>,
193    TargetField,
194    Mul,
195    mul,
196    MulAssign,
197    mul_assign,
198    |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
199        use NonNativeFieldVar::*;
200        match (this, other) {
201            (Constant(c1), Constant(c2)) => Constant(*c1 * c2),
202            (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.mul_constant(c).unwrap()),
203            (Var(v1), Var(v2)) => Var(v1.mul(v2).unwrap()),
204        }
205    },
206    |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: TargetField| {
207        if other.is_zero() {
208            NonNativeFieldVar::zero()
209        } else {
210            this * &NonNativeFieldVar::Constant(other)
211        }
212    },
213    (TargetField: PrimeField, BaseField: PrimeField),
214);
215
216/****************************************************************************/
217/****************************************************************************/
218
219impl<TargetField: PrimeField, BaseField: PrimeField> EqGadget<BaseField>
220    for NonNativeFieldVar<TargetField, BaseField>
221{
222    #[tracing::instrument(target = "r1cs")]
223    fn is_eq(&self, other: &Self) -> R1CSResult<Boolean<BaseField>> {
224        let cs = self.cs().or(other.cs());
225
226        if cs == ConstraintSystemRef::None {
227            Ok(Boolean::Constant(self.value()? == other.value()?))
228        } else {
229            let should_enforce_equal =
230                Boolean::new_witness(cs, || Ok(self.value()? == other.value()?))?;
231
232            self.conditional_enforce_equal(other, &should_enforce_equal)?;
233            self.conditional_enforce_not_equal(other, &should_enforce_equal.not())?;
234
235            Ok(should_enforce_equal)
236        }
237    }
238
239    #[tracing::instrument(target = "r1cs")]
240    fn conditional_enforce_equal(
241        &self,
242        other: &Self,
243        should_enforce: &Boolean<BaseField>,
244    ) -> R1CSResult<()> {
245        match (self, other) {
246            (Self::Constant(c1), Self::Constant(c2)) => {
247                if c1 != c2 {
248                    should_enforce.enforce_equal(&Boolean::FALSE)?;
249                }
250                Ok(())
251            }
252            (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => {
253                let cs = v.cs();
254                let c = AllocatedNonNativeFieldVar::new_constant(cs, c)?;
255                c.conditional_enforce_equal(v, should_enforce)
256            }
257            (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_equal(v2, should_enforce),
258        }
259    }
260
261    #[tracing::instrument(target = "r1cs")]
262    fn conditional_enforce_not_equal(
263        &self,
264        other: &Self,
265        should_enforce: &Boolean<BaseField>,
266    ) -> R1CSResult<()> {
267        match (self, other) {
268            (Self::Constant(c1), Self::Constant(c2)) => {
269                if c1 == c2 {
270                    should_enforce.enforce_equal(&Boolean::FALSE)?;
271                }
272                Ok(())
273            }
274            (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => {
275                let cs = v.cs();
276                let c = AllocatedNonNativeFieldVar::new_constant(cs, c)?;
277                c.conditional_enforce_not_equal(v, should_enforce)
278            }
279            (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_not_equal(v2, should_enforce),
280        }
281    }
282}
283
284impl<TargetField: PrimeField, BaseField: PrimeField> ToBitsGadget<BaseField>
285    for NonNativeFieldVar<TargetField, BaseField>
286{
287    #[tracing::instrument(target = "r1cs")]
288    fn to_bits_le(&self) -> R1CSResult<Vec<Boolean<BaseField>>> {
289        match self {
290            Self::Constant(_) => self.to_non_unique_bits_le(),
291            Self::Var(v) => v.to_bits_le(),
292        }
293    }
294
295    #[tracing::instrument(target = "r1cs")]
296    fn to_non_unique_bits_le(&self) -> R1CSResult<Vec<Boolean<BaseField>>> {
297        use ark_ff::BitIteratorLE;
298        match self {
299            Self::Constant(c) => Ok(BitIteratorLE::new(&c.into_repr())
300                .take((TargetField::Params::MODULUS_BITS) as usize)
301                .map(Boolean::constant)
302                .collect::<Vec<_>>()),
303            Self::Var(v) => v.to_non_unique_bits_le(),
304        }
305    }
306}
307
308impl<TargetField: PrimeField, BaseField: PrimeField> ToBytesGadget<BaseField>
309    for NonNativeFieldVar<TargetField, BaseField>
310{
311    /// Outputs the unique byte decomposition of `self` in *little-endian*
312    /// form.
313    #[tracing::instrument(target = "r1cs")]
314    fn to_bytes(&self) -> R1CSResult<Vec<UInt8<BaseField>>> {
315        match self {
316            Self::Constant(c) => Ok(UInt8::constant_vec(&to_bytes![c].unwrap())),
317            Self::Var(v) => v.to_bytes(),
318        }
319    }
320
321    #[tracing::instrument(target = "r1cs")]
322    fn to_non_unique_bytes(&self) -> R1CSResult<Vec<UInt8<BaseField>>> {
323        match self {
324            Self::Constant(c) => Ok(UInt8::constant_vec(&to_bytes![c].unwrap())),
325            Self::Var(v) => v.to_non_unique_bytes(),
326        }
327    }
328}
329
330impl<TargetField: PrimeField, BaseField: PrimeField> CondSelectGadget<BaseField>
331    for NonNativeFieldVar<TargetField, BaseField>
332{
333    #[tracing::instrument(target = "r1cs")]
334    fn conditionally_select(
335        cond: &Boolean<BaseField>,
336        true_value: &Self,
337        false_value: &Self,
338    ) -> R1CSResult<Self> {
339        match cond {
340            Boolean::Constant(true) => Ok(true_value.clone()),
341            Boolean::Constant(false) => Ok(false_value.clone()),
342            _ => {
343                let cs = cond.cs();
344                let true_value = match true_value {
345                    Self::Constant(f) => AllocatedNonNativeFieldVar::new_constant(cs.clone(), f)?,
346                    Self::Var(v) => v.clone(),
347                };
348                let false_value = match false_value {
349                    Self::Constant(f) => AllocatedNonNativeFieldVar::new_constant(cs, f)?,
350                    Self::Var(v) => v.clone(),
351                };
352                cond.select(&true_value, &false_value).map(Self::Var)
353            }
354        }
355    }
356}
357
358/// Uses two bits to perform a lookup into a table
359/// `b` is little-endian: `b[0]` is LSB.
360impl<TargetField: PrimeField, BaseField: PrimeField> TwoBitLookupGadget<BaseField>
361    for NonNativeFieldVar<TargetField, BaseField>
362{
363    type TableConstant = TargetField;
364
365    #[tracing::instrument(target = "r1cs")]
366    fn two_bit_lookup(b: &[Boolean<BaseField>], c: &[Self::TableConstant]) -> R1CSResult<Self> {
367        debug_assert_eq!(b.len(), 2);
368        debug_assert_eq!(c.len(), 4);
369        if b.cs().is_none() {
370            // We're in the constant case
371
372            let lsb = b[0].value()? as usize;
373            let msb = b[1].value()? as usize;
374            let index = lsb + (msb << 1);
375            Ok(Self::Constant(c[index]))
376        } else {
377            AllocatedNonNativeFieldVar::two_bit_lookup(b, c).map(Self::Var)
378        }
379    }
380}
381
382impl<TargetField: PrimeField, BaseField: PrimeField> ThreeBitCondNegLookupGadget<BaseField>
383    for NonNativeFieldVar<TargetField, BaseField>
384{
385    type TableConstant = TargetField;
386
387    #[tracing::instrument(target = "r1cs")]
388    fn three_bit_cond_neg_lookup(
389        b: &[Boolean<BaseField>],
390        b0b1: &Boolean<BaseField>,
391        c: &[Self::TableConstant],
392    ) -> R1CSResult<Self> {
393        debug_assert_eq!(b.len(), 3);
394        debug_assert_eq!(c.len(), 4);
395
396        if b.cs().or(b0b1.cs()).is_none() {
397            // We're in the constant case
398
399            let lsb = b[0].value()? as usize;
400            let msb = b[1].value()? as usize;
401            let index = lsb + (msb << 1);
402            let intermediate = c[index];
403
404            let is_negative = b[2].value()?;
405            let y = if is_negative {
406                -intermediate
407            } else {
408                intermediate
409            };
410            Ok(Self::Constant(y))
411        } else {
412            AllocatedNonNativeFieldVar::three_bit_cond_neg_lookup(b, b0b1, c).map(Self::Var)
413        }
414    }
415}
416
417impl<TargetField: PrimeField, BaseField: PrimeField> AllocVar<TargetField, BaseField>
418    for NonNativeFieldVar<TargetField, BaseField>
419{
420    fn new_variable<T: Borrow<TargetField>>(
421        cs: impl Into<Namespace<BaseField>>,
422        f: impl FnOnce() -> Result<T, SynthesisError>,
423        mode: AllocationMode,
424    ) -> R1CSResult<Self> {
425        let ns = cs.into();
426        let cs = ns.cs();
427
428        if cs == ConstraintSystemRef::None || mode == AllocationMode::Constant {
429            Ok(Self::Constant(*f()?.borrow()))
430        } else {
431            AllocatedNonNativeFieldVar::new_variable(cs, f, mode).map(Self::Var)
432        }
433    }
434}
435
436impl<TargetField: PrimeField, BaseField: PrimeField> ToConstraintFieldGadget<BaseField>
437    for NonNativeFieldVar<TargetField, BaseField>
438{
439    #[tracing::instrument(target = "r1cs")]
440    fn to_constraint_field(&self) -> R1CSResult<Vec<FpVar<BaseField>>> {
441        // Use one group element to represent the optimization type.
442        //
443        // By default, the constant is converted in the weight-optimized type, because it results in fewer elements.
444        match self {
445            Self::Constant(c) => Ok(AllocatedNonNativeFieldVar::get_limbs_representations(
446                c,
447                OptimizationType::Weight,
448            )?
449            .into_iter()
450            .map(FpVar::constant)
451            .collect()),
452            Self::Var(v) => v.to_constraint_field(),
453        }
454    }
455}
456
457impl<TargetField: PrimeField, BaseField: PrimeField> NonNativeFieldVar<TargetField, BaseField> {
458    /// The `mul_without_reduce` for `NonNativeFieldVar`
459    #[tracing::instrument(target = "r1cs")]
460    pub fn mul_without_reduce(
461        &self,
462        other: &Self,
463    ) -> R1CSResult<NonNativeFieldMulResultVar<TargetField, BaseField>> {
464        match self {
465            Self::Constant(c) => match other {
466                Self::Constant(other_c) => Ok(NonNativeFieldMulResultVar::Constant(*c * other_c)),
467                Self::Var(other_v) => {
468                    let self_v =
469                        AllocatedNonNativeFieldVar::<TargetField, BaseField>::new_constant(
470                            self.cs(),
471                            c,
472                        )?;
473                    Ok(NonNativeFieldMulResultVar::Var(
474                        other_v.mul_without_reduce(&self_v)?,
475                    ))
476                }
477            },
478            Self::Var(v) => {
479                let other_v = match other {
480                    Self::Constant(other_c) => {
481                        AllocatedNonNativeFieldVar::<TargetField, BaseField>::new_constant(
482                            self.cs(),
483                            other_c,
484                        )?
485                    }
486                    Self::Var(other_v) => other_v.clone(),
487                };
488                Ok(NonNativeFieldMulResultVar::Var(
489                    v.mul_without_reduce(&other_v)?,
490                ))
491            }
492        }
493    }
494}