msm_webgpu/cuzk/
utils.rs

1use crate::cuzk::msm::{P, calc_num_words};
2use ff::{Field, PrimeField};
3use halo2curves::CurveAffine;
4use num_bigint::{BigInt, BigUint, Sign};
5use num_traits::One;
6#[cfg(target_arch = "wasm32")]
7use web_sys::console;
8
9/// Convert a field element to bytes
10pub fn field_to_bytes<F: PrimeField>(value: &F) -> Vec<u8> {
11    let s_bytes = value.to_repr();
12    let s_bytes_ref = s_bytes.as_ref();
13    s_bytes_ref.to_vec()
14}
15
16/// Convert bytes to a field element
17pub fn bytes_to_field<F: PrimeField>(bytes: &[u8]) -> F {
18    let mut repr = F::Repr::default();
19    repr.as_mut()[..bytes.len()].copy_from_slice(bytes);
20    F::from_repr(repr).unwrap()
21}
22
23/// Convert a binary representation into u32 limbs.
24pub fn to_words_le_from_le_bytes(val: &[u8], num_words: usize, word_size: usize) -> Vec<u32> {
25    assert!(word_size <= 32, "u32 supports up to 32 bits");
26
27    let mut limbs = vec![0u32; num_words];
28
29    for (idx, limb) in limbs.iter_mut().enumerate() {
30        let mut word = 0u32;
31
32        // Pick out `word_size` bits that start at bit `idx * word_size`
33        for bit_in_word in 0..word_size {
34            let global_bit = idx * word_size + bit_in_word;
35            let byte_idx = global_bit / 8; // 0 = least-significant byte
36            if byte_idx >= val.len() {
37                break;
38            } // past the supplied data → 0
39
40            let bit_in_byte = global_bit % 8;
41            let bit = (val[byte_idx] >> bit_in_byte) & 1;
42            word |= (bit as u32) << bit_in_word;
43        }
44
45        *limb = word;
46    }
47
48    limbs
49}
50
51/// Convert a vector of u32 limbs into a BigUint
52pub fn to_biguint_le(limbs: &[u32], num_limbs: usize, log_limb_size: u32) -> BigUint {
53    assert!(limbs.len() == num_limbs);
54    let mut res = BigUint::from(0u32);
55    let max = 2u32.pow(log_limb_size);
56
57    for i in 0..num_limbs {
58        assert!(limbs[i] < max);
59        let idx = (num_limbs - 1 - i) as u32;
60        let a = idx * log_limb_size;
61        let b = BigUint::from(2u32).pow(a) * BigUint::from(limbs[idx as usize]);
62
63        res += b;
64    }
65
66    res
67}
68
69/// Convert a BigUint into u32 limbs
70pub fn to_words_le(val: &BigUint, num_words: usize, word_size: usize) -> Vec<u32> {
71    let mut limbs = vec![0u32; num_words];
72
73    let mask = BigUint::from((1u32 << word_size) - 1);
74    for i in 0..num_words {
75        let idx = num_words - 1 - i;
76        let shift = idx * word_size;
77        let w = (val >> shift) & mask.clone();
78        let digits = w.to_u32_digits();
79        if !digits.is_empty() {
80            limbs[idx] = digits[0];
81        }
82    }
83
84    limbs
85}
86
87/// Convert a field element into u32 limbs
88pub fn to_words_le_from_field<F: PrimeField>(
89    val: &F,
90    num_words: usize,
91    word_size: usize,
92) -> Vec<u32> {
93    let bytes = field_to_bytes(val);
94    to_words_le_from_le_bytes(&bytes, num_words, word_size)
95}
96
97/// Split each field element into limbs and convert each limb to a vector of bytes.
98pub fn fields_to_u8_vec_for_gpu<F: PrimeField>(
99    fields: &[F],
100    num_words: usize,
101    word_size: usize,
102) -> Vec<u8> {
103    fields
104        .iter()
105        .flat_map(|field| field_to_u8_vec_for_gpu(field, num_words, word_size))
106        .collect::<Vec<_>>()
107}
108
109/// Split a field element into limbs and convert each limb to a vector of bytes.
110pub fn field_to_u8_vec_for_gpu<F: PrimeField>(
111    field: &F,
112    num_words: usize,
113    word_size: usize,
114) -> Vec<u8> {
115    let bytes = field_to_bytes(field);
116    let limbs = to_words_le_from_le_bytes(&bytes, num_words, word_size);
117    let mut u8_vec = vec![0u8; num_words * 4];
118
119    for (i, limb) in limbs.iter().enumerate() {
120        let i4 = i * 4;
121        u8_vec[i4] = (limb & 255) as u8;
122        u8_vec[i4 + 1] = (limb >> 8) as u8;
123    }
124
125    u8_vec
126}
127
128/// Convert a vector of bytes into a vector of field elements
129pub fn u8s_to_fields_without_assertion<F: PrimeField>(
130    u8s: &[u8],
131    num_words: usize,
132    word_size: usize,
133) -> Vec<F> {
134    let num_u8s_per_scalar = num_words * 4;
135
136    let mut result = vec![];
137    for i in 0..(u8s.len() / num_u8s_per_scalar) {
138        let p = i * num_u8s_per_scalar;
139        let s = u8s[p..p + num_u8s_per_scalar].to_vec();
140        result.push(u8s_to_field_without_assertion(&s, num_words, word_size));
141    }
142    result
143}
144
145/// Convert a vector of bytes into a field element
146pub fn u8s_to_field_without_assertion<F: PrimeField>(
147    u8s: &[u8],
148    num_words: usize,
149    word_size: usize,
150) -> F {
151    let a = bytemuck::cast_slice::<u8, u16>(u8s);
152    let mut limbs = vec![];
153    for i in (0..a.len()).step_by(2) {
154        limbs.push(a[i]);
155    }
156    from_words_le_without_assertion(&limbs, num_words, word_size)
157}
158
159/// Convert u16 limbs into a field element
160pub fn from_words_le_without_assertion<F: PrimeField>(
161    limbs: &[u16],
162    num_words: usize,
163    word_size: usize,
164) -> F {
165    assert!(num_words == limbs.len());
166
167    let mut val = BigUint::ZERO;
168    for i in 0..num_words {
169        let exponent = (num_words - i - 1) * word_size;
170        let limb = limbs[num_words - i - 1];
171        val += BigUint::from(2u32).pow(exponent as u32) * BigUint::from(limb);
172        if val == *P {
173            val = BigUint::ZERO;
174        }
175    }
176    let bytes = val.to_bytes_le();
177    
178    bytes_to_field(&bytes)
179}
180
181/// Convert a vector of points to a vector of bytes
182pub fn points_to_bytes_for_gpu<C: CurveAffine>(
183    g: &[C],
184    num_words: usize,
185    word_size: usize,
186) -> Vec<u8> {
187    g.iter()
188        .flat_map(|affine| {
189            let coords = affine.coordinates().unwrap();
190            let x = field_to_u8_vec_for_gpu(coords.x(), num_words, word_size);
191            let y = field_to_u8_vec_for_gpu(coords.y(), num_words, word_size);
192            let z = field_to_u8_vec_for_gpu(&C::Base::ONE, num_words, word_size);
193            [x, y, z].concat()
194        })
195        .collect::<Vec<_>>()
196}
197
198/// Generate the GPU representation of the field characteristic
199pub fn gen_p_limbs(p: &BigUint, num_words: usize, word_size: usize) -> String {
200    let limbs = to_words_le(p, num_words, word_size);
201    let mut r = String::new();
202    for (i, limb) in limbs.iter().enumerate() {
203        r += &format!("    p.limbs[{i}u] = {limb}u;\n");
204    }
205    r
206}
207
208/// Generate the GPU representation of the field characteristic padded with a zero limb
209pub fn gen_p_limbs_plus_one(p: &BigUint, num_words: usize, word_size: usize) -> String {
210    let limbs = to_words_le(p, num_words, word_size);
211    let mut r = String::new();
212    for (i, limb) in limbs.iter().enumerate() {
213        r += &format!("    p.limbs[{i}u] = {limb}u;\n");
214    }
215    r += &format!("    p.limbs[{}u] = {}u;\n", limbs.len(), 0);
216    r
217}
218
219/// Generate the GPU representation of zero
220pub fn gen_zero_limbs(num_words: usize) -> String {
221    let mut r = String::new();
222    for _i in 0..(num_words - 1) {
223        r += "0u, ";
224    }
225    r += "0u";
226    r
227}
228
229/// Generate the GPU representation of one
230pub fn gen_one_limbs(num_words: usize) -> String {
231    let mut r = String::new();
232    r += "1u, ";
233    for _i in 0..(num_words - 2) {
234        r += "0u, ";
235    }
236    r += "0u";
237    r
238}
239
240/// Generate the GPU representation of the Montgomery radix
241pub fn gen_r_limbs(r: &BigUint, num_words: usize, word_size: usize) -> String {
242    let limbs = to_words_le(r, num_words, word_size);
243    let mut r = String::new();
244    for (i, limb) in limbs.iter().enumerate() {
245        r += &format!("    r.limbs[{i}u] = {limb}u;\n");
246    }
247    r
248}
249
250/// Generate the GPU representation of the Montgomery radix inverse
251pub fn gen_rinv_limbs(rinv: &BigUint, num_words: usize, word_size: usize) -> String {
252    let limbs = to_words_le(rinv, num_words, word_size);
253    let mut r = String::new();
254    for (i, limb) in limbs.iter().enumerate() {
255        r += &format!("    rinv.limbs[{i}u] = {limb}u;\n");
256    }
257    r
258}
259
260/// Generate the Montgomery magic number
261pub fn gen_mu(p: &BigUint) -> BigUint {
262    let mut x = 1u32;
263    let two = BigUint::from(2u32);
264
265    while two.pow(x) < *p {
266        x += 1;
267    }
268
269    BigUint::from(4u32).pow(x) / p
270}
271
272/// Generate the GPU representation of the Montgomery magic number
273pub fn gen_mu_limbs(p: &BigUint, num_words: usize, word_size: usize) -> String {
274    let mu = gen_mu(p);
275    let limbs = to_words_le(&mu, num_words, word_size);
276    let mut r = String::new();
277    for (i, limb) in limbs.iter().enumerate() {
278        r += &format!("    mu.limbs[{i}u] = {limb}u;\n");
279    }
280    r
281}
282
283/// Calculate the bitwidth of the field characteristic
284pub fn calc_bitwidth(p: &BigUint) -> usize {
285    if *p == BigUint::from(0u32) {
286        return 0;
287    }
288
289    p.to_radix_le(2).len()
290}
291
292/// Extended Euclidean algorithm
293fn egcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
294    if *a == BigInt::from(0u32) {
295        return (b.clone(), BigInt::from(0u32), BigInt::from(1u32));
296    }
297    let (g, x, y) = egcd(&(b % a), a);
298
299    (g, y - (b / a) * x.clone(), x.clone())
300}
301
302/// Calculate the Montgomery inverse and the Montgomery reduction parameter
303pub fn calc_inv_and_pprime(p: &BigUint, r: &BigUint) -> (BigUint, BigUint) {
304    assert!(*r != BigUint::from(0u32));
305
306    let p_bigint = BigInt::from_biguint(Sign::Plus, p.clone());
307    let r_bigint = BigInt::from_biguint(Sign::Plus, r.clone());
308    let one = BigInt::from(1u32);
309    let (_, mut rinv, mut pprime) = egcd(
310        &BigInt::from_biguint(Sign::Plus, r.clone()),
311        &BigInt::from_biguint(Sign::Plus, p.clone()),
312    );
313
314    if rinv.sign() == Sign::Minus {
315        rinv = BigInt::from_biguint(Sign::Plus, p.clone()) + rinv;
316    }
317
318    if pprime.sign() == Sign::Minus {
319        pprime = BigInt::from_biguint(Sign::Plus, r.clone()) + pprime;
320    }
321
322    // r * rinv - p * pprime == 1
323    assert!(
324        (BigInt::from_biguint(Sign::Plus, r.clone()) * &rinv % &p_bigint)
325            - (&p_bigint * &pprime % &p_bigint)
326            == one
327    );
328
329    // r * rinv % p == 1
330    assert!((BigInt::from_biguint(Sign::Plus, r.clone()) * &rinv % &p_bigint) == one);
331
332    // p * pprime % r == 1
333    assert!((&p_bigint * &pprime % &r_bigint) == one);
334
335    (rinv.to_biguint().unwrap(), pprime.to_biguint().unwrap())
336}
337
338/// Calculate the Montgomery radix inverse and the Montgomery reduction parameter
339pub fn calc_rinv_and_n0(p: &BigUint, r: &BigUint, log_limb_size: u32) -> (BigUint, u32) {
340    let (rinv, pprime) = calc_inv_and_pprime(p, r);
341    let pprime = BigInt::from_biguint(Sign::Plus, pprime);
342
343    let neg_n_inv = BigInt::from_biguint(Sign::Plus, r.clone()) - pprime;
344    let n0 = neg_n_inv % BigInt::from(2u32.pow(log_limb_size));
345    let n0 = n0.to_biguint().unwrap().to_u32_digits()[0];
346
347    (rinv, n0)
348}
349
350/// Miscellaneous parameters for the WebGPU shader
351#[derive(Debug)]
352pub struct MiscParams {
353    pub num_words: usize,
354    pub n0: u32,
355    pub r: BigUint,
356    pub rinv: BigUint,
357}
358
359/// Compute miscellaneous parameters for the WebGPU shader
360pub fn compute_misc_params(p: &BigUint, word_size: usize) -> MiscParams {
361    assert!(word_size > 0);
362    let num_words = calc_num_words(word_size);
363    let r = BigUint::one() << (num_words * word_size);
364    let res = calc_rinv_and_n0(p, &r, word_size as u32);
365    let rinv = res.0;
366    let n0 = res.1;
367    MiscParams {
368        num_words,
369        n0,
370        r: r % p,
371        rinv,
372    }
373}
374
375/// Debug print
376pub fn debug(s: &str) {
377    // if wasm
378    #[cfg(target_arch = "wasm32")]
379    console::log_1(&s.into());
380    // if not wasm
381    #[cfg(not(target_arch = "wasm32"))]
382    println!("{s}");
383}
384
385#[cfg(test)]
386mod tests {
387    use halo2curves::bn256::{Fq, Fr};
388    use num_traits::Num;
389    use rand::thread_rng;
390
391    use super::*;
392    use crate::cuzk::msm::{PARAMS, WORD_SIZE};
393    use crate::sample_scalars;
394
395    #[test]
396    fn test_to_words_le_from_le_bytes() {
397        let val = sample_scalars::<Fr>(1)[0];
398        let bytes = field_to_bytes(&val);
399        for word_size in 13..17 {
400            let num_words = calc_num_words(word_size);
401
402            let v = BigUint::from_bytes_le(&bytes);
403            let limbs = to_words_le(&v, num_words, word_size);
404            let limbs_from_le_bytes = to_words_le_from_le_bytes(&bytes, num_words, word_size);
405            assert_eq!(limbs, limbs_from_le_bytes);
406        }
407    }
408
409    #[test]
410    fn test_gen_p_limbs() {
411        let p = P.clone();
412        let num_words = calc_num_words(13);
413        let p_limbs = gen_p_limbs(&p, num_words, 13);
414        println!("{}", p_limbs);
415    }
416
417    #[test]
418    fn test_gen_r_limbs() {
419        let r = PARAMS.r.clone();
420        let num_words = calc_num_words(WORD_SIZE);
421        let r_limbs = gen_r_limbs(&r, num_words, WORD_SIZE);
422        println!("{}", r_limbs);
423    }
424
425    #[test]
426    fn test_field_to_u8_vec_for_gpu() {
427        // random
428        let mut rng = thread_rng();
429        let a = Fq::random(&mut rng);
430        for word_size in 13..17 {
431            let num_words = calc_num_words(word_size);
432            let bytes = field_to_u8_vec_for_gpu(&a, num_words, word_size);
433            let a_from_bytes = u8s_to_field_without_assertion(&bytes, num_words, word_size);
434            assert_eq!(a, a_from_bytes);
435        }
436    }
437
438    #[test]
439    fn test_to_words_le() {
440        let a = BigUint::from_str_radix(
441            "12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
442            16,
443        )
444        .unwrap();
445        let limbs = to_words_le(&a, 20, 13);
446        let expected = vec![
447            1, 0, 0, 768, 4257, 0, 0, 8154, 2678, 2765, 3072, 6255, 4581, 6694, 6530, 5290, 6700,
448            2804, 2777, 37,
449        ];
450        assert_eq!(limbs, expected);
451    }
452}