1use crate::cuzk::msm::{P, calc_num_words};
2use ff::{Field, PrimeField};
3use halo2curves::CurveAffine;
4use num_bigint::{BigInt, BigUint, Sign};
5use num_traits::One;
6#[cfg(target_arch = "wasm32")]
7use web_sys::console;
8
9pub fn field_to_bytes<F: PrimeField>(value: &F) -> Vec<u8> {
11 let s_bytes = value.to_repr();
12 let s_bytes_ref = s_bytes.as_ref();
13 s_bytes_ref.to_vec()
14}
15
16pub fn bytes_to_field<F: PrimeField>(bytes: &[u8]) -> F {
18 let mut repr = F::Repr::default();
19 repr.as_mut()[..bytes.len()].copy_from_slice(bytes);
20 F::from_repr(repr).unwrap()
21}
22
23pub fn to_words_le_from_le_bytes(val: &[u8], num_words: usize, word_size: usize) -> Vec<u32> {
25 assert!(word_size <= 32, "u32 supports up to 32 bits");
26
27 let mut limbs = vec![0u32; num_words];
28
29 for (idx, limb) in limbs.iter_mut().enumerate() {
30 let mut word = 0u32;
31
32 for bit_in_word in 0..word_size {
34 let global_bit = idx * word_size + bit_in_word;
35 let byte_idx = global_bit / 8; if byte_idx >= val.len() {
37 break;
38 } let bit_in_byte = global_bit % 8;
41 let bit = (val[byte_idx] >> bit_in_byte) & 1;
42 word |= (bit as u32) << bit_in_word;
43 }
44
45 *limb = word;
46 }
47
48 limbs
49}
50
51pub fn to_biguint_le(limbs: &[u32], num_limbs: usize, log_limb_size: u32) -> BigUint {
53 assert!(limbs.len() == num_limbs);
54 let mut res = BigUint::from(0u32);
55 let max = 2u32.pow(log_limb_size);
56
57 for i in 0..num_limbs {
58 assert!(limbs[i] < max);
59 let idx = (num_limbs - 1 - i) as u32;
60 let a = idx * log_limb_size;
61 let b = BigUint::from(2u32).pow(a) * BigUint::from(limbs[idx as usize]);
62
63 res += b;
64 }
65
66 res
67}
68
69pub fn to_words_le(val: &BigUint, num_words: usize, word_size: usize) -> Vec<u32> {
71 let mut limbs = vec![0u32; num_words];
72
73 let mask = BigUint::from((1u32 << word_size) - 1);
74 for i in 0..num_words {
75 let idx = num_words - 1 - i;
76 let shift = idx * word_size;
77 let w = (val >> shift) & mask.clone();
78 let digits = w.to_u32_digits();
79 if !digits.is_empty() {
80 limbs[idx] = digits[0];
81 }
82 }
83
84 limbs
85}
86
87pub fn to_words_le_from_field<F: PrimeField>(
89 val: &F,
90 num_words: usize,
91 word_size: usize,
92) -> Vec<u32> {
93 let bytes = field_to_bytes(val);
94 to_words_le_from_le_bytes(&bytes, num_words, word_size)
95}
96
97pub fn fields_to_u8_vec_for_gpu<F: PrimeField>(
99 fields: &[F],
100 num_words: usize,
101 word_size: usize,
102) -> Vec<u8> {
103 fields
104 .iter()
105 .flat_map(|field| field_to_u8_vec_for_gpu(field, num_words, word_size))
106 .collect::<Vec<_>>()
107}
108
109pub fn field_to_u8_vec_for_gpu<F: PrimeField>(
111 field: &F,
112 num_words: usize,
113 word_size: usize,
114) -> Vec<u8> {
115 let bytes = field_to_bytes(field);
116 let limbs = to_words_le_from_le_bytes(&bytes, num_words, word_size);
117 let mut u8_vec = vec![0u8; num_words * 4];
118
119 for (i, limb) in limbs.iter().enumerate() {
120 let i4 = i * 4;
121 u8_vec[i4] = (limb & 255) as u8;
122 u8_vec[i4 + 1] = (limb >> 8) as u8;
123 }
124
125 u8_vec
126}
127
128pub fn u8s_to_fields_without_assertion<F: PrimeField>(
130 u8s: &[u8],
131 num_words: usize,
132 word_size: usize,
133) -> Vec<F> {
134 let num_u8s_per_scalar = num_words * 4;
135
136 let mut result = vec![];
137 for i in 0..(u8s.len() / num_u8s_per_scalar) {
138 let p = i * num_u8s_per_scalar;
139 let s = u8s[p..p + num_u8s_per_scalar].to_vec();
140 result.push(u8s_to_field_without_assertion(&s, num_words, word_size));
141 }
142 result
143}
144
145pub fn u8s_to_field_without_assertion<F: PrimeField>(
147 u8s: &[u8],
148 num_words: usize,
149 word_size: usize,
150) -> F {
151 let a = bytemuck::cast_slice::<u8, u16>(u8s);
152 let mut limbs = vec![];
153 for i in (0..a.len()).step_by(2) {
154 limbs.push(a[i]);
155 }
156 from_words_le_without_assertion(&limbs, num_words, word_size)
157}
158
159pub fn from_words_le_without_assertion<F: PrimeField>(
161 limbs: &[u16],
162 num_words: usize,
163 word_size: usize,
164) -> F {
165 assert!(num_words == limbs.len());
166
167 let mut val = BigUint::ZERO;
168 for i in 0..num_words {
169 let exponent = (num_words - i - 1) * word_size;
170 let limb = limbs[num_words - i - 1];
171 val += BigUint::from(2u32).pow(exponent as u32) * BigUint::from(limb);
172 if val == *P {
173 val = BigUint::ZERO;
174 }
175 }
176 let bytes = val.to_bytes_le();
177
178 bytes_to_field(&bytes)
179}
180
181pub fn points_to_bytes_for_gpu<C: CurveAffine>(
183 g: &[C],
184 num_words: usize,
185 word_size: usize,
186) -> Vec<u8> {
187 g.iter()
188 .flat_map(|affine| {
189 let coords = affine.coordinates().unwrap();
190 let x = field_to_u8_vec_for_gpu(coords.x(), num_words, word_size);
191 let y = field_to_u8_vec_for_gpu(coords.y(), num_words, word_size);
192 let z = field_to_u8_vec_for_gpu(&C::Base::ONE, num_words, word_size);
193 [x, y, z].concat()
194 })
195 .collect::<Vec<_>>()
196}
197
198pub fn gen_p_limbs(p: &BigUint, num_words: usize, word_size: usize) -> String {
200 let limbs = to_words_le(p, num_words, word_size);
201 let mut r = String::new();
202 for (i, limb) in limbs.iter().enumerate() {
203 r += &format!(" p.limbs[{i}u] = {limb}u;\n");
204 }
205 r
206}
207
208pub fn gen_p_limbs_plus_one(p: &BigUint, num_words: usize, word_size: usize) -> String {
210 let limbs = to_words_le(p, num_words, word_size);
211 let mut r = String::new();
212 for (i, limb) in limbs.iter().enumerate() {
213 r += &format!(" p.limbs[{i}u] = {limb}u;\n");
214 }
215 r += &format!(" p.limbs[{}u] = {}u;\n", limbs.len(), 0);
216 r
217}
218
219pub fn gen_zero_limbs(num_words: usize) -> String {
221 let mut r = String::new();
222 for _i in 0..(num_words - 1) {
223 r += "0u, ";
224 }
225 r += "0u";
226 r
227}
228
229pub fn gen_one_limbs(num_words: usize) -> String {
231 let mut r = String::new();
232 r += "1u, ";
233 for _i in 0..(num_words - 2) {
234 r += "0u, ";
235 }
236 r += "0u";
237 r
238}
239
240pub fn gen_r_limbs(r: &BigUint, num_words: usize, word_size: usize) -> String {
242 let limbs = to_words_le(r, num_words, word_size);
243 let mut r = String::new();
244 for (i, limb) in limbs.iter().enumerate() {
245 r += &format!(" r.limbs[{i}u] = {limb}u;\n");
246 }
247 r
248}
249
250pub fn gen_rinv_limbs(rinv: &BigUint, num_words: usize, word_size: usize) -> String {
252 let limbs = to_words_le(rinv, num_words, word_size);
253 let mut r = String::new();
254 for (i, limb) in limbs.iter().enumerate() {
255 r += &format!(" rinv.limbs[{i}u] = {limb}u;\n");
256 }
257 r
258}
259
260pub fn gen_mu(p: &BigUint) -> BigUint {
262 let mut x = 1u32;
263 let two = BigUint::from(2u32);
264
265 while two.pow(x) < *p {
266 x += 1;
267 }
268
269 BigUint::from(4u32).pow(x) / p
270}
271
272pub fn gen_mu_limbs(p: &BigUint, num_words: usize, word_size: usize) -> String {
274 let mu = gen_mu(p);
275 let limbs = to_words_le(&mu, num_words, word_size);
276 let mut r = String::new();
277 for (i, limb) in limbs.iter().enumerate() {
278 r += &format!(" mu.limbs[{i}u] = {limb}u;\n");
279 }
280 r
281}
282
283pub fn calc_bitwidth(p: &BigUint) -> usize {
285 if *p == BigUint::from(0u32) {
286 return 0;
287 }
288
289 p.to_radix_le(2).len()
290}
291
292fn egcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
294 if *a == BigInt::from(0u32) {
295 return (b.clone(), BigInt::from(0u32), BigInt::from(1u32));
296 }
297 let (g, x, y) = egcd(&(b % a), a);
298
299 (g, y - (b / a) * x.clone(), x.clone())
300}
301
302pub fn calc_inv_and_pprime(p: &BigUint, r: &BigUint) -> (BigUint, BigUint) {
304 assert!(*r != BigUint::from(0u32));
305
306 let p_bigint = BigInt::from_biguint(Sign::Plus, p.clone());
307 let r_bigint = BigInt::from_biguint(Sign::Plus, r.clone());
308 let one = BigInt::from(1u32);
309 let (_, mut rinv, mut pprime) = egcd(
310 &BigInt::from_biguint(Sign::Plus, r.clone()),
311 &BigInt::from_biguint(Sign::Plus, p.clone()),
312 );
313
314 if rinv.sign() == Sign::Minus {
315 rinv = BigInt::from_biguint(Sign::Plus, p.clone()) + rinv;
316 }
317
318 if pprime.sign() == Sign::Minus {
319 pprime = BigInt::from_biguint(Sign::Plus, r.clone()) + pprime;
320 }
321
322 assert!(
324 (BigInt::from_biguint(Sign::Plus, r.clone()) * &rinv % &p_bigint)
325 - (&p_bigint * &pprime % &p_bigint)
326 == one
327 );
328
329 assert!((BigInt::from_biguint(Sign::Plus, r.clone()) * &rinv % &p_bigint) == one);
331
332 assert!((&p_bigint * &pprime % &r_bigint) == one);
334
335 (rinv.to_biguint().unwrap(), pprime.to_biguint().unwrap())
336}
337
338pub fn calc_rinv_and_n0(p: &BigUint, r: &BigUint, log_limb_size: u32) -> (BigUint, u32) {
340 let (rinv, pprime) = calc_inv_and_pprime(p, r);
341 let pprime = BigInt::from_biguint(Sign::Plus, pprime);
342
343 let neg_n_inv = BigInt::from_biguint(Sign::Plus, r.clone()) - pprime;
344 let n0 = neg_n_inv % BigInt::from(2u32.pow(log_limb_size));
345 let n0 = n0.to_biguint().unwrap().to_u32_digits()[0];
346
347 (rinv, n0)
348}
349
350#[derive(Debug)]
352pub struct MiscParams {
353 pub num_words: usize,
354 pub n0: u32,
355 pub r: BigUint,
356 pub rinv: BigUint,
357}
358
359pub fn compute_misc_params(p: &BigUint, word_size: usize) -> MiscParams {
361 assert!(word_size > 0);
362 let num_words = calc_num_words(word_size);
363 let r = BigUint::one() << (num_words * word_size);
364 let res = calc_rinv_and_n0(p, &r, word_size as u32);
365 let rinv = res.0;
366 let n0 = res.1;
367 MiscParams {
368 num_words,
369 n0,
370 r: r % p,
371 rinv,
372 }
373}
374
375pub fn debug(s: &str) {
377 #[cfg(target_arch = "wasm32")]
379 console::log_1(&s.into());
380 #[cfg(not(target_arch = "wasm32"))]
382 println!("{s}");
383}
384
385#[cfg(test)]
386mod tests {
387 use halo2curves::bn256::{Fq, Fr};
388 use num_traits::Num;
389 use rand::thread_rng;
390
391 use super::*;
392 use crate::cuzk::msm::{PARAMS, WORD_SIZE};
393 use crate::sample_scalars;
394
395 #[test]
396 fn test_to_words_le_from_le_bytes() {
397 let val = sample_scalars::<Fr>(1)[0];
398 let bytes = field_to_bytes(&val);
399 for word_size in 13..17 {
400 let num_words = calc_num_words(word_size);
401
402 let v = BigUint::from_bytes_le(&bytes);
403 let limbs = to_words_le(&v, num_words, word_size);
404 let limbs_from_le_bytes = to_words_le_from_le_bytes(&bytes, num_words, word_size);
405 assert_eq!(limbs, limbs_from_le_bytes);
406 }
407 }
408
409 #[test]
410 fn test_gen_p_limbs() {
411 let p = P.clone();
412 let num_words = calc_num_words(13);
413 let p_limbs = gen_p_limbs(&p, num_words, 13);
414 println!("{}", p_limbs);
415 }
416
417 #[test]
418 fn test_gen_r_limbs() {
419 let r = PARAMS.r.clone();
420 let num_words = calc_num_words(WORD_SIZE);
421 let r_limbs = gen_r_limbs(&r, num_words, WORD_SIZE);
422 println!("{}", r_limbs);
423 }
424
425 #[test]
426 fn test_field_to_u8_vec_for_gpu() {
427 let mut rng = thread_rng();
429 let a = Fq::random(&mut rng);
430 for word_size in 13..17 {
431 let num_words = calc_num_words(word_size);
432 let bytes = field_to_u8_vec_for_gpu(&a, num_words, word_size);
433 let a_from_bytes = u8s_to_field_without_assertion(&bytes, num_words, word_size);
434 assert_eq!(a, a_from_bytes);
435 }
436 }
437
438 #[test]
439 fn test_to_words_le() {
440 let a = BigUint::from_str_radix(
441 "12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
442 16,
443 )
444 .unwrap();
445 let limbs = to_words_le(&a, 20, 13);
446 let expected = vec![
447 1, 0, 0, 768, 4257, 0, 0, 8154, 2678, 2765, 3072, 6255, 4581, 6694, 6530, 5290, 6700,
448 2804, 2777, 37,
449 ];
450 assert_eq!(limbs, expected);
451 }
452}