fff_derive/
lib.rs

1#![recursion_limit = "1024"]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5extern crate syn;
6#[macro_use]
7extern crate quote;
8
9extern crate num_bigint;
10extern crate num_integer;
11extern crate num_traits;
12
13use num_bigint::BigUint;
14use num_integer::Integer;
15use num_traits::{One, ToPrimitive, Zero};
16use quote::TokenStreamExt;
17use std::str::FromStr;
18
19const BLS_381_FR_MODULUS: &str =
20    "52435875175126190479447740508185965837690552500527637822603658699938581184513";
21
22#[proc_macro_derive(PrimeField, attributes(PrimeFieldModulus, PrimeFieldGenerator))]
23pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    // Parse the type definition
25    let ast: syn::DeriveInput = syn::parse(input).unwrap();
26
27    // The struct we're deriving for is a wrapper around a "Repr" type we must construct.
28    let repr_ident = fetch_wrapped_ident(&ast.data)
29        .expect("PrimeField derive only operates over tuple structs of a single item");
30
31    // We're given the modulus p of the prime field
32    let modulus_raw = fetch_attr("PrimeFieldModulus", &ast.attrs)
33        .expect("Please supply a PrimeFieldModulus attribute");
34    let modulus: BigUint = modulus_raw
35        .parse()
36        .expect("PrimeFieldModulus should be a number");
37
38    // We may be provided with a generator of p - 1 order. It is required that this generator be quadratic
39    // nonresidue.
40    let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs)
41        .expect("Please supply a PrimeFieldGenerator attribute")
42        .parse()
43        .expect("PrimeFieldGenerator should be a number");
44
45    // The arithmetic in this library only works if the modulus*2 is smaller than the backing
46    // representation. Compute the number of limbs we need.
47    let mut limbs = 1;
48    {
49        let mod2 = (&modulus) << 1; // modulus * 2
50        let mut cur = BigUint::one() << 64; // always 64-bit limbs for now
51        while cur < mod2 {
52            limbs += 1;
53            cur = cur << 64;
54        }
55    }
56
57    let mut gen = proc_macro2::TokenStream::new();
58
59    let (constants_impl, sqrt_impl) =
60        prime_field_constants_and_sqrt(&ast.ident, &repr_ident, modulus, limbs, generator);
61
62    gen.extend(constants_impl);
63    gen.extend(prime_field_repr_impl(&repr_ident, limbs));
64    gen.extend(prime_field_impl(
65        &ast.ident,
66        &repr_ident,
67        limbs,
68        &modulus_raw,
69    ));
70    gen.extend(sqrt_impl);
71
72    // Return the generated impl
73    gen.into()
74}
75
76/// Fetches the ident being wrapped by the type we're deriving.
77fn fetch_wrapped_ident(body: &syn::Data) -> Option<syn::Ident> {
78    match body {
79        &syn::Data::Struct(ref variant_data) => match variant_data.fields {
80            syn::Fields::Unnamed(ref fields) => {
81                if fields.unnamed.len() == 1 {
82                    match fields.unnamed[0].ty {
83                        syn::Type::Path(ref path) => {
84                            if path.path.segments.len() == 1 {
85                                return Some(path.path.segments[0].ident.clone());
86                            }
87                        }
88                        _ => {}
89                    }
90                }
91            }
92            _ => {}
93        },
94        _ => {}
95    };
96
97    None
98}
99
100/// Fetch an attribute string from the derived struct.
101fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
102    for attr in attrs {
103        if let Ok(meta) = attr.parse_meta() {
104            match meta {
105                syn::Meta::NameValue(nv) => {
106                    if nv.path.get_ident().map(|i| i.to_string()) == Some(name.to_string()) {
107                        match nv.lit {
108                            syn::Lit::Str(ref s) => return Some(s.value()),
109                            _ => {
110                                panic!("attribute {} should be a string", name);
111                            }
112                        }
113                    }
114                }
115                _ => {
116                    panic!("attribute {} should be a string", name);
117                }
118            }
119        }
120    }
121
122    None
123}
124
125// Implement PrimeFieldRepr for the wrapped ident `repr` with `limbs` limbs.
126fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream {
127    quote! {
128        #[derive(Copy, Clone, PartialEq, Eq, Default, ::serde::Serialize, ::serde::Deserialize)]
129        pub struct #repr(pub [u64; #limbs]);
130
131        impl ::std::fmt::Debug for #repr
132        {
133            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
134                write!(f, "0x")?;
135                for i in self.0.iter().rev() {
136                    write!(f, "{:016x}", *i)?;
137                }
138
139                Ok(())
140            }
141        }
142
143        impl ::std::fmt::Display for #repr {
144            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
145                write!(f, "0x")?;
146                for i in self.0.iter().rev() {
147                    write!(f, "{:016x}", *i)?;
148                }
149
150                Ok(())
151            }
152        }
153
154        impl AsRef<[u64]> for #repr {
155            #[inline(always)]
156            fn as_ref(&self) -> &[u64] {
157                &self.0
158            }
159        }
160
161        impl AsMut<[u64]> for #repr {
162            #[inline(always)]
163            fn as_mut(&mut self) -> &mut [u64] {
164                &mut self.0
165            }
166        }
167
168        impl From<u64> for #repr {
169            #[inline(always)]
170            fn from(val: u64) -> #repr {
171                use std::default::Default;
172
173                let mut repr = Self::default();
174                repr.0[0] = val;
175                repr
176            }
177        }
178
179        impl Ord for #repr {
180            #[inline(always)]
181            fn cmp(&self, other: &#repr) -> ::std::cmp::Ordering {
182                for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) {
183                    if a < b {
184                        return ::std::cmp::Ordering::Less
185                    } else if a > b {
186                        return ::std::cmp::Ordering::Greater
187                    }
188                }
189
190                ::std::cmp::Ordering::Equal
191            }
192        }
193
194        impl PartialOrd for #repr {
195            #[inline(always)]
196            fn partial_cmp(&self, other: &#repr) -> Option<::std::cmp::Ordering> {
197                Some(self.cmp(other))
198            }
199        }
200
201        impl ::fff::PrimeFieldRepr for #repr {
202            #[inline(always)]
203            fn is_odd(&self) -> bool {
204                self.0[0] & 1 == 1
205            }
206
207            #[inline(always)]
208            fn is_even(&self) -> bool {
209                !self.is_odd()
210            }
211
212            #[inline(always)]
213            fn is_zero(&self) -> bool {
214                self.0.iter().all(|&e| e == 0)
215            }
216
217            #[inline(always)]
218            fn shr(&mut self, mut n: u32) {
219                if n as usize >= 64 * #limbs {
220                    *self = Self::from(0);
221                    return;
222                }
223
224                while n >= 64 {
225                    let mut t = 0;
226                    for i in self.0.iter_mut().rev() {
227                        ::std::mem::swap(&mut t, i);
228                    }
229                    n -= 64;
230                }
231
232                if n > 0 {
233                    let mut t = 0;
234                    for i in self.0.iter_mut().rev() {
235                        let t2 = *i << (64 - n);
236                        *i >>= n;
237                        *i |= t;
238                        t = t2;
239                    }
240                }
241            }
242
243            #[inline(always)]
244            fn div2(&mut self) {
245                let mut t = 0;
246                for i in self.0.iter_mut().rev() {
247                    let t2 = *i << 63;
248                    *i >>= 1;
249                    *i |= t;
250                    t = t2;
251                }
252            }
253
254            #[inline(always)]
255            fn mul2(&mut self) {
256                let mut last = 0;
257                for i in &mut self.0 {
258                    let tmp = *i >> 63;
259                    *i <<= 1;
260                    *i |= last;
261                    last = tmp;
262                }
263            }
264
265            #[inline(always)]
266            fn shl(&mut self, mut n: u32) {
267                if n as usize >= 64 * #limbs {
268                    *self = Self::from(0);
269                    return;
270                }
271
272                while n >= 64 {
273                    let mut t = 0;
274                    for i in &mut self.0 {
275                        ::std::mem::swap(&mut t, i);
276                    }
277                    n -= 64;
278                }
279
280                if n > 0 {
281                    let mut t = 0;
282                    for i in &mut self.0 {
283                        let t2 = *i >> (64 - n);
284                        *i <<= n;
285                        *i |= t;
286                        t = t2;
287                    }
288                }
289            }
290
291            #[inline(always)]
292            fn num_bits(&self) -> u32 {
293                let mut ret = (#limbs as u32) * 64;
294                for i in self.0.iter().rev() {
295                    let leading = i.leading_zeros();
296                    ret -= leading;
297                    if leading != 64 {
298                        break;
299                    }
300                }
301
302                ret
303            }
304
305            #[inline(always)]
306            fn add_nocarry(&mut self, other: &#repr) {
307                let mut carry = 0;
308
309                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
310                    *a = ::fff::adc(*a, *b, &mut carry);
311                }
312            }
313
314            #[inline(always)]
315            fn sub_noborrow(&mut self, other: &#repr) {
316                let mut borrow = 0;
317
318                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
319                    *a = ::fff::sbb(*a, *b, &mut borrow);
320                }
321            }
322        }
323    }
324}
325
326/// Convert BigUint into a vector of 64-bit limbs.
327fn biguint_to_real_u64_vec(mut v: BigUint, limbs: usize) -> Vec<u64> {
328    let m = BigUint::one() << 64;
329    let mut ret = vec![];
330
331    while v > BigUint::zero() {
332        ret.push((&v % &m).to_u64().unwrap());
333        v = v >> 64;
334    }
335
336    while ret.len() < limbs {
337        ret.push(0);
338    }
339
340    assert!(ret.len() == limbs);
341
342    ret
343}
344
345/// Convert BigUint into a tokenized vector of 64-bit limbs.
346fn biguint_to_u64_vec(v: BigUint, limbs: usize) -> proc_macro2::TokenStream {
347    let ret = biguint_to_real_u64_vec(v, limbs);
348    quote!([#(#ret,)*])
349}
350
351fn biguint_num_bits(mut v: BigUint) -> u32 {
352    let mut bits = 0;
353
354    while v != BigUint::zero() {
355        v = v >> 1;
356        bits += 1;
357    }
358
359    bits
360}
361
362/// BigUint modular exponentiation by square-and-multiply.
363fn exp(base: BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint {
364    let mut ret = BigUint::one();
365
366    for i in exp
367        .to_bytes_be()
368        .into_iter()
369        .flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd()))
370    {
371        ret = (&ret * &ret) % modulus;
372        if i {
373            ret = (ret * &base) % modulus;
374        }
375    }
376
377    ret
378}
379
380#[test]
381fn test_exp() {
382    assert_eq!(
383        exp(
384            BigUint::from_str("4398572349857239485729348572983472345").unwrap(),
385            &BigUint::from_str("5489673498567349856734895").unwrap(),
386            &BigUint::from_str(
387                "52435875175126190479447740508185965837690552500527637822603658699938581184513"
388            )
389            .unwrap()
390        ),
391        BigUint::from_str(
392            "4371221214068404307866768905142520595925044802278091865033317963560480051536"
393        )
394        .unwrap()
395    );
396}
397
398fn prime_field_constants_and_sqrt(
399    name: &syn::Ident,
400    repr: &syn::Ident,
401    modulus: BigUint,
402    limbs: usize,
403    generator: BigUint,
404) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
405    let modulus_num_bits = biguint_num_bits(modulus.clone());
406
407    // The number of bits we should "shave" from a randomly sampled reputation, i.e.,
408    // if our modulus is 381 bits and our representation is 384 bits, we should shave
409    // 3 bits from the beginning of a randomly sampled 384 bit representation to
410    // reduce the cost of rejection sampling.
411    let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone());
412
413    // Compute R = 2**(64 * limbs) mod m
414    let r = (BigUint::one() << (limbs * 64)) % &modulus;
415
416    // modulus - 1 = 2^s * t
417    let mut s: u32 = 0;
418    let mut t = &modulus - BigUint::from_str("1").unwrap();
419    while t.is_even() {
420        t = t >> 1;
421        s += 1;
422    }
423
424    // Compute 2^s root of unity given the generator
425    let root_of_unity = biguint_to_u64_vec(
426        (exp(generator.clone(), &t, &modulus) * &r) % &modulus,
427        limbs,
428    );
429    let generator = biguint_to_u64_vec((generator.clone() * &r) % &modulus, limbs);
430
431    let mod_minus_1_over_2 =
432        biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs);
433    let legendre_impl = quote! {
434        fn legendre(&self) -> ::fff::LegendreSymbol {
435            // s = self^((modulus - 1) // 2)
436            let s = self.pow(#mod_minus_1_over_2);
437            if s == Self::zero() {
438                ::fff::LegendreSymbol::Zero
439            } else if s == Self::one() {
440                ::fff::LegendreSymbol::QuadraticResidue
441            } else {
442                ::fff::LegendreSymbol::QuadraticNonResidue
443            }
444        }
445    };
446
447    let sqrt_impl =
448        if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
449            let mod_minus_3_over_4 =
450                biguint_to_u64_vec((&modulus - BigUint::from_str("3").unwrap()) >> 2, limbs);
451
452            // Compute -R as (m - r)
453            let rneg = biguint_to_u64_vec(&modulus - &r, limbs);
454
455            quote! {
456                impl ::fff::SqrtField for #name {
457                    #legendre_impl
458
459                    fn sqrt(&self) -> Option<Self> {
460                        // Shank's algorithm for q mod 4 = 3
461                        // https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2)
462
463                        let mut a1 = self.pow(#mod_minus_3_over_4);
464
465                        let mut a0 = a1;
466                        a0.square();
467                        a0.mul_assign(self);
468
469                        if a0.0 == #repr(#rneg) {
470                            None
471                        } else {
472                            a1.mul_assign(self);
473                            Some(a1)
474                        }
475                    }
476                }
477            }
478        } else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() {
479            let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs);
480            let t = biguint_to_u64_vec(t.clone(), limbs);
481
482            quote! {
483                impl ::fff::SqrtField for #name {
484                    #legendre_impl
485
486                    fn sqrt(&self) -> Option<Self> {
487                        // Tonelli-Shank's algorithm for q mod 16 = 1
488                        // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
489
490                        match self.legendre() {
491                            ::fff::LegendreSymbol::Zero => Some(*self),
492                            ::fff::LegendreSymbol::QuadraticNonResidue => None,
493                            ::fff::LegendreSymbol::QuadraticResidue => {
494                                let mut c = #name(ROOT_OF_UNITY);
495                                let mut r = self.pow(#t_plus_1_over_2);
496                                let mut t = self.pow(#t);
497                                let mut m = S;
498
499                                while t != Self::one() {
500                                    let mut i = 1;
501                                    {
502                                        let mut t2i = t;
503                                        t2i.square();
504                                        loop {
505                                            if t2i == Self::one() {
506                                                break;
507                                            }
508                                            t2i.square();
509                                            i += 1;
510                                        }
511                                    }
512
513                                    for _ in 0..(m - i - 1) {
514                                        c.square();
515                                    }
516                                    r.mul_assign(&c);
517                                    c.square();
518                                    t.mul_assign(&c);
519                                    m = i;
520                                }
521
522                                Some(r)
523                            }
524                        }
525                    }
526                }
527            }
528        } else {
529            quote! {}
530        };
531
532    // Compute R^2 mod m
533    let r2 = biguint_to_u64_vec((&r * &r) % &modulus, limbs);
534
535    let r = biguint_to_u64_vec(r, limbs);
536    let modulus = biguint_to_real_u64_vec(modulus, limbs);
537
538    // Compute -m^-1 mod 2**64 by exponentiating by totient(2**64) - 1
539    let mut inv = 1u64;
540    for _ in 0..63 {
541        inv = inv.wrapping_mul(inv);
542        inv = inv.wrapping_mul(modulus[0]);
543    }
544    inv = inv.wrapping_neg();
545
546    (
547        quote! {
548            /// This is the modulus m of the prime field
549            const MODULUS: #repr = #repr([#(#modulus,)*]);
550
551            /// The number of bits needed to represent the modulus.
552            const MODULUS_BITS: u32 = #modulus_num_bits;
553
554            /// The number of bits that must be shaved from the beginning of
555            /// the representation when randomly sampling.
556            const REPR_SHAVE_BITS: u32 = #repr_shave_bits;
557
558            /// 2^{limbs*64} mod m
559            const R: #repr = #repr(#r);
560
561            /// 2^{limbs*64*2} mod m
562            const R2: #repr = #repr(#r2);
563
564            /// -(m^{-1} mod m) mod m
565            const INV: u64 = #inv;
566
567            /// Multiplicative generator of `MODULUS` - 1 order, also quadratic
568            /// nonresidue.
569            const GENERATOR: #repr = #repr(#generator);
570
571            /// 2^s * t = MODULUS - 1 with t odd
572            const S: u32 = #s;
573
574            /// 2^s root of unity computed by GENERATOR^t
575            const ROOT_OF_UNITY: #repr = #repr(#root_of_unity);
576        },
577        sqrt_impl,
578    )
579}
580
581/// Implement PrimeField for the derived type.
582fn prime_field_impl(
583    name: &syn::Ident,
584    repr: &syn::Ident,
585    limbs: usize,
586    modulus_raw: &str,
587) -> proc_macro2::TokenStream {
588    // Returns r{n} as an ident.
589    fn get_temp(n: usize) -> syn::Ident {
590        syn::Ident::new(&format!("r{}", n), proc_macro2::Span::call_site())
591    }
592
593    // The parameter list for the mont_reduce() internal method.
594    // r0: u64, mut r1: u64, mut r2: u64, ...
595    let mut mont_paramlist = proc_macro2::TokenStream::new();
596    mont_paramlist.append_separated(
597        (0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| {
598            if i != 0 {
599                quote! {mut #x: u64}
600            } else {
601                quote! {#x: u64}
602            }
603        }),
604        proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
605    );
606
607    // Implement montgomery reduction for some number of limbs
608    fn mont_impl(limbs: usize) -> proc_macro2::TokenStream {
609        let mut gen = proc_macro2::TokenStream::new();
610
611        for i in 0..limbs {
612            {
613                let temp = get_temp(i);
614                gen.extend(quote! {
615                    let k = #temp.wrapping_mul(INV);
616                    let mut carry = 0;
617                    ::fff::mac_with_carry(#temp, k, MODULUS.0[0], &mut carry);
618                });
619            }
620
621            for j in 1..limbs {
622                let temp = get_temp(i + j);
623                gen.extend(quote! {
624                    #temp = ::fff::mac_with_carry(#temp, k, MODULUS.0[#j], &mut carry);
625                });
626            }
627
628            let temp = get_temp(i + limbs);
629
630            if i == 0 {
631                gen.extend(quote! {
632                    #temp = ::fff::adc(#temp, 0, &mut carry);
633                });
634            } else {
635                gen.extend(quote! {
636                    #temp = ::fff::adc(#temp, carry2, &mut carry);
637                });
638            }
639
640            if i != (limbs - 1) {
641                gen.extend(quote! {
642                    let carry2 = carry;
643                });
644            }
645        }
646
647        for i in 0..limbs {
648            let temp = get_temp(limbs + i);
649
650            gen.extend(quote! {
651                (self.0).0[#i] = #temp;
652            });
653        }
654
655        gen
656    }
657
658    fn sqr_impl(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
659        let mut gen = proc_macro2::TokenStream::new();
660
661        for i in 0..(limbs - 1) {
662            gen.extend(quote! {
663                let mut carry = 0;
664            });
665
666            for j in (i + 1)..limbs {
667                let temp = get_temp(i + j);
668                if i == 0 {
669                    gen.extend(quote! {
670                        let #temp = ::fff::mac_with_carry(0, (#a.0).0[#i], (#a.0).0[#j], &mut carry);
671                    });
672                } else {
673                    gen.extend(quote!{
674                        let #temp = ::fff::mac_with_carry(#temp, (#a.0).0[#i], (#a.0).0[#j], &mut carry);
675                    });
676                }
677            }
678
679            let temp = get_temp(i + limbs);
680
681            gen.extend(quote! {
682                let #temp = carry;
683            });
684        }
685
686        for i in 1..(limbs * 2) {
687            let temp0 = get_temp(limbs * 2 - i);
688            let temp1 = get_temp(limbs * 2 - i - 1);
689
690            if i == 1 {
691                gen.extend(quote! {
692                    let #temp0 = #temp1 >> 63;
693                });
694            } else if i == (limbs * 2 - 1) {
695                gen.extend(quote! {
696                    let #temp0 = #temp0 << 1;
697                });
698            } else {
699                gen.extend(quote! {
700                    let #temp0 = (#temp0 << 1) | (#temp1 >> 63);
701                });
702            }
703        }
704
705        gen.extend(quote! {
706            let mut carry = 0;
707        });
708
709        for i in 0..limbs {
710            let temp0 = get_temp(i * 2);
711            let temp1 = get_temp(i * 2 + 1);
712            if i == 0 {
713                gen.extend(quote! {
714                    let #temp0 = ::fff::mac_with_carry(0, (#a.0).0[#i], (#a.0).0[#i], &mut carry);
715                });
716            } else {
717                gen.extend(quote!{
718                    let #temp0 = ::fff::mac_with_carry(#temp0, (#a.0).0[#i], (#a.0).0[#i], &mut carry);
719                });
720            }
721
722            gen.extend(quote! {
723                let #temp1 = ::fff::adc(#temp1, 0, &mut carry);
724            });
725        }
726
727        let mut mont_calling = proc_macro2::TokenStream::new();
728        mont_calling.append_separated(
729            (0..(limbs * 2)).map(|i| get_temp(i)),
730            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
731        );
732
733        gen.extend(quote! {
734            self.mont_reduce(#mont_calling);
735        });
736
737        gen
738    }
739
740    fn mul_impl(
741        a: proc_macro2::TokenStream,
742        b: proc_macro2::TokenStream,
743        limbs: usize,
744        modulus_raw: &str,
745    ) -> proc_macro2::TokenStream {
746        if limbs == 4 && modulus_raw == BLS_381_FR_MODULUS {
747            mul_impl_asm4(a, b)
748        } else {
749            mul_impl_default(a, b, limbs)
750        }
751    }
752
753    fn mul_impl_asm4(
754        a: proc_macro2::TokenStream,
755        b: proc_macro2::TokenStream,
756    ) -> proc_macro2::TokenStream {
757        // x86_64 asm for four limbs
758        let default_impl = mul_impl_default(a.clone(), b.clone(), 4);
759
760        let mut gen = proc_macro2::TokenStream::new();
761        gen.extend(quote! {
762            #[cfg(target_arch = "x86_64")]
763            {
764                if *::fff::CPU_SUPPORTS_ADX_INSTRUCTION {
765                    ::fff::mod_mul_4w_assign(&mut (#a.0).0, &(#b.0).0);
766                } else {
767                    #default_impl
768                }
769            }
770            #[cfg(not(target_arch = "x86_64"))]
771            {
772                #default_impl
773            }
774        });
775
776        gen
777    }
778
779    fn mul_impl_default(
780        a: proc_macro2::TokenStream,
781        b: proc_macro2::TokenStream,
782        limbs: usize,
783    ) -> proc_macro2::TokenStream {
784        let mut gen = proc_macro2::TokenStream::new();
785
786        for i in 0..limbs {
787            gen.extend(quote! {
788                let mut carry = 0;
789            });
790
791            for j in 0..limbs {
792                let temp = get_temp(i + j);
793
794                if i == 0 {
795                    gen.extend(quote! {
796                        let #temp = ::fff::mac_with_carry(0, (#a.0).0[#i], (#b.0).0[#j], &mut carry);
797                    });
798                } else {
799                    gen.extend(quote!{
800                        let #temp = ::fff::mac_with_carry(#temp, (#a.0).0[#i], (#b.0).0[#j], &mut carry);
801                    });
802                }
803            }
804
805            let temp = get_temp(i + limbs);
806
807            gen.extend(quote! {
808                let #temp = carry;
809            });
810        }
811
812        let mut mont_calling = proc_macro2::TokenStream::new();
813        mont_calling.append_separated(
814            (0..(limbs * 2)).map(|i| get_temp(i)),
815            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
816        );
817
818        gen.extend(quote! {
819            self.mont_reduce(#mont_calling);
820        });
821
822        gen
823    }
824
825    fn add_assign_impl(
826        a: proc_macro2::TokenStream,
827        b: proc_macro2::TokenStream,
828        limbs: usize,
829    ) -> proc_macro2::TokenStream {
830        if limbs == 4 {
831            add_assign_asm_impl(a, b, limbs)
832        } else {
833            add_assign_default_impl(a, b, limbs)
834        }
835    }
836
837    fn add_assign_asm_impl(
838        a: proc_macro2::TokenStream,
839        b: proc_macro2::TokenStream,
840        limbs: usize,
841    ) -> proc_macro2::TokenStream {
842        let mut gen = proc_macro2::TokenStream::new();
843        let default_impl = add_assign_default_impl(a.clone(), b.clone(), limbs);
844
845        gen.extend(quote! {
846            #[cfg(target_arch = "x86_64")]
847            {
848                // This cannot exceed the backing capacity.
849                use std::arch::x86_64::*;
850                use std::mem;
851
852                unsafe {
853                    let mut carry = _addcarry_u64(
854                        0,
855                        (#a.0).0[0],
856                        (#b.0).0[0],
857                        &mut (#a.0).0[0]
858                    );
859                    carry = _addcarry_u64(
860                        carry, (#a.0).0[1],
861                        (#b.0).0[1],
862                        &mut (#a.0).0[1]
863                    );
864                    carry = _addcarry_u64(
865                        carry, (#a.0).0[2],
866                        (#b.0).0[2],
867                        &mut (#a.0).0[2]
868                    );
869                    _addcarry_u64(
870                        carry,
871                        (#a.0).0[3],
872                        (#b.0).0[3],
873                        &mut (#a.0).0[3]
874                    );
875
876                    let mut s_sub: [u64; 4] = mem::uninitialized();
877
878                    carry = _subborrow_u64(
879                        0,
880                        (#a.0).0[0],
881                        MODULUS.0[0],
882                        &mut s_sub[0]
883                    );
884                    carry = _subborrow_u64(
885                        carry,
886                        (#a.0).0[1],
887                        MODULUS.0[1],
888                        &mut s_sub[1]
889                    );
890                    carry = _subborrow_u64(
891                        carry,
892                        (#a.0).0[2],
893                        MODULUS.0[2],
894                        &mut s_sub[2]
895                    );
896                    carry = _subborrow_u64(
897                        carry,
898                        (#a.0).0[3],
899                        MODULUS.0[3],
900                        &mut s_sub[3]
901                    );
902
903                    if carry == 0 {
904                        // Direct assign fails since size can be 4 or 6
905                        // Obviously code doesn't work at all for size 6
906                        // (#a).0 = s_sub;
907                        (#a.0).0[0] = s_sub[0];
908                        (#a.0).0[1] = s_sub[1];
909                        (#a.0).0[2] = s_sub[2];
910                        (#a.0).0[3] = s_sub[3];
911                    }
912                }
913            }
914            #[cfg(not(target_arch = "x86_64"))]
915            {
916                #default_impl
917            }
918        });
919
920        gen
921    }
922
923    fn add_assign_default_impl(
924        a: proc_macro2::TokenStream,
925        b: proc_macro2::TokenStream,
926        _limbs: usize,
927    ) -> proc_macro2::TokenStream {
928        let mut gen = proc_macro2::TokenStream::new();
929
930        gen.extend(quote! {
931            // This cannot exceed the backing capacity.
932            #a.0.add_nocarry(&#b.0);
933
934            // However, it may need to be reduced.
935            #a.reduce();
936        });
937        gen
938    }
939
940    let squaring_impl = sqr_impl(quote! {self}, limbs);
941    let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs, modulus_raw);
942    let add_assign = add_assign_impl(quote! {self}, quote! {other}, limbs);
943    let montgomery_impl = mont_impl(limbs);
944
945    // (self.0).0[0], (self.0).0[1], ..., 0, 0, 0, 0, ...
946    let mut into_repr_params = proc_macro2::TokenStream::new();
947    into_repr_params.append_separated(
948        (0..limbs)
949            .map(|i| quote! { (self.0).0[#i] })
950            .chain((0..limbs).map(|_| quote! {0})),
951        proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
952    );
953
954    let top_limb_index = limbs - 1;
955
956    quote! {
957        impl ::std::marker::Copy for #name { }
958
959        impl ::std::clone::Clone for #name {
960            fn clone(&self) -> #name {
961                *self
962            }
963        }
964
965        impl ::std::cmp::PartialEq for #name {
966            fn eq(&self, other: &#name) -> bool {
967                self.0 == other.0
968            }
969        }
970
971        impl ::std::cmp::Eq for #name { }
972
973        impl ::std::fmt::Debug for #name
974        {
975            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
976                write!(f, "{}({:?})", stringify!(#name), self.into_repr())
977            }
978        }
979
980        /// Elements are ordered lexicographically.
981        impl Ord for #name {
982            #[inline(always)]
983            fn cmp(&self, other: &#name) -> ::std::cmp::Ordering {
984                self.into_repr().cmp(&other.into_repr())
985            }
986        }
987
988        impl PartialOrd for #name {
989            #[inline(always)]
990            fn partial_cmp(&self, other: &#name) -> Option<::std::cmp::Ordering> {
991                Some(self.cmp(other))
992            }
993        }
994
995        impl ::std::fmt::Display for #name {
996            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
997                write!(f, "{}({})", stringify!(#name), self.into_repr())
998            }
999        }
1000
1001        impl From<#name> for #repr {
1002            fn from(e: #name) -> #repr {
1003                e.into_repr()
1004            }
1005        }
1006
1007        impl ::serde::Serialize for #name {
1008            fn serialize<S: ::serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
1009                self.into_repr().serialize(s)
1010            }
1011        }
1012
1013        impl<'de> ::serde::Deserialize<'de> for #name {
1014            fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
1015                use serde::de::Error;
1016                #name::from_repr(#repr::deserialize(d)?)
1017                    .map_err(|_| D::Error::custom(stringify!(format!("deserialized bytes don't encode a valid {}", #name))))
1018            }
1019        }
1020
1021        impl ::fff::PrimeField for #name {
1022            type Repr = #repr;
1023
1024            fn from_repr(r: #repr) -> Result<#name, PrimeFieldDecodingError> {
1025                let mut r = #name(r);
1026                if r.is_valid() {
1027                    r.mul_assign(&#name(R2));
1028
1029                    Ok(r)
1030                } else {
1031                    Err(PrimeFieldDecodingError::NotInField(format!("{}", r.0)))
1032                }
1033            }
1034
1035            fn into_repr(&self) -> #repr {
1036                let mut r = *self;
1037                r.mont_reduce(
1038                    #into_repr_params
1039                );
1040
1041                r.0
1042            }
1043
1044            fn char() -> #repr {
1045                MODULUS
1046            }
1047
1048            const NUM_BITS: u32 = MODULUS_BITS;
1049
1050            const CAPACITY: u32 = Self::NUM_BITS - 1;
1051
1052            fn multiplicative_generator() -> Self {
1053                #name(GENERATOR)
1054            }
1055
1056            const S: u32 = S;
1057
1058            fn root_of_unity() -> Self {
1059                #name(ROOT_OF_UNITY)
1060            }
1061
1062
1063            fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
1064                use std::convert::TryInto;
1065
1066                let mut repr = #repr::default();
1067                for (limb, chunk) in repr.0.iter_mut().zip(bytes.chunks_exact(8)) {
1068                    *limb = u64::from_le_bytes(chunk.try_into().unwrap());
1069                }
1070
1071                // Mask away the unused most-significant bits.
1072                repr.0.as_mut()[#top_limb_index] &= 0xffffffffffffffff >> REPR_SHAVE_BITS;
1073
1074                #name::from_repr(repr).ok()
1075            }
1076        }
1077
1078        impl ::fff::Field for #name {
1079            /// Computes a uniformly random element using rejection sampling.
1080            fn random<R: ::rand_core::RngCore>(rng: &mut R) -> Self {
1081                loop {
1082                    let mut tmp = {
1083                        let mut repr = [0u64; #limbs];
1084                        for i in 0..#limbs {
1085                            repr[i] = rng.next_u64();
1086                        }
1087                        #name(#repr(repr))
1088                    };
1089
1090                    // Mask away the unused most-significant bits.
1091                    tmp.0.as_mut()[#top_limb_index] &= 0xffffffffffffffff >> REPR_SHAVE_BITS;
1092
1093                    if tmp.is_valid() {
1094                        return tmp
1095                    }
1096                }
1097            }
1098
1099            #[inline]
1100            fn zero() -> Self {
1101                #name(#repr::from(0))
1102            }
1103
1104            #[inline]
1105            fn one() -> Self {
1106                #name(R)
1107            }
1108
1109            #[inline]
1110            fn is_zero(&self) -> bool {
1111                self.0.is_zero()
1112            }
1113
1114            #[inline]
1115            fn add_assign(&mut self, other: &#name) {
1116                #add_assign
1117            }
1118
1119            #[inline]
1120            fn double(&mut self) {
1121                // This cannot exceed the backing capacity.
1122                self.0.mul2();
1123
1124                // However, it may need to be reduced.
1125                self.reduce();
1126            }
1127
1128            #[inline]
1129            fn sub_assign(&mut self, other: &#name) {
1130                // If `other` is larger than `self`, we'll need to add the modulus to self first.
1131                if other.0 > self.0 {
1132                    self.0.add_nocarry(&MODULUS);
1133                }
1134
1135                self.0.sub_noborrow(&other.0);
1136            }
1137
1138            #[inline]
1139            fn negate(&mut self) {
1140                if !self.is_zero() {
1141                    let mut tmp = MODULUS;
1142                    tmp.sub_noborrow(&self.0);
1143                    self.0 = tmp;
1144                }
1145            }
1146
1147            fn inverse(&self) -> Option<Self> {
1148                if self.is_zero() {
1149                    None
1150                } else {
1151                    // Guajardo Kumar Paar Pelzl
1152                    // Efficient Software-Implementation of Finite Fields with Applications to Cryptography
1153                    // Algorithm 16 (BEA for Inversion in Fp)
1154
1155                    let one = #repr::from(1);
1156
1157                    let mut u = self.0;
1158                    let mut v = MODULUS;
1159                    let mut b = #name(R2); // Avoids unnecessary reduction step.
1160                    let mut c = Self::zero();
1161
1162                    while u != one && v != one {
1163                        while u.is_even() {
1164                            u.div2();
1165
1166                            if b.0.is_even() {
1167                                b.0.div2();
1168                            } else {
1169                                b.0.add_nocarry(&MODULUS);
1170                                b.0.div2();
1171                            }
1172                        }
1173
1174                        while v.is_even() {
1175                            v.div2();
1176
1177                            if c.0.is_even() {
1178                                c.0.div2();
1179                            } else {
1180                                c.0.add_nocarry(&MODULUS);
1181                                c.0.div2();
1182                            }
1183                        }
1184
1185                        if v < u {
1186                            u.sub_noborrow(&v);
1187                            b.sub_assign(&c);
1188                        } else {
1189                            v.sub_noborrow(&u);
1190                            c.sub_assign(&b);
1191                        }
1192                    }
1193
1194                    if u == one {
1195                        Some(b)
1196                    } else {
1197                        Some(c)
1198                    }
1199                }
1200            }
1201
1202            #[inline(always)]
1203            fn frobenius_map(&mut self, _: usize) {
1204                // This has no effect in a prime field.
1205            }
1206
1207            #[inline]
1208            fn mul_assign(&mut self, other: &#name)
1209            {
1210                #multiply_impl
1211            }
1212
1213            #[inline]
1214            fn square(&mut self)
1215            {
1216                #squaring_impl
1217            }
1218        }
1219
1220        impl #name {
1221            /// Determines if the element is really in the field. This is only used
1222            /// internally.
1223            #[inline(always)]
1224            fn is_valid(&self) -> bool {
1225                self.0 < MODULUS
1226            }
1227
1228            /// Subtracts the modulus from this element if this element is not in the
1229            /// field. Only used interally.
1230            #[inline(always)]
1231            fn reduce(&mut self) {
1232                if !self.is_valid() {
1233                    self.0.sub_noborrow(&MODULUS);
1234                }
1235            }
1236
1237            #[inline(always)]
1238            fn mont_reduce(
1239                &mut self,
1240                #mont_paramlist
1241            )
1242            {
1243                // The Montgomery reduction here is based on Algorithm 14.32 in
1244                // Handbook of Applied Cryptography
1245                // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
1246
1247                #montgomery_impl
1248
1249                self.reduce();
1250            }
1251        }
1252    }
1253}