Skip to main content

ff_derive_arcium_fork/
lib.rs

1#![recursion_limit = "1024"]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5
6use num_bigint::BigUint;
7use num_integer::Integer;
8use num_traits::{One, ToPrimitive, Zero};
9use quote::quote;
10use quote::TokenStreamExt;
11use std::iter;
12use std::str::FromStr;
13
14mod pow_fixed;
15
16enum ReprEndianness {
17    Big,
18    Little,
19}
20
21impl FromStr for ReprEndianness {
22    type Err = ();
23
24    fn from_str(s: &str) -> Result<Self, Self::Err> {
25        match s {
26            "big" => Ok(ReprEndianness::Big),
27            "little" => Ok(ReprEndianness::Little),
28            _ => Err(()),
29        }
30    }
31}
32
33impl ReprEndianness {
34    fn modulus_repr(&self, modulus: &BigUint, bytes: usize) -> Vec<u8> {
35        match self {
36            ReprEndianness::Big => {
37                let buf = modulus.to_bytes_be();
38                iter::repeat(0)
39                    .take(bytes - buf.len())
40                    .chain(buf.into_iter())
41                    .collect()
42            }
43            ReprEndianness::Little => {
44                let mut buf = modulus.to_bytes_le();
45                buf.extend(iter::repeat(0).take(bytes - buf.len()));
46                buf
47            }
48        }
49    }
50
51    fn from_repr(&self, name: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream {
52        let read_repr = match self {
53            ReprEndianness::Big => quote! {
54                ::ff::derive::byteorder::BigEndian::read_u64_into(r.as_ref(), &mut inner[..]);
55                inner.reverse();
56            },
57            ReprEndianness::Little => quote! {
58                ::ff::derive::byteorder::LittleEndian::read_u64_into(r.as_ref(), &mut inner[..]);
59            },
60        };
61
62        quote! {
63            use ::ff::derive::byteorder::ByteOrder;
64
65            let r = {
66                let mut inner = [0u64; #limbs];
67                #read_repr
68                #name(inner)
69            };
70        }
71    }
72
73    fn to_repr(
74        &self,
75        repr: proc_macro2::TokenStream,
76        mont_reduce_self_params: &proc_macro2::TokenStream,
77        limbs: usize,
78    ) -> proc_macro2::TokenStream {
79        let bytes = limbs * 8;
80
81        let write_repr = match self {
82            ReprEndianness::Big => quote! {
83                r.0.reverse();
84                ::ff::derive::byteorder::BigEndian::write_u64_into(&r.0, &mut repr[..]);
85            },
86            ReprEndianness::Little => quote! {
87                ::ff::derive::byteorder::LittleEndian::write_u64_into(&r.0, &mut repr[..]);
88            },
89        };
90
91        quote! {
92            use ::ff::derive::byteorder::ByteOrder;
93
94            let mut r = *self;
95            r.mont_reduce(
96                #mont_reduce_self_params
97            );
98
99            let mut repr = [0u8; #bytes];
100            #write_repr
101            #repr(repr)
102        }
103    }
104
105    fn iter_be(&self) -> proc_macro2::TokenStream {
106        match self {
107            ReprEndianness::Big => quote! {self.0.iter()},
108            ReprEndianness::Little => quote! {self.0.iter().rev()},
109        }
110    }
111}
112
113/// Derive the `PrimeField` trait.
114#[proc_macro_derive(
115    PrimeField,
116    attributes(PrimeFieldModulus, PrimeFieldGenerator, PrimeFieldReprEndianness)
117)]
118pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
119    prime_field_2(input.into()).into()
120}
121
122fn prime_field_2(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
123    // Parse the type definition
124    let ast: syn::DeriveInput = syn::parse2(input).unwrap();
125
126    // We're given the modulus p of the prime field
127    let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs)
128        .expect("Please supply a PrimeFieldModulus attribute")
129        .parse()
130        .expect("PrimeFieldModulus should be a number");
131
132    // We may be provided with a generator of p - 1 order. It is required that this generator be quadratic
133    // nonresidue.
134    // TODO: Compute this ourselves.
135    let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs)
136        .expect("Please supply a PrimeFieldGenerator attribute")
137        .parse()
138        .expect("PrimeFieldGenerator should be a number");
139
140    // Field element representations may be in little-endian or big-endian.
141    let endianness = fetch_attr("PrimeFieldReprEndianness", &ast.attrs)
142        .expect("Please supply a PrimeFieldReprEndianness attribute")
143        .parse()
144        .expect("PrimeFieldReprEndianness should be 'big' or 'little'");
145
146    // The arithmetic in this library only works if the modulus*2 is smaller than the backing
147    // representation. Compute the number of limbs we need.
148    let mut limbs = 1;
149    {
150        let mod2 = (&modulus) << 1; // modulus * 2
151        let mut cur = BigUint::one() << 64; // always 64-bit limbs for now
152        while cur < mod2 {
153            limbs += 1;
154            cur <<= 64;
155        }
156    }
157
158    // The struct we're deriving for must be a wrapper around `pub [u64; limbs]`.
159    if let Some(err) = validate_struct(&ast, limbs) {
160        return err.into();
161    }
162
163    // Generate the identifier for the "Repr" type we must construct.
164    let repr_ident = syn::Ident::new(
165        &format!("{}Repr", ast.ident),
166        proc_macro2::Span::call_site(),
167    );
168
169    let mut gen = proc_macro2::TokenStream::new();
170
171    let (constants_impl, sqrt_impl) =
172        prime_field_constants_and_sqrt(&ast.ident, &modulus, limbs, generator);
173
174    gen.extend(constants_impl);
175    gen.extend(prime_field_repr_impl(&repr_ident, &endianness, limbs * 8));
176    gen.extend(prime_field_impl(
177        &ast.ident,
178        &repr_ident,
179        &modulus,
180        &endianness,
181        limbs,
182        sqrt_impl,
183    ));
184
185    // Return the generated impl
186    gen.into()
187}
188
189/// Checks that `body` contains `pub [u64; limbs]`.
190fn validate_struct(ast: &syn::DeriveInput, limbs: usize) -> Option<proc_macro2::TokenStream> {
191    // The body should be a struct.
192    let variant_data = match &ast.data {
193        syn::Data::Struct(x) => x,
194        _ => {
195            return Some(
196                syn::Error::new_spanned(ast, "PrimeField derive only works for structs.")
197                    .to_compile_error(),
198            )
199        }
200    };
201
202    // The struct should contain a single unnamed field.
203    let fields = match &variant_data.fields {
204        syn::Fields::Unnamed(x) if x.unnamed.len() == 1 => x,
205        _ => {
206            return Some(
207                syn::Error::new_spanned(
208                    &ast.ident,
209                    format!(
210                        "The struct must contain an array of limbs. Change this to `{}([u64; {}])`",
211                        ast.ident, limbs,
212                    ),
213                )
214                .to_compile_error(),
215            )
216        }
217    };
218    let field = &fields.unnamed[0];
219
220    // The field should be an array.
221    let arr = match &field.ty {
222        syn::Type::Array(x) => x,
223        _ => {
224            return Some(
225                syn::Error::new_spanned(
226                    field,
227                    format!(
228                        "The inner field must be an array of limbs. Change this to `[u64; {}]`",
229                        limbs,
230                    ),
231                )
232                .to_compile_error(),
233            )
234        }
235    };
236
237    // The array's element type should be `u64`.
238    if match arr.elem.as_ref() {
239        syn::Type::Path(path) => path
240            .path
241            .get_ident()
242            .map(|x| x.to_string() != "u64")
243            .unwrap_or(true),
244        _ => true,
245    } {
246        return Some(
247            syn::Error::new_spanned(
248                arr,
249                format!(
250                    "PrimeField derive requires 64-bit limbs. Change this to `[u64; {}]",
251                    limbs
252                ),
253            )
254            .to_compile_error(),
255        );
256    }
257
258    // The array's length should be a literal int equal to `limbs`.
259    let expr_lit = match &arr.len {
260        syn::Expr::Lit(expr_lit) => Some(&expr_lit.lit),
261        syn::Expr::Group(expr_group) => match &*expr_group.expr {
262            syn::Expr::Lit(expr_lit) => Some(&expr_lit.lit),
263            _ => None,
264        },
265        _ => None,
266    };
267    let lit_int = match match expr_lit {
268        Some(syn::Lit::Int(lit_int)) => Some(lit_int),
269        _ => None,
270    } {
271        Some(x) => x,
272        _ => {
273            return Some(
274                syn::Error::new_spanned(
275                    arr,
276                    format!("To derive PrimeField, change this to `[u64; {}]`.", limbs),
277                )
278                .to_compile_error(),
279            )
280        }
281    };
282    if lit_int.base10_digits() != limbs.to_string() {
283        return Some(
284            syn::Error::new_spanned(
285                lit_int,
286                format!("The given modulus requires {} limbs.", limbs),
287            )
288            .to_compile_error(),
289        );
290    }
291
292    // The field should not be public.
293    match &field.vis {
294        syn::Visibility::Inherited => (),
295        _ => {
296            return Some(
297                syn::Error::new_spanned(&field.vis, "Field must not be public.").to_compile_error(),
298            )
299        }
300    }
301
302    // Valid!
303    None
304}
305
306/// Fetch an attribute string from the derived struct.
307fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
308    for attr in attrs {
309        if let Ok(meta) = attr.parse_meta() {
310            match meta {
311                syn::Meta::NameValue(nv) => {
312                    if nv.path.get_ident().map(|i| i.to_string()) == Some(name.to_string()) {
313                        match nv.lit {
314                            syn::Lit::Str(ref s) => return Some(s.value()),
315                            _ => {
316                                panic!("attribute {} should be a string", name);
317                            }
318                        }
319                    }
320                }
321                _ => {
322                    panic!("attribute {} should be a string", name);
323                }
324            }
325        }
326    }
327
328    None
329}
330
331// Implement the wrapped ident `repr` with `bytes` bytes.
332fn prime_field_repr_impl(
333    repr: &syn::Ident,
334    endianness: &ReprEndianness,
335    bytes: usize,
336) -> proc_macro2::TokenStream {
337    let repr_iter_be = endianness.iter_be();
338
339    quote! {
340        #[derive(Copy, Clone)]
341        pub struct #repr(pub [u8; #bytes]);
342
343        impl ::ff::derive::subtle::ConstantTimeEq for #repr {
344            fn ct_eq(&self, other: &#repr) -> ::ff::derive::subtle::Choice {
345                self.0
346                    .iter()
347                    .zip(other.0.iter())
348                    .map(|(a, b)| a.ct_eq(b))
349                    .fold(1.into(), |acc, x| acc & x)
350            }
351        }
352
353        impl ::core::cmp::PartialEq for #repr {
354            fn eq(&self, other: &#repr) -> bool {
355                use ::ff::derive::subtle::ConstantTimeEq;
356                self.ct_eq(other).into()
357            }
358        }
359
360        impl ::core::cmp::Eq for #repr { }
361
362        impl ::core::default::Default for #repr {
363            fn default() -> #repr {
364                #repr([0u8; #bytes])
365            }
366        }
367
368        impl ::core::fmt::Debug for #repr
369        {
370            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
371                write!(f, "0x")?;
372                for i in #repr_iter_be {
373                    write!(f, "{:02x}", *i)?;
374                }
375
376                Ok(())
377            }
378        }
379
380        impl AsRef<[u8]> for #repr {
381            #[inline(always)]
382            fn as_ref(&self) -> &[u8] {
383                &self.0
384            }
385        }
386
387        impl AsMut<[u8]> for #repr {
388            #[inline(always)]
389            fn as_mut(&mut self) -> &mut [u8] {
390                &mut self.0
391            }
392        }
393    }
394}
395
396/// Convert BigUint into a vector of 64-bit limbs.
397fn biguint_to_real_u64_vec(mut v: BigUint, limbs: usize) -> Vec<u64> {
398    let m = BigUint::one() << 64;
399    let mut ret = vec![];
400
401    while v > BigUint::zero() {
402        let limb: BigUint = &v % &m;
403        ret.push(limb.to_u64().unwrap());
404        v >>= 64;
405    }
406
407    while ret.len() < limbs {
408        ret.push(0);
409    }
410
411    assert!(ret.len() == limbs);
412
413    ret
414}
415
416/// Convert BigUint into a tokenized vector of 64-bit limbs.
417fn biguint_to_u64_vec(v: BigUint, limbs: usize) -> proc_macro2::TokenStream {
418    let ret = biguint_to_real_u64_vec(v, limbs);
419    quote!([#(#ret,)*])
420}
421
422fn biguint_num_bits(mut v: BigUint) -> u32 {
423    let mut bits = 0;
424
425    while v != BigUint::zero() {
426        v >>= 1;
427        bits += 1;
428    }
429
430    bits
431}
432
433/// BigUint modular exponentiation by square-and-multiply.
434fn exp(base: BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint {
435    let mut ret = BigUint::one();
436
437    for i in exp
438        .to_bytes_be()
439        .into_iter()
440        .flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd()))
441    {
442        ret = (&ret * &ret) % modulus;
443        if i {
444            ret = (ret * &base) % modulus;
445        }
446    }
447
448    ret
449}
450
451#[test]
452fn test_exp() {
453    assert_eq!(
454        exp(
455            BigUint::from_str("4398572349857239485729348572983472345").unwrap(),
456            &BigUint::from_str("5489673498567349856734895").unwrap(),
457            &BigUint::from_str(
458                "52435875175126190479447740508185965837690552500527637822603658699938581184513"
459            )
460            .unwrap()
461        ),
462        BigUint::from_str(
463            "4371221214068404307866768905142520595925044802278091865033317963560480051536"
464        )
465        .unwrap()
466    );
467}
468
469fn prime_field_constants_and_sqrt(
470    name: &syn::Ident,
471    modulus: &BigUint,
472    limbs: usize,
473    generator: BigUint,
474) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
475    let bytes = limbs * 8;
476    let modulus_num_bits = biguint_num_bits(modulus.clone());
477
478    // The number of bits we should "shave" from a randomly sampled reputation, i.e.,
479    // if our modulus is 381 bits and our representation is 384 bits, we should shave
480    // 3 bits from the beginning of a randomly sampled 384 bit representation to
481    // reduce the cost of rejection sampling.
482    let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone());
483
484    // Compute R = 2**(64 * limbs) mod m
485    let r = (BigUint::one() << (limbs * 64)) % modulus;
486    let to_mont = |v| (v * &r) % modulus;
487
488    let two = BigUint::from_str("2").unwrap();
489    let p_minus_2 = modulus - &two;
490    let invert = |v| exp(v, &p_minus_2, &modulus);
491
492    // 2^-1 mod m
493    let two_inv = biguint_to_u64_vec(to_mont(invert(two)), limbs);
494
495    // modulus - 1 = 2^s * t
496    let mut s: u32 = 0;
497    let mut t = modulus - BigUint::from_str("1").unwrap();
498    while t.is_even() {
499        t >>= 1;
500        s += 1;
501    }
502
503    // Compute 2^s root of unity given the generator
504    let root_of_unity = exp(generator.clone(), &t, &modulus);
505    let root_of_unity_inv = biguint_to_u64_vec(to_mont(invert(root_of_unity.clone())), limbs);
506    let root_of_unity = biguint_to_u64_vec(to_mont(root_of_unity), limbs);
507    let delta = biguint_to_u64_vec(
508        to_mont(exp(generator.clone(), &(BigUint::one() << s), &modulus)),
509        limbs,
510    );
511    let generator = biguint_to_u64_vec(to_mont(generator), limbs);
512
513    let sqrt_impl =
514        if (modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
515            // Addition chain for (r + 1) // 4
516            let mod_plus_1_over_4 = pow_fixed::generate(
517                &quote! {self},
518                (modulus + BigUint::from_str("1").unwrap()) >> 2,
519            );
520
521            quote! {
522                use ::ff::derive::subtle::ConstantTimeEq;
523
524                // Because r = 3 (mod 4)
525                // sqrt can be done with only one exponentiation,
526                // via the computation of  self^((r + 1) // 4) (mod r)
527                let sqrt = {
528                    #mod_plus_1_over_4
529                };
530
531                ::ff::derive::subtle::CtOption::new(
532                    sqrt,
533                    (sqrt * &sqrt).ct_eq(self), // Only return Some if it's the square root.
534                )
535            }
536        } else {
537            // Addition chain for (t - 1) // 2
538            let t_minus_1_over_2 = if t == BigUint::one() {
539                quote!( #name::ONE )
540            } else {
541                pow_fixed::generate(&quote! {self}, (&t - BigUint::one()) >> 1)
542            };
543
544            quote! {
545                // Tonelli-Shank's algorithm works for every odd prime.
546                // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
547                use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq};
548
549                // w = self^((t - 1) // 2)
550                let w = {
551                    #t_minus_1_over_2
552                };
553
554                let mut v = S;
555                let mut x = *self * &w;
556                let mut b = x * &w;
557
558                // Initialize z as the 2^S root of unity.
559                let mut z = ROOT_OF_UNITY;
560
561                for max_v in (1..=S).rev() {
562                    let mut k = 1;
563                    let mut tmp = b.square();
564                    let mut j_less_than_v: ::ff::derive::subtle::Choice = 1.into();
565
566                    for j in 2..max_v {
567                        let tmp_is_one = tmp.ct_eq(&#name::ONE);
568                        let squared = #name::conditional_select(&tmp, &z, tmp_is_one).square();
569                        tmp = #name::conditional_select(&squared, &tmp, tmp_is_one);
570                        let new_z = #name::conditional_select(&z, &squared, tmp_is_one);
571                        j_less_than_v &= !j.ct_eq(&v);
572                        k = u32::conditional_select(&j, &k, tmp_is_one);
573                        z = #name::conditional_select(&z, &new_z, j_less_than_v);
574                    }
575
576                    let result = x * &z;
577                    x = #name::conditional_select(&result, &x, b.ct_eq(&#name::ONE));
578                    z = z.square();
579                    b *= &z;
580                    v = k;
581                }
582
583                ::ff::derive::subtle::CtOption::new(
584                    x,
585                    (x * &x).ct_eq(self), // Only return Some if it's the square root.
586                )
587            }
588        };
589
590    // Compute R^2 mod m
591    let r2 = biguint_to_u64_vec((&r * &r) % modulus, limbs);
592
593    let r = biguint_to_u64_vec(r, limbs);
594    let modulus_le_bytes = ReprEndianness::Little.modulus_repr(modulus, limbs * 8);
595    let modulus_str = format!("0x{}", modulus.to_str_radix(16));
596    let modulus = biguint_to_real_u64_vec(modulus.clone(), limbs);
597
598    // Compute -m^-1 mod 2**64 by exponentiating by totient(2**64) - 1
599    let mut inv = 1u64;
600    for _ in 0..63 {
601        inv = inv.wrapping_mul(inv);
602        inv = inv.wrapping_mul(modulus[0]);
603    }
604    inv = inv.wrapping_neg();
605
606    (
607        quote! {
608            type REPR_BYTES = [u8; #bytes];
609            type REPR_BITS = REPR_BYTES;
610
611            /// This is the modulus m of the prime field
612            const MODULUS: REPR_BITS = [#(#modulus_le_bytes,)*];
613
614            /// This is the modulus m of the prime field in limb form
615            const MODULUS_LIMBS: #name = #name([#(#modulus,)*]);
616
617            /// This is the modulus m of the prime field in hex string form
618            const MODULUS_STR: &'static str = #modulus_str;
619
620            /// The number of bits needed to represent the modulus.
621            const MODULUS_BITS: u32 = #modulus_num_bits;
622
623            /// The number of bits that must be shaved from the beginning of
624            /// the representation when randomly sampling.
625            const REPR_SHAVE_BITS: u32 = #repr_shave_bits;
626
627            /// 2^{limbs*64} mod m
628            const R: #name = #name(#r);
629
630            /// 2^{limbs*64*2} mod m
631            const R2: #name = #name(#r2);
632
633            /// -(m^{-1} mod m) mod m
634            const INV: u64 = #inv;
635
636            /// 2^{-1} mod m
637            const TWO_INV: #name = #name(#two_inv);
638
639            /// Multiplicative generator of `MODULUS` - 1 order, also quadratic
640            /// nonresidue.
641            const GENERATOR: #name = #name(#generator);
642
643            /// 2^s * t = MODULUS - 1 with t odd
644            const S: u32 = #s;
645
646            /// 2^s root of unity computed by GENERATOR^t
647            const ROOT_OF_UNITY: #name = #name(#root_of_unity);
648
649            /// (2^s)^{-1} mod m
650            const ROOT_OF_UNITY_INV: #name = #name(#root_of_unity_inv);
651
652            /// GENERATOR^{2^s}
653            const DELTA: #name = #name(#delta);
654        },
655        sqrt_impl,
656    )
657}
658
659/// Implement PrimeField for the derived type.
660fn prime_field_impl(
661    name: &syn::Ident,
662    repr: &syn::Ident,
663    modulus: &BigUint,
664    endianness: &ReprEndianness,
665    limbs: usize,
666    sqrt_impl: proc_macro2::TokenStream,
667) -> proc_macro2::TokenStream {
668    // Returns r{n} as an ident.
669    fn get_temp(n: usize) -> syn::Ident {
670        syn::Ident::new(&format!("r{}", n), proc_macro2::Span::call_site())
671    }
672
673    // The parameter list for the mont_reduce() internal method.
674    // r0: u64, mut r1: u64, mut r2: u64, ...
675    let mut mont_paramlist = proc_macro2::TokenStream::new();
676    mont_paramlist.append_separated(
677        (0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| {
678            if i != 0 {
679                quote! {mut #x: u64}
680            } else {
681                quote! {#x: u64}
682            }
683        }),
684        proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
685    );
686
687    // Implement montgomery reduction for some number of limbs
688    fn mont_impl(limbs: usize) -> proc_macro2::TokenStream {
689        let mut gen = proc_macro2::TokenStream::new();
690
691        for i in 0..limbs {
692            {
693                let temp = get_temp(i);
694                gen.extend(quote! {
695                    let k = #temp.wrapping_mul(INV);
696                    let (_, carry) = ::ff::derive::mac(#temp, k, MODULUS_LIMBS.0[0], 0);
697                });
698            }
699
700            for j in 1..limbs {
701                let temp = get_temp(i + j);
702                gen.extend(quote! {
703                    let (#temp, carry) = ::ff::derive::mac(#temp, k, MODULUS_LIMBS.0[#j], carry);
704                });
705            }
706
707            let temp = get_temp(i + limbs);
708
709            if i == 0 {
710                gen.extend(quote! {
711                    let (#temp, carry2) = ::ff::derive::adc(#temp, 0, carry);
712                });
713            } else {
714                gen.extend(quote! {
715                    let (#temp, carry2) = ::ff::derive::adc(#temp, carry2, carry);
716                });
717            }
718        }
719
720        for i in 0..limbs {
721            let temp = get_temp(limbs + i);
722
723            gen.extend(quote! {
724                self.0[#i] = #temp;
725            });
726        }
727
728        gen
729    }
730
731    fn sqr_impl(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
732        let mut gen = proc_macro2::TokenStream::new();
733
734        if limbs > 1 {
735            for i in 0..(limbs - 1) {
736                gen.extend(quote! {
737                    let carry = 0;
738                });
739
740                for j in (i + 1)..limbs {
741                    let temp = get_temp(i + j);
742                    if i == 0 {
743                        gen.extend(quote! {
744                            let (#temp, carry) = ::ff::derive::mac(0, #a.0[#i], #a.0[#j], carry);
745                        });
746                    } else {
747                        gen.extend(quote! {
748                            let (#temp, carry) = ::ff::derive::mac(#temp, #a.0[#i], #a.0[#j], carry);
749                        });
750                    }
751                }
752
753                let temp = get_temp(i + limbs);
754
755                gen.extend(quote! {
756                    let #temp = carry;
757                });
758            }
759
760            for i in 1..(limbs * 2) {
761                let temp0 = get_temp(limbs * 2 - i);
762                let temp1 = get_temp(limbs * 2 - i - 1);
763
764                if i == 1 {
765                    gen.extend(quote! {
766                        let #temp0 = #temp1 >> 63;
767                    });
768                } else if i == (limbs * 2 - 1) {
769                    gen.extend(quote! {
770                        let #temp0 = #temp0 << 1;
771                    });
772                } else {
773                    gen.extend(quote! {
774                        let #temp0 = (#temp0 << 1) | (#temp1 >> 63);
775                    });
776                }
777            }
778        } else {
779            let temp1 = get_temp(1);
780            gen.extend(quote! {
781                let #temp1 = 0;
782            });
783        }
784
785        for i in 0..limbs {
786            let temp0 = get_temp(i * 2);
787            let temp1 = get_temp(i * 2 + 1);
788            if i == 0 {
789                gen.extend(quote! {
790                    let (#temp0, carry) = ::ff::derive::mac(0, #a.0[#i], #a.0[#i], 0);
791                });
792            } else {
793                gen.extend(quote! {
794                    let (#temp0, carry) = ::ff::derive::mac(#temp0, #a.0[#i], #a.0[#i], carry);
795                });
796            }
797
798            gen.extend(quote! {
799                let (#temp1, carry) = ::ff::derive::adc(#temp1, 0, carry);
800            });
801        }
802
803        let mut mont_calling = proc_macro2::TokenStream::new();
804        mont_calling.append_separated(
805            (0..(limbs * 2)).map(get_temp),
806            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
807        );
808
809        gen.extend(quote! {
810            let mut ret = *self;
811            ret.mont_reduce(#mont_calling);
812            ret
813        });
814
815        gen
816    }
817
818    fn mul_impl(
819        a: proc_macro2::TokenStream,
820        b: proc_macro2::TokenStream,
821        limbs: usize,
822    ) -> proc_macro2::TokenStream {
823        let mut gen = proc_macro2::TokenStream::new();
824
825        for i in 0..limbs {
826            gen.extend(quote! {
827                let carry = 0;
828            });
829
830            for j in 0..limbs {
831                let temp = get_temp(i + j);
832
833                if i == 0 {
834                    gen.extend(quote! {
835                        let (#temp, carry) = ::ff::derive::mac(0, #a.0[#i], #b.0[#j], carry);
836                    });
837                } else {
838                    gen.extend(quote! {
839                        let (#temp, carry) = ::ff::derive::mac(#temp, #a.0[#i], #b.0[#j], carry);
840                    });
841                }
842            }
843
844            let temp = get_temp(i + limbs);
845
846            gen.extend(quote! {
847                let #temp = carry;
848            });
849        }
850
851        let mut mont_calling = proc_macro2::TokenStream::new();
852        mont_calling.append_separated(
853            (0..(limbs * 2)).map(get_temp),
854            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
855        );
856
857        gen.extend(quote! {
858            self.mont_reduce(#mont_calling);
859        });
860
861        gen
862    }
863
864    /// Generates an implementation of multiplicative inversion within the target prime
865    /// field.
866    fn inv_impl(a: proc_macro2::TokenStream, modulus: &BigUint) -> proc_macro2::TokenStream {
867        // Addition chain for p - 2
868        let mod_minus_2 = pow_fixed::generate(&a, modulus - BigUint::from(2u64));
869
870        quote! {
871            use ::ff::derive::subtle::ConstantTimeEq;
872
873            // By Euler's theorem, if `a` is coprime to `p` (i.e. `gcd(a, p) = 1`), then:
874            //     a^-1 ≡ a^(phi(p) - 1) mod p
875            //
876            // `ff_derive` requires that `p` is prime; in this case, `phi(p) = p - 1`, and
877            // thus:
878            //     a^-1 ≡ a^(p - 2) mod p
879            let inv = {
880                #mod_minus_2
881            };
882
883            ::ff::derive::subtle::CtOption::new(inv, !#a.is_zero())
884        }
885    }
886
887    let squaring_impl = sqr_impl(quote! {self}, limbs);
888    let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs);
889    let invert_impl = inv_impl(quote! {self}, modulus);
890    let montgomery_impl = mont_impl(limbs);
891
892    // self.0[0].ct_eq(&other.0[0]) & self.0[1].ct_eq(&other.0[1]) & ...
893    let mut ct_eq_impl = proc_macro2::TokenStream::new();
894    ct_eq_impl.append_separated(
895        (0..limbs).map(|i| quote! { self.0[#i].ct_eq(&other.0[#i]) }),
896        proc_macro2::Punct::new('&', proc_macro2::Spacing::Alone),
897    );
898
899    fn mont_reduce_params(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
900        // a.0[0], a.0[1], ..., 0, 0, 0, 0, ...
901        let mut mont_reduce_params = proc_macro2::TokenStream::new();
902        mont_reduce_params.append_separated(
903            (0..limbs)
904                .map(|i| quote! { #a.0[#i] })
905                .chain((0..limbs).map(|_| quote! {0})),
906            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
907        );
908        mont_reduce_params
909    }
910
911    let mont_reduce_self_params = mont_reduce_params(quote! {self}, limbs);
912    let mont_reduce_other_params = mont_reduce_params(quote! {other}, limbs);
913
914    let from_repr_impl = endianness.from_repr(name, limbs);
915    let to_repr_impl = endianness.to_repr(quote! {#repr}, &mont_reduce_self_params, limbs);
916
917    cfg_if::cfg_if! {
918        if #[cfg(feature = "bits")] {
919            let to_le_bits_impl = ReprEndianness::Little.to_repr(
920                quote! {::ff::derive::bitvec::array::BitArray::new},
921                &mont_reduce_self_params,
922                limbs,
923            );
924
925            let prime_field_bits_impl = quote! {
926                impl ::ff::PrimeFieldBits for #name {
927                    type ReprBits = REPR_BITS;
928
929                    fn to_le_bits(&self) -> ::ff::FieldBits<REPR_BITS> {
930                        #to_le_bits_impl
931                    }
932
933                    fn char_le_bits() -> ::ff::FieldBits<REPR_BITS> {
934                        ::ff::FieldBits::new(MODULUS)
935                    }
936                }
937            };
938        } else {
939            let prime_field_bits_impl = quote! {};
940        }
941    };
942
943    let top_limb_index = limbs - 1;
944
945    quote! {
946        impl ::core::marker::Copy for #name { }
947
948        impl ::core::clone::Clone for #name {
949            fn clone(&self) -> #name {
950                *self
951            }
952        }
953
954        impl ::core::default::Default for #name {
955            fn default() -> #name {
956                use ::ff::Field;
957                #name::ZERO
958            }
959        }
960
961        impl ::ff::derive::subtle::ConstantTimeEq for #name {
962            fn ct_eq(&self, other: &#name) -> ::ff::derive::subtle::Choice {
963                use ::ff::PrimeField;
964                self.to_repr().ct_eq(&other.to_repr())
965            }
966        }
967
968        impl ::core::cmp::PartialEq for #name {
969            fn eq(&self, other: &#name) -> bool {
970                use ::ff::derive::subtle::ConstantTimeEq;
971                self.ct_eq(other).into()
972            }
973        }
974
975        impl ::core::cmp::Eq for #name { }
976
977        impl ::core::fmt::Debug for #name
978        {
979            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
980                use ::ff::PrimeField;
981                write!(f, "{}({:?})", stringify!(#name), self.to_repr())
982            }
983        }
984
985        /// Elements are ordered lexicographically.
986        impl Ord for #name {
987            #[inline(always)]
988            fn cmp(&self, other: &#name) -> ::core::cmp::Ordering {
989                let mut a = *self;
990                a.mont_reduce(
991                    #mont_reduce_self_params
992                );
993
994                let mut b = *other;
995                b.mont_reduce(
996                    #mont_reduce_other_params
997                );
998
999                a.cmp_native(&b)
1000            }
1001        }
1002
1003        impl PartialOrd for #name {
1004            #[inline(always)]
1005            fn partial_cmp(&self, other: &#name) -> Option<::core::cmp::Ordering> {
1006                Some(self.cmp(other))
1007            }
1008        }
1009
1010        impl From<u64> for #name {
1011            #[inline(always)]
1012            fn from(val: u64) -> #name {
1013                let mut raw = [0u64; #limbs];
1014                raw[0] = val;
1015                #name(raw) * R2
1016            }
1017        }
1018
1019        impl From<#name> for #repr {
1020            fn from(e: #name) -> #repr {
1021                use ::ff::PrimeField;
1022                e.to_repr()
1023            }
1024        }
1025
1026        impl<'a> From<&'a #name> for #repr {
1027            fn from(e: &'a #name) -> #repr {
1028                use ::ff::PrimeField;
1029                e.to_repr()
1030            }
1031        }
1032
1033        impl ::ff::derive::subtle::ConditionallySelectable for #name {
1034            fn conditional_select(a: &#name, b: &#name, choice: ::ff::derive::subtle::Choice) -> #name {
1035                let mut res = [0u64; #limbs];
1036                for i in 0..#limbs {
1037                    res[i] = u64::conditional_select(&a.0[i], &b.0[i], choice);
1038                }
1039                #name(res)
1040            }
1041        }
1042
1043        impl ::core::ops::Neg for #name {
1044            type Output = #name;
1045
1046            #[inline]
1047            fn neg(self) -> #name {
1048                use ::ff::Field;
1049
1050                let mut ret = self;
1051                if !ret.is_zero_vartime() {
1052                    let mut tmp = MODULUS_LIMBS;
1053                    tmp.sub_noborrow(&ret);
1054                    ret = tmp;
1055                }
1056                ret
1057            }
1058        }
1059
1060        impl<'r> ::core::ops::Add<&'r #name> for #name {
1061            type Output = #name;
1062
1063            #[inline]
1064            fn add(self, other: &#name) -> #name {
1065                use ::core::ops::AddAssign;
1066
1067                let mut ret = self;
1068                ret.add_assign(other);
1069                ret
1070            }
1071        }
1072
1073        impl ::core::ops::Add for #name {
1074            type Output = #name;
1075
1076            #[inline]
1077            fn add(self, other: #name) -> Self {
1078                self + &other
1079            }
1080        }
1081
1082        impl<'r> ::core::ops::AddAssign<&'r #name> for #name {
1083            #[inline]
1084            fn add_assign(&mut self, other: &#name) {
1085                // This cannot exceed the backing capacity.
1086                self.add_nocarry(other);
1087
1088                // However, it may need to be reduced.
1089                self.reduce();
1090            }
1091        }
1092
1093        impl ::core::ops::AddAssign for #name {
1094            #[inline]
1095            fn add_assign(&mut self, other: #name) {
1096                self.add_assign(&other);
1097            }
1098        }
1099
1100        impl<'r> ::core::ops::Sub<&'r #name> for #name {
1101            type Output = #name;
1102
1103            #[inline]
1104            fn sub(self, other: &#name) -> Self {
1105                use ::core::ops::SubAssign;
1106
1107                let mut ret = self;
1108                ret.sub_assign(other);
1109                ret
1110            }
1111        }
1112
1113        impl ::core::ops::Sub for #name {
1114            type Output = #name;
1115
1116            #[inline]
1117            fn sub(self, other: #name) -> Self {
1118                self - &other
1119            }
1120        }
1121
1122        impl<'r> ::core::ops::SubAssign<&'r #name> for #name {
1123            #[inline]
1124            fn sub_assign(&mut self, other: &#name) {
1125                // If `other` is larger than `self`, we'll need to add the modulus to self first.
1126                if other.cmp_native(self) == ::core::cmp::Ordering::Greater {
1127                    self.add_nocarry(&MODULUS_LIMBS);
1128                }
1129
1130                self.sub_noborrow(other);
1131            }
1132        }
1133
1134        impl ::core::ops::SubAssign for #name {
1135            #[inline]
1136            fn sub_assign(&mut self, other: #name) {
1137                self.sub_assign(&other);
1138            }
1139        }
1140
1141        impl<'r> ::core::ops::Mul<&'r #name> for #name {
1142            type Output = #name;
1143
1144            #[inline]
1145            fn mul(self, other: &#name) -> Self {
1146                use ::core::ops::MulAssign;
1147
1148                let mut ret = self;
1149                ret.mul_assign(other);
1150                ret
1151            }
1152        }
1153
1154        impl ::core::ops::Mul for #name {
1155            type Output = #name;
1156
1157            #[inline]
1158            fn mul(self, other: #name) -> Self {
1159                self * &other
1160            }
1161        }
1162
1163        impl<'r> ::core::ops::MulAssign<&'r #name> for #name {
1164            #[inline]
1165            fn mul_assign(&mut self, other: &#name)
1166            {
1167                #multiply_impl
1168            }
1169        }
1170
1171        impl ::core::ops::MulAssign for #name {
1172            #[inline]
1173            fn mul_assign(&mut self, other: #name)
1174            {
1175                self.mul_assign(&other);
1176            }
1177        }
1178
1179        impl<T: ::core::borrow::Borrow<#name>> ::core::iter::Sum<T> for #name {
1180            fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
1181                use ::ff::Field;
1182
1183                iter.fold(Self::ZERO, |acc, item| acc + item.borrow())
1184            }
1185        }
1186
1187        impl<T: ::core::borrow::Borrow<#name>> ::core::iter::Product<T> for #name {
1188            fn product<I: Iterator<Item = T>>(iter: I) -> Self {
1189                use ::ff::Field;
1190
1191                iter.fold(Self::ONE, |acc, item| acc * item.borrow())
1192            }
1193        }
1194
1195        impl ::ff::PrimeField for #name {
1196            type Repr = #repr;
1197
1198            fn from_repr(r: #repr) -> ::ff::derive::subtle::CtOption<#name> {
1199                #from_repr_impl
1200
1201                // Try to subtract the modulus
1202                let borrow = r.0.iter().zip(MODULUS_LIMBS.0.iter()).fold(0, |borrow, (a, b)| {
1203                    ::ff::derive::sbb(*a, *b, borrow).1
1204                });
1205
1206                // If the element is smaller than MODULUS then the
1207                // subtraction will underflow, producing a borrow value
1208                // of 0xffff...ffff. Otherwise, it'll be zero.
1209                let is_some = ::ff::derive::subtle::Choice::from((borrow as u8) & 1);
1210
1211                // Convert to Montgomery form by computing
1212                // (a.R^0 * R^2) / R = a.R
1213                ::ff::derive::subtle::CtOption::new(r * &R2, is_some)
1214            }
1215
1216            fn from_repr_vartime(r: #repr) -> Option<#name> {
1217                #from_repr_impl
1218
1219                if r.is_valid() {
1220                    Some(r * R2)
1221                } else {
1222                    None
1223                }
1224            }
1225
1226            fn to_repr(&self) -> #repr {
1227                #to_repr_impl
1228            }
1229
1230            #[inline(always)]
1231            fn is_odd(&self) -> ::ff::derive::subtle::Choice {
1232                let mut r = *self;
1233                r.mont_reduce(
1234                    #mont_reduce_self_params
1235                );
1236
1237                // TODO: This looks like a constant-time result, but r.mont_reduce() is
1238                // currently implemented using variable-time code.
1239                ::ff::derive::subtle::Choice::from((r.0[0] & 1) as u8)
1240            }
1241
1242            const MODULUS: &'static str = MODULUS_STR;
1243
1244            const NUM_BITS: u32 = MODULUS_BITS;
1245
1246            const CAPACITY: u32 = Self::NUM_BITS - 1;
1247
1248            const TWO_INV: Self = TWO_INV;
1249
1250            const MULTIPLICATIVE_GENERATOR: Self = GENERATOR;
1251
1252            const S: u32 = S;
1253
1254            const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
1255
1256            const ROOT_OF_UNITY_INV: Self = ROOT_OF_UNITY_INV;
1257
1258            const DELTA: Self = DELTA;
1259        }
1260
1261        #prime_field_bits_impl
1262
1263        impl ::ff::Field for #name {
1264            const ZERO: Self = #name([0; #limbs]);
1265            const ONE: Self = R;
1266
1267            /// Computes a uniformly random element using rejection sampling.
1268            fn random(mut rng: impl ::ff::derive::rand_core::RngCore) -> Self {
1269                loop {
1270                    let mut tmp = {
1271                        let mut repr = [0u64; #limbs];
1272                        for i in 0..#limbs {
1273                            repr[i] = rng.next_u64();
1274                        }
1275                        #name(repr)
1276                    };
1277
1278                    // Mask away the unused most-significant bits.
1279                    // Note: In some edge cases, `REPR_SHAVE_BITS` could be 64, in which case
1280                    // `0xfff... >> REPR_SHAVE_BITS` overflows. So use `checked_shr` instead.
1281                    // This is always sufficient because we will have at most one spare limb
1282                    // to accommodate values of up to twice the modulus.
1283                    tmp.0[#top_limb_index] &= 0xffffffffffffffffu64.checked_shr(REPR_SHAVE_BITS).unwrap_or(0);
1284
1285                    if tmp.is_valid() {
1286                        return tmp
1287                    }
1288                }
1289            }
1290
1291            #[inline]
1292            fn is_zero_vartime(&self) -> bool {
1293                self.0.iter().all(|&e| e == 0)
1294            }
1295
1296            #[inline]
1297            fn double(&self) -> Self {
1298                let mut ret = *self;
1299
1300                // This cannot exceed the backing capacity.
1301                let mut last = 0;
1302                for i in &mut ret.0 {
1303                    let tmp = *i >> 63;
1304                    *i <<= 1;
1305                    *i |= last;
1306                    last = tmp;
1307                }
1308
1309                // However, it may need to be reduced.
1310                ret.reduce();
1311
1312                ret
1313            }
1314
1315            fn invert(&self) -> ::ff::derive::subtle::CtOption<Self> {
1316                #invert_impl
1317            }
1318
1319            #[inline]
1320            fn square(&self) -> Self
1321            {
1322                #squaring_impl
1323            }
1324
1325            fn sqrt_ratio(num: &Self, div: &Self) -> (::ff::derive::subtle::Choice, Self) {
1326                ::ff::helpers::sqrt_ratio_generic(num, div)
1327            }
1328
1329            fn sqrt(&self) -> ::ff::derive::subtle::CtOption<Self> {
1330                #sqrt_impl
1331            }
1332        }
1333
1334        impl #name {
1335            /// Compares two elements in native representation. This is only used
1336            /// internally.
1337            #[inline(always)]
1338            fn cmp_native(&self, other: &#name) -> ::core::cmp::Ordering {
1339                for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) {
1340                    if a < b {
1341                        return ::core::cmp::Ordering::Less
1342                    } else if a > b {
1343                        return ::core::cmp::Ordering::Greater
1344                    }
1345                }
1346
1347                ::core::cmp::Ordering::Equal
1348            }
1349
1350            /// Determines if the element is really in the field. This is only used
1351            /// internally.
1352            #[inline(always)]
1353            fn is_valid(&self) -> bool {
1354                // The Ord impl calls `reduce`, which in turn calls `is_valid`, so we use
1355                // this internal function to eliminate the cycle.
1356                self.cmp_native(&MODULUS_LIMBS) == ::core::cmp::Ordering::Less
1357            }
1358
1359            #[inline(always)]
1360            fn add_nocarry(&mut self, other: &#name) {
1361                let mut carry = 0;
1362
1363                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
1364                    let (new_a, new_carry) = ::ff::derive::adc(*a, *b, carry);
1365                    *a = new_a;
1366                    carry = new_carry;
1367                }
1368            }
1369
1370            #[inline(always)]
1371            fn sub_noborrow(&mut self, other: &#name) {
1372                let mut borrow = 0;
1373
1374                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
1375                    let (new_a, new_borrow) = ::ff::derive::sbb(*a, *b, borrow);
1376                    *a = new_a;
1377                    borrow = new_borrow;
1378                }
1379            }
1380
1381            /// Subtracts the modulus from this element if this element is not in the
1382            /// field. Only used interally.
1383            #[inline(always)]
1384            fn reduce(&mut self) {
1385                if !self.is_valid() {
1386                    self.sub_noborrow(&MODULUS_LIMBS);
1387                }
1388            }
1389
1390            #[allow(clippy::too_many_arguments)]
1391            #[inline(always)]
1392            fn mont_reduce(
1393                &mut self,
1394                #mont_paramlist
1395            )
1396            {
1397                // The Montgomery reduction here is based on Algorithm 14.32 in
1398                // Handbook of Applied Cryptography
1399                // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
1400
1401                #montgomery_impl
1402
1403                self.reduce();
1404            }
1405        }
1406    }
1407}
1408
1409#[cfg(test)]
1410mod tests {
1411    use super::*;
1412    #[test]
1413    fn test() {
1414        let blabla = quote! {
1415        #[PrimeFieldModulus = "7237005577332262213973186563042994240857116359379907606001950938285454250989"]
1416        #[PrimeFieldGenerator = "2"]
1417        #[PrimeFieldReprEndianness = "little"]
1418        pub struct ScalarField([u64; 4]);
1419
1420        };
1421        let y = prime_field_2(blabla);
1422        //panic!("{}", y);
1423    }
1424}