1use crate::params::{get_params, OptimizationType};
2use crate::reduce::{bigint_to_basefield, limbs_to_bigint, Reducer};
3use crate::AllocatedNonNativeFieldVar;
4use ark_ff::{FpParameters, PrimeField};
5use ark_r1cs_std::fields::fp::FpVar;
6use ark_r1cs_std::prelude::*;
7use ark_relations::r1cs::{OptimizationGoal, Result as R1CSResult};
8use ark_relations::{ns, r1cs::ConstraintSystemRef};
9use ark_std::marker::PhantomData;
10use ark_std::vec::Vec;
11use num_bigint::BigUint;
12
13#[derive(Debug)]
15#[must_use]
16pub struct AllocatedNonNativeFieldMulResultVar<TargetField: PrimeField, BaseField: PrimeField> {
17 pub cs: ConstraintSystemRef<BaseField>,
19 pub limbs: Vec<FpVar<BaseField>>,
21 pub prod_of_num_of_additions: BaseField,
23 #[doc(hidden)]
24 pub target_phantom: PhantomData<TargetField>,
25}
26
27impl<TargetField: PrimeField, BaseField: PrimeField>
28 From<&AllocatedNonNativeFieldVar<TargetField, BaseField>>
29 for AllocatedNonNativeFieldMulResultVar<TargetField, BaseField>
30{
31 fn from(src: &AllocatedNonNativeFieldVar<TargetField, BaseField>) -> Self {
32 let params = get_params(
33 TargetField::size_in_bits(),
34 BaseField::size_in_bits(),
35 src.get_optimization_type(),
36 );
37
38 let mut limbs = src.limbs.clone();
39 limbs.reverse();
40 limbs.resize(2 * params.num_limbs - 1, FpVar::<BaseField>::zero());
41 limbs.reverse();
42
43 let prod_of_num_of_additions = src.num_of_additions_over_normal_form + &BaseField::one();
44
45 Self {
46 cs: src.cs(),
47 limbs,
48 prod_of_num_of_additions,
49 target_phantom: PhantomData,
50 }
51 }
52}
53
54impl<TargetField: PrimeField, BaseField: PrimeField>
55 AllocatedNonNativeFieldMulResultVar<TargetField, BaseField>
56{
57 pub fn cs(&self) -> ConstraintSystemRef<BaseField> {
59 self.cs.clone()
60 }
61
62 pub fn value(&self) -> R1CSResult<TargetField> {
64 let params = get_params(
65 TargetField::size_in_bits(),
66 BaseField::size_in_bits(),
67 self.get_optimization_type(),
68 );
69
70 let p_representations =
71 AllocatedNonNativeFieldVar::<TargetField, BaseField>::get_limbs_representations_from_big_integer(
72 &<TargetField as PrimeField>::Params::MODULUS,
73 self.get_optimization_type()
74 )?;
75 let p_bigint = limbs_to_bigint(params.bits_per_limb, &p_representations);
76
77 let mut limbs_values = Vec::<BaseField>::new();
78 for limb in self.limbs.iter() {
79 limbs_values.push(limb.value().unwrap_or_default());
80 }
81 let value_bigint = limbs_to_bigint(params.bits_per_limb, &limbs_values);
82
83 let res = bigint_to_basefield::<TargetField>(&(value_bigint % p_bigint));
84 Ok(res)
85 }
86
87 pub fn reduce(&self) -> R1CSResult<AllocatedNonNativeFieldVar<TargetField, BaseField>> {
89 let params = get_params(
90 TargetField::size_in_bits(),
91 BaseField::size_in_bits(),
92 self.get_optimization_type(),
93 );
94
95 let p_representations =
97 AllocatedNonNativeFieldVar::<TargetField, BaseField>::get_limbs_representations_from_big_integer(
98 &<TargetField as PrimeField>::Params::MODULUS,
99 self.get_optimization_type()
100 )?;
101 let p_bigint = limbs_to_bigint(params.bits_per_limb, &p_representations);
102
103 let mut p_gadget_limbs = Vec::new();
104 for limb in p_representations.iter() {
105 p_gadget_limbs.push(FpVar::<BaseField>::new_constant(self.cs(), limb)?);
106 }
107 let p_gadget = AllocatedNonNativeFieldVar::<TargetField, BaseField> {
108 cs: self.cs(),
109 limbs: p_gadget_limbs,
110 num_of_additions_over_normal_form: BaseField::one(),
111 is_in_the_normal_form: false,
112 target_phantom: PhantomData,
113 };
114
115 let surfeit = overhead!(self.prod_of_num_of_additions + BaseField::one()) + 1 + 1;
117
118 let k_bits = {
120 let mut res = Vec::new();
121
122 let mut limbs_values = Vec::<BaseField>::new();
123 for limb in self.limbs.iter() {
124 limbs_values.push(limb.value().unwrap_or_default());
125 }
126
127 let value_bigint = limbs_to_bigint(params.bits_per_limb, &limbs_values);
128 let mut k_cur = value_bigint / p_bigint;
129
130 let total_len = TargetField::size_in_bits() + surfeit;
131
132 for _ in 0..total_len {
133 res.push(Boolean::<BaseField>::new_witness(self.cs(), || {
134 Ok(&k_cur % 2u64 == BigUint::from(1u64))
135 })?);
136 k_cur /= 2u64;
137 }
138 res
139 };
140
141 let k_limbs = {
142 let zero = FpVar::Constant(BaseField::zero());
143 let mut limbs = Vec::new();
144
145 let mut k_bits_cur = k_bits.clone();
146
147 for i in 0..params.num_limbs {
148 let this_limb_size = if i != params.num_limbs - 1 {
149 params.bits_per_limb
150 } else {
151 k_bits.len() - (params.num_limbs - 1) * params.bits_per_limb
152 };
153
154 let this_limb_bits = k_bits_cur[0..this_limb_size].to_vec();
155 k_bits_cur = k_bits_cur[this_limb_size..].to_vec();
156
157 let mut limb = zero.clone();
158 let mut cur = BaseField::one();
159
160 for bit in this_limb_bits.iter() {
161 limb += &(FpVar::<BaseField>::from(bit.clone()) * cur);
162 cur.double_in_place();
163 }
164 limbs.push(limb);
165 }
166
167 limbs.reverse();
168 limbs
169 };
170
171 let k_gadget = AllocatedNonNativeFieldVar::<TargetField, BaseField> {
172 cs: self.cs(),
173 limbs: k_limbs,
174 num_of_additions_over_normal_form: self.prod_of_num_of_additions,
175 is_in_the_normal_form: false,
176 target_phantom: PhantomData,
177 };
178
179 let cs = self.cs();
180
181 let r_gadget = AllocatedNonNativeFieldVar::<TargetField, BaseField>::new_witness(
182 ns!(cs, "r"),
183 || Ok(self.value()?),
184 )?;
185
186 let params = get_params(
187 TargetField::size_in_bits(),
188 BaseField::size_in_bits(),
189 self.get_optimization_type(),
190 );
191
192 let mut prod_limbs = Vec::new();
194 let zero = FpVar::<BaseField>::zero();
195
196 for _ in 0..2 * params.num_limbs - 1 {
197 prod_limbs.push(zero.clone());
198 }
199
200 for i in 0..params.num_limbs {
201 for j in 0..params.num_limbs {
202 prod_limbs[i + j] = &prod_limbs[i + j] + (&p_gadget.limbs[i] * &k_gadget.limbs[j]);
203 }
204 }
205
206 let mut kp_plus_r_gadget = Self {
207 cs: cs,
208 limbs: prod_limbs,
209 prod_of_num_of_additions: (p_gadget.num_of_additions_over_normal_form
210 + BaseField::one())
211 * (k_gadget.num_of_additions_over_normal_form + BaseField::one()),
212 target_phantom: PhantomData,
213 };
214
215 let kp_plus_r_limbs_len = kp_plus_r_gadget.limbs.len();
216 for (i, limb) in r_gadget.limbs.iter().rev().enumerate() {
217 kp_plus_r_gadget.limbs[kp_plus_r_limbs_len - 1 - i] += limb;
218 }
219
220 Reducer::<TargetField, BaseField>::group_and_check_equality(
221 surfeit,
222 2 * params.bits_per_limb,
223 params.bits_per_limb,
224 &self.limbs,
225 &kp_plus_r_gadget.limbs,
226 )?;
227
228 Ok(r_gadget)
229 }
230
231 #[tracing::instrument(target = "r1cs")]
233 pub fn add(&self, other: &Self) -> R1CSResult<Self> {
234 assert_eq!(self.get_optimization_type(), other.get_optimization_type());
235
236 let mut new_limbs = Vec::new();
237
238 for (l1, l2) in self.limbs.iter().zip(other.limbs.iter()) {
239 let new_limb = l1 + l2;
240 new_limbs.push(new_limb);
241 }
242
243 Ok(Self {
244 cs: self.cs(),
245 limbs: new_limbs,
246 prod_of_num_of_additions: self.prod_of_num_of_additions
247 + other.prod_of_num_of_additions,
248 target_phantom: PhantomData,
249 })
250 }
251
252 #[tracing::instrument(target = "r1cs")]
254 pub fn add_constant(&self, other: &TargetField) -> R1CSResult<Self> {
255 let mut other_limbs =
256 AllocatedNonNativeFieldVar::<TargetField, BaseField>::get_limbs_representations(
257 other,
258 self.get_optimization_type(),
259 )?;
260 other_limbs.reverse();
261
262 let mut new_limbs = Vec::new();
263
264 for (i, limb) in self.limbs.iter().rev().enumerate() {
265 if i < other_limbs.len() {
266 new_limbs.push(limb + other_limbs[i]);
267 } else {
268 new_limbs.push((*limb).clone());
269 }
270 }
271
272 new_limbs.reverse();
273
274 Ok(Self {
275 cs: self.cs(),
276 limbs: new_limbs,
277 prod_of_num_of_additions: self.prod_of_num_of_additions + BaseField::one(),
278 target_phantom: PhantomData,
279 })
280 }
281
282 pub(crate) fn get_optimization_type(&self) -> OptimizationType {
283 match self.cs().optimization_goal() {
284 OptimizationGoal::None => OptimizationType::Constraints,
285 OptimizationGoal::Constraints => OptimizationType::Constraints,
286 OptimizationGoal::Weight => OptimizationType::Weight,
287 }
288 }
289}