1use crate::params::OptimizationType;
2use crate::{AllocatedNonNativeFieldVar, NonNativeFieldMulResultVar};
3use ark_ff::PrimeField;
4use ark_ff::{to_bytes, FpParameters};
5use ark_r1cs_std::boolean::Boolean;
6use ark_r1cs_std::fields::fp::FpVar;
7use ark_r1cs_std::fields::FieldVar;
8use ark_r1cs_std::prelude::*;
9use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget};
10use ark_relations::r1cs::Result as R1CSResult;
11use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError};
12use ark_std::hash::{Hash, Hasher};
13use ark_std::{borrow::Borrow, vec::Vec};
14
15#[derive(Clone, Debug)]
17#[must_use]
18pub enum NonNativeFieldVar<TargetField: PrimeField, BaseField: PrimeField> {
19 Constant(TargetField),
21 Var(AllocatedNonNativeFieldVar<TargetField, BaseField>),
23}
24
25impl<TargetField: PrimeField, BaseField: PrimeField> PartialEq
26 for NonNativeFieldVar<TargetField, BaseField>
27{
28 fn eq(&self, other: &Self) -> bool {
29 self.value()
30 .unwrap_or_default()
31 .eq(&other.value().unwrap_or_default())
32 }
33}
34
35impl<TargetField: PrimeField, BaseField: PrimeField> Eq
36 for NonNativeFieldVar<TargetField, BaseField>
37{
38}
39
40impl<TargetField: PrimeField, BaseField: PrimeField> Hash
41 for NonNativeFieldVar<TargetField, BaseField>
42{
43 fn hash<H: Hasher>(&self, state: &mut H) {
44 self.value().unwrap_or_default().hash(state);
45 }
46}
47
48impl<TargetField: PrimeField, BaseField: PrimeField> R1CSVar<BaseField>
49 for NonNativeFieldVar<TargetField, BaseField>
50{
51 type Value = TargetField;
52
53 fn cs(&self) -> ConstraintSystemRef<BaseField> {
54 match self {
55 Self::Constant(_) => ConstraintSystemRef::None,
56 Self::Var(a) => a.cs(),
57 }
58 }
59
60 fn value(&self) -> R1CSResult<Self::Value> {
61 match self {
62 Self::Constant(v) => Ok(*v),
63 Self::Var(v) => v.value(),
64 }
65 }
66}
67
68impl<TargetField: PrimeField, BaseField: PrimeField> From<Boolean<BaseField>>
69 for NonNativeFieldVar<TargetField, BaseField>
70{
71 fn from(other: Boolean<BaseField>) -> Self {
72 if let Boolean::Constant(b) = other {
73 Self::Constant(<TargetField as From<u128>>::from(b as u128))
74 } else {
75 let one = Self::Constant(TargetField::one());
77 let zero = Self::Constant(TargetField::zero());
78 Self::conditionally_select(&other, &one, &zero).unwrap()
79 }
80 }
81}
82
83impl<TargetField: PrimeField, BaseField: PrimeField>
84 From<AllocatedNonNativeFieldVar<TargetField, BaseField>>
85 for NonNativeFieldVar<TargetField, BaseField>
86{
87 fn from(other: AllocatedNonNativeFieldVar<TargetField, BaseField>) -> Self {
88 Self::Var(other)
89 }
90}
91
92impl<'a, TargetField: PrimeField, BaseField: PrimeField> FieldOpsBounds<'a, TargetField, Self>
93 for NonNativeFieldVar<TargetField, BaseField>
94{
95}
96
97impl<'a, TargetField: PrimeField, BaseField: PrimeField>
98 FieldOpsBounds<'a, TargetField, NonNativeFieldVar<TargetField, BaseField>>
99 for &'a NonNativeFieldVar<TargetField, BaseField>
100{
101}
102
103impl<TargetField: PrimeField, BaseField: PrimeField> FieldVar<TargetField, BaseField>
104 for NonNativeFieldVar<TargetField, BaseField>
105{
106 fn zero() -> Self {
107 Self::Constant(TargetField::zero())
108 }
109
110 fn one() -> Self {
111 Self::Constant(TargetField::one())
112 }
113
114 fn constant(v: TargetField) -> Self {
115 Self::Constant(v)
116 }
117
118 #[tracing::instrument(target = "r1cs")]
119 fn negate(&self) -> R1CSResult<Self> {
120 match self {
121 Self::Constant(c) => Ok(Self::Constant(-*c)),
122 Self::Var(v) => Ok(Self::Var(v.negate()?)),
123 }
124 }
125
126 #[tracing::instrument(target = "r1cs")]
127 fn inverse(&self) -> R1CSResult<Self> {
128 match self {
129 Self::Constant(c) => Ok(Self::Constant(c.inverse().unwrap_or_default())),
130 Self::Var(v) => Ok(Self::Var(v.inverse()?)),
131 }
132 }
133
134 #[tracing::instrument(target = "r1cs")]
135 fn frobenius_map(&self, power: usize) -> R1CSResult<Self> {
136 match self {
137 Self::Constant(c) => Ok(Self::Constant({
138 let mut tmp = *c;
139 tmp.frobenius_map(power);
140 tmp
141 })),
142 Self::Var(v) => Ok(Self::Var(v.frobenius_map(power)?)),
143 }
144 }
145}
146
147impl_bounded_ops!(
151 NonNativeFieldVar<TargetField, BaseField>,
152 TargetField,
153 Add,
154 add,
155 AddAssign,
156 add_assign,
157 |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
158 use NonNativeFieldVar::*;
159 match (this, other) {
160 (Constant(c1), Constant(c2)) => Constant(*c1 + c2),
161 (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.add_constant(c).unwrap()),
162 (Var(v1), Var(v2)) => Var(v1.add(v2).unwrap()),
163 }
164 },
165 |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: TargetField| { this + &NonNativeFieldVar::Constant(other) },
166 (TargetField: PrimeField, BaseField: PrimeField),
167);
168
169impl_bounded_ops!(
170 NonNativeFieldVar<TargetField, BaseField>,
171 TargetField,
172 Sub,
173 sub,
174 SubAssign,
175 sub_assign,
176 |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
177 use NonNativeFieldVar::*;
178 match (this, other) {
179 (Constant(c1), Constant(c2)) => Constant(*c1 - c2),
180 (Var(v), Constant(c)) => Var(v.sub_constant(c).unwrap()),
181 (Constant(c), Var(v)) => Var(v.sub_constant(c).unwrap().negate().unwrap()),
182 (Var(v1), Var(v2)) => Var(v1.sub(v2).unwrap()),
183 }
184 },
185 |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: TargetField| {
186 this - &NonNativeFieldVar::Constant(other)
187 },
188 (TargetField: PrimeField, BaseField: PrimeField),
189);
190
191impl_bounded_ops!(
192 NonNativeFieldVar<TargetField, BaseField>,
193 TargetField,
194 Mul,
195 mul,
196 MulAssign,
197 mul_assign,
198 |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
199 use NonNativeFieldVar::*;
200 match (this, other) {
201 (Constant(c1), Constant(c2)) => Constant(*c1 * c2),
202 (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.mul_constant(c).unwrap()),
203 (Var(v1), Var(v2)) => Var(v1.mul(v2).unwrap()),
204 }
205 },
206 |this: &'a NonNativeFieldVar<TargetField, BaseField>, other: TargetField| {
207 if other.is_zero() {
208 NonNativeFieldVar::zero()
209 } else {
210 this * &NonNativeFieldVar::Constant(other)
211 }
212 },
213 (TargetField: PrimeField, BaseField: PrimeField),
214);
215
216impl<TargetField: PrimeField, BaseField: PrimeField> EqGadget<BaseField>
220 for NonNativeFieldVar<TargetField, BaseField>
221{
222 #[tracing::instrument(target = "r1cs")]
223 fn is_eq(&self, other: &Self) -> R1CSResult<Boolean<BaseField>> {
224 let cs = self.cs().or(other.cs());
225
226 if cs == ConstraintSystemRef::None {
227 Ok(Boolean::Constant(self.value()? == other.value()?))
228 } else {
229 let should_enforce_equal =
230 Boolean::new_witness(cs, || Ok(self.value()? == other.value()?))?;
231
232 self.conditional_enforce_equal(other, &should_enforce_equal)?;
233 self.conditional_enforce_not_equal(other, &should_enforce_equal.not())?;
234
235 Ok(should_enforce_equal)
236 }
237 }
238
239 #[tracing::instrument(target = "r1cs")]
240 fn conditional_enforce_equal(
241 &self,
242 other: &Self,
243 should_enforce: &Boolean<BaseField>,
244 ) -> R1CSResult<()> {
245 match (self, other) {
246 (Self::Constant(c1), Self::Constant(c2)) => {
247 if c1 != c2 {
248 should_enforce.enforce_equal(&Boolean::FALSE)?;
249 }
250 Ok(())
251 }
252 (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => {
253 let cs = v.cs();
254 let c = AllocatedNonNativeFieldVar::new_constant(cs, c)?;
255 c.conditional_enforce_equal(v, should_enforce)
256 }
257 (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_equal(v2, should_enforce),
258 }
259 }
260
261 #[tracing::instrument(target = "r1cs")]
262 fn conditional_enforce_not_equal(
263 &self,
264 other: &Self,
265 should_enforce: &Boolean<BaseField>,
266 ) -> R1CSResult<()> {
267 match (self, other) {
268 (Self::Constant(c1), Self::Constant(c2)) => {
269 if c1 == c2 {
270 should_enforce.enforce_equal(&Boolean::FALSE)?;
271 }
272 Ok(())
273 }
274 (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => {
275 let cs = v.cs();
276 let c = AllocatedNonNativeFieldVar::new_constant(cs, c)?;
277 c.conditional_enforce_not_equal(v, should_enforce)
278 }
279 (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_not_equal(v2, should_enforce),
280 }
281 }
282}
283
284impl<TargetField: PrimeField, BaseField: PrimeField> ToBitsGadget<BaseField>
285 for NonNativeFieldVar<TargetField, BaseField>
286{
287 #[tracing::instrument(target = "r1cs")]
288 fn to_bits_le(&self) -> R1CSResult<Vec<Boolean<BaseField>>> {
289 match self {
290 Self::Constant(_) => self.to_non_unique_bits_le(),
291 Self::Var(v) => v.to_bits_le(),
292 }
293 }
294
295 #[tracing::instrument(target = "r1cs")]
296 fn to_non_unique_bits_le(&self) -> R1CSResult<Vec<Boolean<BaseField>>> {
297 use ark_ff::BitIteratorLE;
298 match self {
299 Self::Constant(c) => Ok(BitIteratorLE::new(&c.into_repr())
300 .take((TargetField::Params::MODULUS_BITS) as usize)
301 .map(Boolean::constant)
302 .collect::<Vec<_>>()),
303 Self::Var(v) => v.to_non_unique_bits_le(),
304 }
305 }
306}
307
308impl<TargetField: PrimeField, BaseField: PrimeField> ToBytesGadget<BaseField>
309 for NonNativeFieldVar<TargetField, BaseField>
310{
311 #[tracing::instrument(target = "r1cs")]
314 fn to_bytes(&self) -> R1CSResult<Vec<UInt8<BaseField>>> {
315 match self {
316 Self::Constant(c) => Ok(UInt8::constant_vec(&to_bytes![c].unwrap())),
317 Self::Var(v) => v.to_bytes(),
318 }
319 }
320
321 #[tracing::instrument(target = "r1cs")]
322 fn to_non_unique_bytes(&self) -> R1CSResult<Vec<UInt8<BaseField>>> {
323 match self {
324 Self::Constant(c) => Ok(UInt8::constant_vec(&to_bytes![c].unwrap())),
325 Self::Var(v) => v.to_non_unique_bytes(),
326 }
327 }
328}
329
330impl<TargetField: PrimeField, BaseField: PrimeField> CondSelectGadget<BaseField>
331 for NonNativeFieldVar<TargetField, BaseField>
332{
333 #[tracing::instrument(target = "r1cs")]
334 fn conditionally_select(
335 cond: &Boolean<BaseField>,
336 true_value: &Self,
337 false_value: &Self,
338 ) -> R1CSResult<Self> {
339 match cond {
340 Boolean::Constant(true) => Ok(true_value.clone()),
341 Boolean::Constant(false) => Ok(false_value.clone()),
342 _ => {
343 let cs = cond.cs();
344 let true_value = match true_value {
345 Self::Constant(f) => AllocatedNonNativeFieldVar::new_constant(cs.clone(), f)?,
346 Self::Var(v) => v.clone(),
347 };
348 let false_value = match false_value {
349 Self::Constant(f) => AllocatedNonNativeFieldVar::new_constant(cs, f)?,
350 Self::Var(v) => v.clone(),
351 };
352 cond.select(&true_value, &false_value).map(Self::Var)
353 }
354 }
355 }
356}
357
358impl<TargetField: PrimeField, BaseField: PrimeField> TwoBitLookupGadget<BaseField>
361 for NonNativeFieldVar<TargetField, BaseField>
362{
363 type TableConstant = TargetField;
364
365 #[tracing::instrument(target = "r1cs")]
366 fn two_bit_lookup(b: &[Boolean<BaseField>], c: &[Self::TableConstant]) -> R1CSResult<Self> {
367 debug_assert_eq!(b.len(), 2);
368 debug_assert_eq!(c.len(), 4);
369 if b.cs().is_none() {
370 let lsb = b[0].value()? as usize;
373 let msb = b[1].value()? as usize;
374 let index = lsb + (msb << 1);
375 Ok(Self::Constant(c[index]))
376 } else {
377 AllocatedNonNativeFieldVar::two_bit_lookup(b, c).map(Self::Var)
378 }
379 }
380}
381
382impl<TargetField: PrimeField, BaseField: PrimeField> ThreeBitCondNegLookupGadget<BaseField>
383 for NonNativeFieldVar<TargetField, BaseField>
384{
385 type TableConstant = TargetField;
386
387 #[tracing::instrument(target = "r1cs")]
388 fn three_bit_cond_neg_lookup(
389 b: &[Boolean<BaseField>],
390 b0b1: &Boolean<BaseField>,
391 c: &[Self::TableConstant],
392 ) -> R1CSResult<Self> {
393 debug_assert_eq!(b.len(), 3);
394 debug_assert_eq!(c.len(), 4);
395
396 if b.cs().or(b0b1.cs()).is_none() {
397 let lsb = b[0].value()? as usize;
400 let msb = b[1].value()? as usize;
401 let index = lsb + (msb << 1);
402 let intermediate = c[index];
403
404 let is_negative = b[2].value()?;
405 let y = if is_negative {
406 -intermediate
407 } else {
408 intermediate
409 };
410 Ok(Self::Constant(y))
411 } else {
412 AllocatedNonNativeFieldVar::three_bit_cond_neg_lookup(b, b0b1, c).map(Self::Var)
413 }
414 }
415}
416
417impl<TargetField: PrimeField, BaseField: PrimeField> AllocVar<TargetField, BaseField>
418 for NonNativeFieldVar<TargetField, BaseField>
419{
420 fn new_variable<T: Borrow<TargetField>>(
421 cs: impl Into<Namespace<BaseField>>,
422 f: impl FnOnce() -> Result<T, SynthesisError>,
423 mode: AllocationMode,
424 ) -> R1CSResult<Self> {
425 let ns = cs.into();
426 let cs = ns.cs();
427
428 if cs == ConstraintSystemRef::None || mode == AllocationMode::Constant {
429 Ok(Self::Constant(*f()?.borrow()))
430 } else {
431 AllocatedNonNativeFieldVar::new_variable(cs, f, mode).map(Self::Var)
432 }
433 }
434}
435
436impl<TargetField: PrimeField, BaseField: PrimeField> ToConstraintFieldGadget<BaseField>
437 for NonNativeFieldVar<TargetField, BaseField>
438{
439 #[tracing::instrument(target = "r1cs")]
440 fn to_constraint_field(&self) -> R1CSResult<Vec<FpVar<BaseField>>> {
441 match self {
445 Self::Constant(c) => Ok(AllocatedNonNativeFieldVar::get_limbs_representations(
446 c,
447 OptimizationType::Weight,
448 )?
449 .into_iter()
450 .map(FpVar::constant)
451 .collect()),
452 Self::Var(v) => v.to_constraint_field(),
453 }
454 }
455}
456
457impl<TargetField: PrimeField, BaseField: PrimeField> NonNativeFieldVar<TargetField, BaseField> {
458 #[tracing::instrument(target = "r1cs")]
460 pub fn mul_without_reduce(
461 &self,
462 other: &Self,
463 ) -> R1CSResult<NonNativeFieldMulResultVar<TargetField, BaseField>> {
464 match self {
465 Self::Constant(c) => match other {
466 Self::Constant(other_c) => Ok(NonNativeFieldMulResultVar::Constant(*c * other_c)),
467 Self::Var(other_v) => {
468 let self_v =
469 AllocatedNonNativeFieldVar::<TargetField, BaseField>::new_constant(
470 self.cs(),
471 c,
472 )?;
473 Ok(NonNativeFieldMulResultVar::Var(
474 other_v.mul_without_reduce(&self_v)?,
475 ))
476 }
477 },
478 Self::Var(v) => {
479 let other_v = match other {
480 Self::Constant(other_c) => {
481 AllocatedNonNativeFieldVar::<TargetField, BaseField>::new_constant(
482 self.cs(),
483 other_c,
484 )?
485 }
486 Self::Var(other_v) => other_v.clone(),
487 };
488 Ok(NonNativeFieldMulResultVar::Var(
489 v.mul_without_reduce(&other_v)?,
490 ))
491 }
492 }
493 }
494}