ark_nonnative_field/
allocated_nonnative_field_mul_result_var.rs

1use crate::params::{get_params, OptimizationType};
2use crate::reduce::{bigint_to_basefield, limbs_to_bigint, Reducer};
3use crate::AllocatedNonNativeFieldVar;
4use ark_ff::{FpParameters, PrimeField};
5use ark_r1cs_std::fields::fp::FpVar;
6use ark_r1cs_std::prelude::*;
7use ark_relations::r1cs::{OptimizationGoal, Result as R1CSResult};
8use ark_relations::{ns, r1cs::ConstraintSystemRef};
9use ark_std::marker::PhantomData;
10use ark_std::vec::Vec;
11use num_bigint::BigUint;
12
13/// The allocated form of `NonNativeFieldMulResultVar` (introduced below)
14#[derive(Debug)]
15#[must_use]
16pub struct AllocatedNonNativeFieldMulResultVar<TargetField: PrimeField, BaseField: PrimeField> {
17    /// Constraint system reference
18    pub cs: ConstraintSystemRef<BaseField>,
19    /// Limbs of the intermediate representations
20    pub limbs: Vec<FpVar<BaseField>>,
21    /// The cumulative num of additions
22    pub prod_of_num_of_additions: BaseField,
23    #[doc(hidden)]
24    pub target_phantom: PhantomData<TargetField>,
25}
26
27impl<TargetField: PrimeField, BaseField: PrimeField>
28    From<&AllocatedNonNativeFieldVar<TargetField, BaseField>>
29    for AllocatedNonNativeFieldMulResultVar<TargetField, BaseField>
30{
31    fn from(src: &AllocatedNonNativeFieldVar<TargetField, BaseField>) -> Self {
32        let params = get_params(
33            TargetField::size_in_bits(),
34            BaseField::size_in_bits(),
35            src.get_optimization_type(),
36        );
37
38        let mut limbs = src.limbs.clone();
39        limbs.reverse();
40        limbs.resize(2 * params.num_limbs - 1, FpVar::<BaseField>::zero());
41        limbs.reverse();
42
43        let prod_of_num_of_additions = src.num_of_additions_over_normal_form + &BaseField::one();
44
45        Self {
46            cs: src.cs(),
47            limbs,
48            prod_of_num_of_additions,
49            target_phantom: PhantomData,
50        }
51    }
52}
53
54impl<TargetField: PrimeField, BaseField: PrimeField>
55    AllocatedNonNativeFieldMulResultVar<TargetField, BaseField>
56{
57    /// Get the CS
58    pub fn cs(&self) -> ConstraintSystemRef<BaseField> {
59        self.cs.clone()
60    }
61
62    /// Get the value of the multiplication result
63    pub fn value(&self) -> R1CSResult<TargetField> {
64        let params = get_params(
65            TargetField::size_in_bits(),
66            BaseField::size_in_bits(),
67            self.get_optimization_type(),
68        );
69
70        let p_representations =
71            AllocatedNonNativeFieldVar::<TargetField, BaseField>::get_limbs_representations_from_big_integer(
72                &<TargetField as PrimeField>::Params::MODULUS,
73                self.get_optimization_type()
74            )?;
75        let p_bigint = limbs_to_bigint(params.bits_per_limb, &p_representations);
76
77        let mut limbs_values = Vec::<BaseField>::new();
78        for limb in self.limbs.iter() {
79            limbs_values.push(limb.value().unwrap_or_default());
80        }
81        let value_bigint = limbs_to_bigint(params.bits_per_limb, &limbs_values);
82
83        let res = bigint_to_basefield::<TargetField>(&(value_bigint % p_bigint));
84        Ok(res)
85    }
86
87    /// Constraints for reducing the result of a multiplication mod p, to get an original representation.
88    pub fn reduce(&self) -> R1CSResult<AllocatedNonNativeFieldVar<TargetField, BaseField>> {
89        let params = get_params(
90            TargetField::size_in_bits(),
91            BaseField::size_in_bits(),
92            self.get_optimization_type(),
93        );
94
95        // Step 1: get p
96        let p_representations =
97            AllocatedNonNativeFieldVar::<TargetField, BaseField>::get_limbs_representations_from_big_integer(
98                &<TargetField as PrimeField>::Params::MODULUS,
99                self.get_optimization_type()
100            )?;
101        let p_bigint = limbs_to_bigint(params.bits_per_limb, &p_representations);
102
103        let mut p_gadget_limbs = Vec::new();
104        for limb in p_representations.iter() {
105            p_gadget_limbs.push(FpVar::<BaseField>::new_constant(self.cs(), limb)?);
106        }
107        let p_gadget = AllocatedNonNativeFieldVar::<TargetField, BaseField> {
108            cs: self.cs(),
109            limbs: p_gadget_limbs,
110            num_of_additions_over_normal_form: BaseField::one(),
111            is_in_the_normal_form: false,
112            target_phantom: PhantomData,
113        };
114
115        // Step 2: compute surfeit
116        let surfeit = overhead!(self.prod_of_num_of_additions + BaseField::one()) + 1 + 1;
117
118        // Step 3: allocate k
119        let k_bits = {
120            let mut res = Vec::new();
121
122            let mut limbs_values = Vec::<BaseField>::new();
123            for limb in self.limbs.iter() {
124                limbs_values.push(limb.value().unwrap_or_default());
125            }
126
127            let value_bigint = limbs_to_bigint(params.bits_per_limb, &limbs_values);
128            let mut k_cur = value_bigint / p_bigint;
129
130            let total_len = TargetField::size_in_bits() + surfeit;
131
132            for _ in 0..total_len {
133                res.push(Boolean::<BaseField>::new_witness(self.cs(), || {
134                    Ok(&k_cur % 2u64 == BigUint::from(1u64))
135                })?);
136                k_cur /= 2u64;
137            }
138            res
139        };
140
141        let k_limbs = {
142            let zero = FpVar::Constant(BaseField::zero());
143            let mut limbs = Vec::new();
144
145            let mut k_bits_cur = k_bits.clone();
146
147            for i in 0..params.num_limbs {
148                let this_limb_size = if i != params.num_limbs - 1 {
149                    params.bits_per_limb
150                } else {
151                    k_bits.len() - (params.num_limbs - 1) * params.bits_per_limb
152                };
153
154                let this_limb_bits = k_bits_cur[0..this_limb_size].to_vec();
155                k_bits_cur = k_bits_cur[this_limb_size..].to_vec();
156
157                let mut limb = zero.clone();
158                let mut cur = BaseField::one();
159
160                for bit in this_limb_bits.iter() {
161                    limb += &(FpVar::<BaseField>::from(bit.clone()) * cur);
162                    cur.double_in_place();
163                }
164                limbs.push(limb);
165            }
166
167            limbs.reverse();
168            limbs
169        };
170
171        let k_gadget = AllocatedNonNativeFieldVar::<TargetField, BaseField> {
172            cs: self.cs(),
173            limbs: k_limbs,
174            num_of_additions_over_normal_form: self.prod_of_num_of_additions,
175            is_in_the_normal_form: false,
176            target_phantom: PhantomData,
177        };
178
179        let cs = self.cs();
180
181        let r_gadget = AllocatedNonNativeFieldVar::<TargetField, BaseField>::new_witness(
182            ns!(cs, "r"),
183            || Ok(self.value()?),
184        )?;
185
186        let params = get_params(
187            TargetField::size_in_bits(),
188            BaseField::size_in_bits(),
189            self.get_optimization_type(),
190        );
191
192        // Step 1: reduce `self` and `other` if neceessary
193        let mut prod_limbs = Vec::new();
194        let zero = FpVar::<BaseField>::zero();
195
196        for _ in 0..2 * params.num_limbs - 1 {
197            prod_limbs.push(zero.clone());
198        }
199
200        for i in 0..params.num_limbs {
201            for j in 0..params.num_limbs {
202                prod_limbs[i + j] = &prod_limbs[i + j] + (&p_gadget.limbs[i] * &k_gadget.limbs[j]);
203            }
204        }
205
206        let mut kp_plus_r_gadget = Self {
207            cs: cs,
208            limbs: prod_limbs,
209            prod_of_num_of_additions: (p_gadget.num_of_additions_over_normal_form
210                + BaseField::one())
211                * (k_gadget.num_of_additions_over_normal_form + BaseField::one()),
212            target_phantom: PhantomData,
213        };
214
215        let kp_plus_r_limbs_len = kp_plus_r_gadget.limbs.len();
216        for (i, limb) in r_gadget.limbs.iter().rev().enumerate() {
217            kp_plus_r_gadget.limbs[kp_plus_r_limbs_len - 1 - i] += limb;
218        }
219
220        Reducer::<TargetField, BaseField>::group_and_check_equality(
221            surfeit,
222            2 * params.bits_per_limb,
223            params.bits_per_limb,
224            &self.limbs,
225            &kp_plus_r_gadget.limbs,
226        )?;
227
228        Ok(r_gadget)
229    }
230
231    /// Add unreduced elements.
232    #[tracing::instrument(target = "r1cs")]
233    pub fn add(&self, other: &Self) -> R1CSResult<Self> {
234        assert_eq!(self.get_optimization_type(), other.get_optimization_type());
235
236        let mut new_limbs = Vec::new();
237
238        for (l1, l2) in self.limbs.iter().zip(other.limbs.iter()) {
239            let new_limb = l1 + l2;
240            new_limbs.push(new_limb);
241        }
242
243        Ok(Self {
244            cs: self.cs(),
245            limbs: new_limbs,
246            prod_of_num_of_additions: self.prod_of_num_of_additions
247                + other.prod_of_num_of_additions,
248            target_phantom: PhantomData,
249        })
250    }
251
252    /// Add native constant elem
253    #[tracing::instrument(target = "r1cs")]
254    pub fn add_constant(&self, other: &TargetField) -> R1CSResult<Self> {
255        let mut other_limbs =
256            AllocatedNonNativeFieldVar::<TargetField, BaseField>::get_limbs_representations(
257                other,
258                self.get_optimization_type(),
259            )?;
260        other_limbs.reverse();
261
262        let mut new_limbs = Vec::new();
263
264        for (i, limb) in self.limbs.iter().rev().enumerate() {
265            if i < other_limbs.len() {
266                new_limbs.push(limb + other_limbs[i]);
267            } else {
268                new_limbs.push((*limb).clone());
269            }
270        }
271
272        new_limbs.reverse();
273
274        Ok(Self {
275            cs: self.cs(),
276            limbs: new_limbs,
277            prod_of_num_of_additions: self.prod_of_num_of_additions + BaseField::one(),
278            target_phantom: PhantomData,
279        })
280    }
281
282    pub(crate) fn get_optimization_type(&self) -> OptimizationType {
283        match self.cs().optimization_goal() {
284            OptimizationGoal::None => OptimizationType::Constraints,
285            OptimizationGoal::Constraints => OptimizationType::Constraints,
286            OptimizationGoal::Weight => OptimizationType::Weight,
287        }
288    }
289}