Skip to main content

ark_ff/fields/models/fp/
montgomery_backend.rs

1use super::{Fp, FpConfig};
2use crate::{
3    biginteger::arithmetic as fa, BigInt, BigInteger, PrimeField, SqrtPrecomputation, Zero,
4};
5use ark_ff_macros::unroll_for_loops;
6use ark_std::marker::PhantomData;
7
8/// A trait that specifies the constants and arithmetic procedures
9/// for Montgomery arithmetic over the prime field defined by `MODULUS`.
10///
11/// # Note
12/// Manual implementation of this trait is not recommended unless one wishes
13/// to specialize arithmetic methods. Instead, the
14/// [`MontConfig`][`ark_ff_macros::MontConfig`] derive macro should be used.
15pub trait MontConfig<const N: usize>: 'static + Sync + Send + Sized {
16    /// The modulus of the field.
17    const MODULUS: BigInt<N>;
18
19    /// Let `M` be the power of 2^64 nearest to `Self::MODULUS_BITS`. Then
20    /// `R = M % Self::MODULUS`.
21    const R: BigInt<N> = Self::MODULUS.montgomery_r();
22
23    /// R2 = R^2 % Self::MODULUS
24    const R2: BigInt<N> = Self::MODULUS.montgomery_r2();
25
26    /// INV = -MODULUS^{-1} mod 2^64
27    const INV: u64 = inv::<Self, N>();
28
29    /// A multiplicative generator of the field.
30    /// `Self::GENERATOR` is an element having multiplicative order
31    /// `Self::MODULUS - 1`.
32    const GENERATOR: Fp<MontBackend<Self, N>, N>;
33
34    /// Can we use the no-carry optimization for multiplication
35    /// outlined [here](https://hackmd.io/@gnark/modular_multiplication)?
36    ///
37    /// This optimization applies if
38    /// (a) `Self::MODULUS[N-1] < u64::MAX >> 1`, and
39    /// (b) the bits of the modulus are not all 1.
40    #[doc(hidden)]
41    const CAN_USE_NO_CARRY_MUL_OPT: bool = can_use_no_carry_mul_optimization::<Self, N>();
42
43    /// Can we use the no-carry optimization for squaring
44    /// outlined [here](https://hackmd.io/@gnark/modular_multiplication)?
45    ///
46    /// This optimization applies if
47    /// (a) `Self::MODULUS[N-1] < u64::MAX >> 2`, and
48    /// (b) the bits of the modulus are not all 1.
49    #[doc(hidden)]
50    const CAN_USE_NO_CARRY_SQUARE_OPT: bool = can_use_no_carry_mul_optimization::<Self, N>();
51
52    /// Does the modulus have a spare unused bit
53    ///
54    /// This condition applies if
55    /// (a) `Self::MODULUS[N-1] >> 63 == 0`
56    #[doc(hidden)]
57    const MODULUS_HAS_SPARE_BIT: bool = modulus_has_spare_bit::<Self, N>();
58
59    /// 2^s root of unity computed by GENERATOR^t
60    const TWO_ADIC_ROOT_OF_UNITY: Fp<MontBackend<Self, N>, N>;
61
62    /// An integer `b` such that there exists a multiplicative subgroup
63    /// of size `b^k` for some integer `k`.
64    const SMALL_SUBGROUP_BASE: Option<u32> = None;
65
66    /// The integer `k` such that there exists a multiplicative subgroup
67    /// of size `Self::SMALL_SUBGROUP_BASE^k`.
68    const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = None;
69
70    /// GENERATOR^((MODULUS-1) / (2^s *
71    /// SMALL_SUBGROUP_BASE^SMALL_SUBGROUP_BASE_ADICITY)).
72    /// Used for mixed-radix FFT.
73    const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<Fp<MontBackend<Self, N>, N>> = None;
74
75    /// Precomputed material for use when computing square roots.
76    /// The default is to use the standard Tonelli-Shanks algorithm.
77    const SQRT_PRECOMP: Option<SqrtPrecomputation<Fp<MontBackend<Self, N>, N>>> =
78        sqrt_precomputation();
79
80    /// (MODULUS + 1) / 4 when MODULUS % 4 == 3. Used for square root precomputations.
81    #[doc(hidden)]
82    const MODULUS_PLUS_ONE_DIV_FOUR: Option<BigInt<N>> = {
83        match Self::MODULUS.mod_4() == 3 {
84            true => {
85                let (modulus_plus_one, carry) = Self::MODULUS.const_add_with_carry(&BigInt::one());
86                let mut result = modulus_plus_one.divide_by_2_round_down();
87                // Since modulus_plus_one is even, dividing by 2 results in a MSB of 0.
88                // Thus we can set MSB to `carry` to get the correct result of (MODULUS + 1) // 2:
89                result.0[N - 1] |= (carry as u64) << 63;
90                Some(result.divide_by_2_round_down())
91            },
92            false => None,
93        }
94    };
95
96    /// (MODULUS + 3) / 8 when MODULUS % 8 == 5. Used for square root precomputations.
97    #[doc(hidden)]
98    const MODULUS_PLUS_THREE_DIV_EIGHT: Option<BigInt<N>> = {
99        match Self::MODULUS.mod_8() == 5 {
100            true => {
101                let (modulus_plus_three, carry) = Self::MODULUS.const_add_with_carry(&BigInt!("3"));
102                let mut result = modulus_plus_three.divide_by_2_round_down();
103                // Since modulus_plus_one is even, dividing by 2 results in a MSB of 0.
104                // Thus we can set MSB to `carry` to get the correct result of (MODULUS + 1) // 2:
105                result.0[N - 1] |= (carry as u64) << 63;
106                result = result.divide_by_2_round_down();
107
108                Some(result.divide_by_2_round_down())
109            },
110            false => None,
111        }
112    };
113
114    /// (MODULUS - 1) / 4 when MODULUS % 8 == 5. Used for square root precomputations.
115    #[doc(hidden)]
116    const MODULUS_MINUS_ONE_DIV_FOUR: Option<BigInt<N>> = {
117        match Self::MODULUS.mod_8() == 5 {
118            true => {
119                let (modulus_plus_three, _) = Self::MODULUS.const_sub_with_borrow(&BigInt::one());
120                let result = modulus_plus_three.divide_by_2_round_down();
121                Some(result.divide_by_2_round_down())
122            },
123            false => None,
124        }
125    };
126
127    /// Sets `a = a + b`.
128    #[inline(always)]
129    fn add_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
130        // This cannot exceed the backing capacity.
131        let c = a.0.add_with_carry(&b.0);
132        // However, it may need to be reduced
133        if Self::MODULUS_HAS_SPARE_BIT {
134            a.subtract_modulus()
135        } else {
136            a.subtract_modulus_with_carry(c)
137        }
138    }
139
140    /// Sets `a = a - b`.
141    #[inline(always)]
142    fn sub_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
143        // If `other` is larger than `self`, add the modulus to self first.
144        if b.0 > a.0 {
145            a.0.add_with_carry(&Self::MODULUS);
146        }
147        a.0.sub_with_borrow(&b.0);
148    }
149
150    /// Sets `a = 2 * a`.
151    #[inline(always)]
152    fn double_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
153        // This cannot exceed the backing capacity.
154        let c = a.0.mul2();
155        // However, it may need to be reduced.
156        if Self::MODULUS_HAS_SPARE_BIT {
157            a.subtract_modulus()
158        } else {
159            a.subtract_modulus_with_carry(c)
160        }
161    }
162
163    /// Sets `a = -a`.
164    #[inline(always)]
165    fn neg_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
166        if !a.is_zero() {
167            let mut tmp = Self::MODULUS;
168            tmp.sub_with_borrow(&a.0);
169            a.0 = tmp;
170        }
171    }
172
173    /// This modular multiplication algorithm uses Montgomery
174    /// reduction for efficient implementation. It also additionally
175    /// uses the "no-carry optimization" outlined
176    /// [here](https://hackmd.io/@gnark/modular_multiplication) if
177    /// `Self::MODULUS` has (a) a non-zero MSB, and (b) at least one
178    /// zero bit in the rest of the modulus.
179    #[unroll_for_loops(12)]
180    #[inline(always)]
181    fn mul_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
182        // No-carry optimisation applied to CIOS
183        if Self::CAN_USE_NO_CARRY_MUL_OPT {
184            if N <= 6
185                && N > 1
186                && cfg!(all(
187                    feature = "asm",
188                    target_feature = "bmi2",
189                    target_feature = "adx",
190                    target_arch = "x86_64"
191                ))
192            {
193                #[cfg(
194                    all(
195                        feature = "asm",
196                        target_feature = "bmi2",
197                        target_feature = "adx",
198                        target_arch = "x86_64"
199                    )
200                )]
201                #[allow(unsafe_code)]
202                #[rustfmt::skip]
203
204                // Tentatively avoid using assembly for `N == 1`.
205                match N {
206                    2 => { ark_ff_asm::x86_64_asm_mul!(2, (a.0).0, (b.0).0); },
207                    3 => { ark_ff_asm::x86_64_asm_mul!(3, (a.0).0, (b.0).0); },
208                    4 => { ark_ff_asm::x86_64_asm_mul!(4, (a.0).0, (b.0).0); },
209                    5 => { ark_ff_asm::x86_64_asm_mul!(5, (a.0).0, (b.0).0); },
210                    6 => { ark_ff_asm::x86_64_asm_mul!(6, (a.0).0, (b.0).0); },
211                    _ => unsafe { ark_std::hint::unreachable_unchecked() },
212                };
213            } else {
214                let mut r = [0u64; N];
215
216                for i in 0..N {
217                    let mut carry1 = 0u64;
218                    r[0] = fa::mac(r[0], (a.0).0[0], (b.0).0[i], &mut carry1);
219
220                    let k = r[0].wrapping_mul(Self::INV);
221
222                    let mut carry2 = 0u64;
223                    fa::mac_discard(r[0], k, Self::MODULUS.0[0], &mut carry2);
224
225                    for j in 1..N {
226                        r[j] = fa::mac_with_carry(r[j], (a.0).0[j], (b.0).0[i], &mut carry1);
227                        r[j - 1] = fa::mac_with_carry(r[j], k, Self::MODULUS.0[j], &mut carry2);
228                    }
229                    r[N - 1] = carry1 + carry2;
230                }
231                (a.0).0.copy_from_slice(&r);
232            }
233            a.subtract_modulus();
234        } else {
235            // Alternative implementation
236            // Implements CIOS.
237            let (carry, res) = a.mul_without_cond_subtract(b);
238            *a = res;
239
240            if Self::MODULUS_HAS_SPARE_BIT {
241                a.subtract_modulus_with_carry(carry);
242            } else {
243                a.subtract_modulus();
244            }
245        }
246    }
247
248    #[inline(always)]
249    #[unroll_for_loops(12)]
250    fn square_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
251        if N == 1 {
252            // We default to multiplying with `a` using the `Mul` impl
253            // for the N == 1 case
254            *a *= *a;
255            return;
256        }
257        #[cfg(all(
258            feature = "asm",
259            target_feature = "bmi2",
260            target_feature = "adx",
261            target_arch = "x86_64"
262        ))]
263        #[allow(unsafe_code)]
264        if Self::CAN_USE_NO_CARRY_SQUARE_OPT && (2..=6).contains(&N) {
265            use ark_ff_asm::x86_64_asm_square;
266            #[rustfmt::skip]
267            match N {
268                2 => { x86_64_asm_square!(2, (a.0).0); },
269                3 => { x86_64_asm_square!(3, (a.0).0); },
270                4 => { x86_64_asm_square!(4, (a.0).0); },
271                5 => { x86_64_asm_square!(5, (a.0).0); },
272                6 => { x86_64_asm_square!(6, (a.0).0); },
273                _ => unsafe { ark_std::hint::unreachable_unchecked() },
274            };
275            a.subtract_modulus();
276            return;
277        }
278
279        let mut r = crate::const_helpers::MulBuffer::<N>::zeroed();
280
281        let mut carry = 0;
282        for i in 0..(N - 1) {
283            for j in (i + 1)..N {
284                r[i + j] = fa::mac_with_carry(r[i + j], (a.0).0[i], (a.0).0[j], &mut carry);
285            }
286            r.b1[i] = carry;
287            carry = 0;
288        }
289
290        r.b1[N - 1] = r.b1[N - 2] >> 63;
291        for i in 2..(2 * N - 1) {
292            r[2 * N - i] = (r[2 * N - i] << 1) | (r[2 * N - (i + 1)] >> 63);
293        }
294        r.b0[1] <<= 1;
295
296        for i in 0..N {
297            r[2 * i] = fa::mac_with_carry(r[2 * i], (a.0).0[i], (a.0).0[i], &mut carry);
298            carry = fa::adc(&mut r[2 * i + 1], 0, carry);
299        }
300        // Montgomery reduction
301        let mut carry2 = 0;
302        for i in 0..N {
303            let k = r[i].wrapping_mul(Self::INV);
304            carry = 0;
305            fa::mac_discard(r[i], k, Self::MODULUS.0[0], &mut carry);
306            for j in 1..N {
307                r[j + i] = fa::mac_with_carry(r[j + i], k, Self::MODULUS.0[j], &mut carry);
308            }
309            carry2 = fa::adc(&mut r.b1[i], carry, carry2);
310        }
311        (a.0).0.copy_from_slice(&r.b1);
312        if Self::MODULUS_HAS_SPARE_BIT {
313            a.subtract_modulus();
314        } else {
315            a.subtract_modulus_with_carry(carry2 != 0);
316        }
317    }
318
319    fn inverse(a: &Fp<MontBackend<Self, N>, N>) -> Option<Fp<MontBackend<Self, N>, N>> {
320        if a.is_zero() {
321            return None;
322        }
323        // Guajardo Kumar Paar Pelzl
324        // Efficient Software-Implementation of Finite Fields with Applications to
325        // Cryptography
326        // Algorithm 16 (BEA for Inversion in Fp)
327
328        let one = BigInt::from(1u64);
329
330        let mut u = a.0;
331        let mut v = Self::MODULUS;
332        let mut b = Fp::new_unchecked(Self::R2); // Avoids unnecessary reduction step.
333        let mut c = Fp::zero();
334
335        while u != one && v != one {
336            while u.is_even() {
337                u.div2();
338
339                if b.0.is_even() {
340                    b.0.div2();
341                } else {
342                    let carry = b.0.add_with_carry(&Self::MODULUS);
343                    b.0.div2();
344                    if !Self::MODULUS_HAS_SPARE_BIT && carry {
345                        (b.0).0[N - 1] |= 1 << 63;
346                    }
347                }
348            }
349
350            while v.is_even() {
351                v.div2();
352
353                if c.0.is_even() {
354                    c.0.div2();
355                } else {
356                    let carry = c.0.add_with_carry(&Self::MODULUS);
357                    c.0.div2();
358                    if !Self::MODULUS_HAS_SPARE_BIT && carry {
359                        (c.0).0[N - 1] |= 1 << 63;
360                    }
361                }
362            }
363
364            if v < u {
365                u.sub_with_borrow(&v);
366                b -= &c;
367            } else {
368                v.sub_with_borrow(&u);
369                c -= &b;
370            }
371        }
372
373        if u == one {
374            Some(b)
375        } else {
376            Some(c)
377        }
378    }
379
380    fn from_bigint(r: BigInt<N>) -> Option<Fp<MontBackend<Self, N>, N>> {
381        let mut r = Fp::new_unchecked(r);
382        if r.is_zero() {
383            Some(r)
384        } else if r.is_geq_modulus() {
385            None
386        } else {
387            r *= &Fp::new_unchecked(Self::R2);
388            Some(r)
389        }
390    }
391
392    #[inline]
393    #[cfg_attr(not(target_family = "wasm"), unroll_for_loops(12))]
394    #[cfg_attr(target_family = "wasm", unroll_for_loops(6))]
395    #[allow(clippy::modulo_one)]
396    fn into_bigint(a: Fp<MontBackend<Self, N>, N>) -> BigInt<N> {
397        let mut r = (a.0).0;
398        // Montgomery Reduction
399        for i in 0..N {
400            let k = r[i].wrapping_mul(Self::INV);
401            let mut carry = 0;
402
403            fa::mac_with_carry(r[i], k, Self::MODULUS.0[0], &mut carry);
404            for j in 1..N {
405                r[(j + i) % N] =
406                    fa::mac_with_carry(r[(j + i) % N], k, Self::MODULUS.0[j], &mut carry);
407            }
408            r[i] = carry;
409        }
410
411        BigInt::new(r)
412    }
413
414    #[unroll_for_loops(12)]
415    fn sum_of_products<const M: usize>(
416        a: &[Fp<MontBackend<Self, N>, N>; M],
417        b: &[Fp<MontBackend<Self, N>, N>; M],
418    ) -> Fp<MontBackend<Self, N>, N> {
419        // Adapted from https://github.com/zkcrypto/bls12_381/pull/84 by @str4d.
420
421        // For a single `a x b` multiplication, operand scanning (schoolbook) takes each
422        // limb of `a` in turn, and multiplies it by all of the limbs of `b` to compute
423        // the result as a double-width intermediate representation, which is then fully
424        // reduced at the carry. Here however we have pairs of multiplications (a_i, b_i),
425        // the results of which are summed.
426        //
427        // The intuition for this algorithm is two-fold:
428        // - We can interleave the operand scanning for each pair, by processing the jth
429        //   limb of each `a_i` together. As these have the same offset within the overall
430        //   operand scanning flow, their results can be summed directly.
431        // - We can interleave the multiplication and reduction steps, resulting in a
432        //   single bitshift by the limb size after each iteration. This means we only
433        //   need to store a single extra limb overall, instead of keeping around all the
434        //   intermediate results and eventually having twice as many limbs.
435
436        let modulus_size = Self::MODULUS.const_num_bits() as usize;
437        if modulus_size >= 64 * N - 1 {
438            a.iter().zip(b).map(|(a, b)| *a * b).sum()
439        } else if M == 2 {
440            // Algorithm 2, line 2
441            let result = (0..N).fold(BigInt::zero(), |mut result, j| {
442                // Algorithm 2, line 3
443                let mut carry_a = 0;
444                let mut carry_b = 0;
445                for (a, b) in a.iter().zip(b) {
446                    let a = &a.0;
447                    let b = &b.0;
448                    let mut carry2 = 0;
449                    result.0[0] = fa::mac(result.0[0], a.0[j], b.0[0], &mut carry2);
450                    for k in 1..N {
451                        result.0[k] = fa::mac_with_carry(result.0[k], a.0[j], b.0[k], &mut carry2);
452                    }
453                    carry_b = fa::adc(&mut carry_a, carry_b, carry2);
454                }
455
456                let k = result.0[0].wrapping_mul(Self::INV);
457                let mut carry2 = 0;
458                fa::mac_discard(result.0[0], k, Self::MODULUS.0[0], &mut carry2);
459                for i in 1..N {
460                    result.0[i - 1] =
461                        fa::mac_with_carry(result.0[i], k, Self::MODULUS.0[i], &mut carry2);
462                }
463                result.0[N - 1] = fa::adc_no_carry(carry_a, carry_b, &carry2);
464                result
465            });
466            let mut result = Fp::new_unchecked(result);
467            result.subtract_modulus();
468            debug_assert_eq!(
469                a.iter().zip(b).map(|(a, b)| *a * b).sum::<Fp<_, N>>(),
470                result
471            );
472            result
473        } else {
474            let chunk_size = 2 * (N * 64 - modulus_size) - 1;
475            // chunk_size is at least 1, since MODULUS_BIT_SIZE is at most N * 64 - 1.
476            a.chunks(chunk_size)
477                .zip(b.chunks(chunk_size))
478                .map(|(a, b)| {
479                    // Algorithm 2, line 2
480                    let result = (0..N).fold(BigInt::zero(), |mut result, j| {
481                        // Algorithm 2, line 3
482                        let (temp, carry) = a.iter().zip(b).fold(
483                            (result, 0),
484                            |(mut temp, mut carry), (Fp(a, _), Fp(b, _))| {
485                                let mut carry2 = 0;
486                                temp.0[0] = fa::mac(temp.0[0], a.0[j], b.0[0], &mut carry2);
487                                for k in 1..N {
488                                    temp.0[k] =
489                                        fa::mac_with_carry(temp.0[k], a.0[j], b.0[k], &mut carry2);
490                                }
491                                carry = fa::adc_no_carry(carry, 0, &carry2);
492                                (temp, carry)
493                            },
494                        );
495
496                        let k = temp.0[0].wrapping_mul(Self::INV);
497                        let mut carry2 = 0;
498                        fa::mac_discard(temp.0[0], k, Self::MODULUS.0[0], &mut carry2);
499                        for i in 1..N {
500                            result.0[i - 1] =
501                                fa::mac_with_carry(temp.0[i], k, Self::MODULUS.0[i], &mut carry2);
502                        }
503                        result.0[N - 1] = fa::adc_no_carry(carry, 0, &carry2);
504                        result
505                    });
506                    let mut result = Fp::new_unchecked(result);
507                    result.subtract_modulus();
508                    debug_assert_eq!(
509                        a.iter().zip(b).map(|(a, b)| *a * b).sum::<Fp<_, N>>(),
510                        result
511                    );
512                    result
513                })
514                .sum()
515        }
516    }
517}
518
519/// Compute -M^{-1} mod 2^64.
520pub const fn inv<T: MontConfig<N>, const N: usize>() -> u64 {
521    // We compute this as follows.
522    // First, MODULUS mod 2^64 is just the lower 64 bits of MODULUS.
523    // Hence MODULUS mod 2^64 = MODULUS.0[0] mod 2^64.
524    //
525    // Next, computing the inverse mod 2^64 involves exponentiating by
526    // the multiplicative group order, which is euler_totient(2^64) - 1.
527    // Now, euler_totient(2^64) = 1 << 63, and so
528    // euler_totient(2^64) - 1 = (1 << 63) - 1 = 1111111... (63 digits).
529    // We compute this powering via standard square and multiply.
530    let mut inv = 1u64;
531    crate::const_for!((_i in 0..63) {
532        // Square
533        inv = inv.wrapping_mul(inv);
534        // Multiply
535        inv = inv.wrapping_mul(T::MODULUS.0[0]);
536    });
537    inv.wrapping_neg()
538}
539
540#[inline]
541pub const fn can_use_no_carry_mul_optimization<T: MontConfig<N>, const N: usize>() -> bool {
542    // Checking the modulus at compile time
543    let mut all_remaining_bits_are_one = T::MODULUS.0[N - 1] == u64::MAX >> 1;
544    crate::const_for!((i in 1..N) {
545        all_remaining_bits_are_one  &= T::MODULUS.0[N - i - 1] == u64::MAX;
546    });
547    modulus_has_spare_bit::<T, N>() && !all_remaining_bits_are_one
548}
549
550#[inline]
551pub const fn modulus_has_spare_bit<T: MontConfig<N>, const N: usize>() -> bool {
552    T::MODULUS.0[N - 1] >> 63 == 0
553}
554
555#[inline]
556pub const fn can_use_no_carry_square_optimization<T: MontConfig<N>, const N: usize>() -> bool {
557    // Checking the modulus at compile time
558    let top_two_bits_are_zero = T::MODULUS.0[N - 1] >> 62 == 0;
559    let mut all_remaining_bits_are_one = T::MODULUS.0[N - 1] == u64::MAX >> 2;
560    crate::const_for!((i in 1..N) {
561        all_remaining_bits_are_one  &= T::MODULUS.0[N - i - 1] == u64::MAX;
562    });
563    top_two_bits_are_zero && !all_remaining_bits_are_one
564}
565
566pub const fn sqrt_precomputation<const N: usize, T: MontConfig<N>>(
567) -> Option<SqrtPrecomputation<Fp<MontBackend<T, N>, N>>> {
568    match T::MODULUS.mod_4() {
569        3 => match T::MODULUS_PLUS_ONE_DIV_FOUR.as_ref() {
570            Some(BigInt(modulus_plus_one_div_four)) => Some(SqrtPrecomputation::Case3Mod4 {
571                modulus_plus_one_div_four,
572            }),
573            None => None,
574        },
575        _ => match T::MODULUS.mod_8() {
576            5 => match (
577                T::MODULUS_PLUS_THREE_DIV_EIGHT.as_ref(),
578                T::MODULUS_MINUS_ONE_DIV_FOUR.as_ref(),
579            ) {
580                (
581                    Some(BigInt(modulus_plus_three_div_eight)),
582                    Some(BigInt(modulus_minus_one_div_four)),
583                ) => Some(SqrtPrecomputation::Case5Mod8 {
584                    modulus_plus_three_div_eight,
585                    modulus_minus_one_div_four,
586                }),
587                _ => None,
588            },
589            _ => Some(SqrtPrecomputation::TonelliShanks {
590                two_adicity: <MontBackend<T, N>>::TWO_ADICITY,
591                quadratic_nonresidue_to_trace: T::TWO_ADIC_ROOT_OF_UNITY,
592                trace_of_modulus_minus_one_div_two:
593                    &<Fp<MontBackend<T, N>, N>>::TRACE_MINUS_ONE_DIV_TWO.0,
594            }),
595        },
596    }
597}
598
599/// Construct a [`Fp<MontBackend<T, N>, N>`] element from a literal string.
600///
601/// This should be used primarily for constructing constant field elements; in a
602/// non-const context, [`Fp::from_str`](`ark_std::str::FromStr::from_str`) is
603/// preferable.
604///
605/// # Panics
606///
607/// If the integer represented by the string cannot fit in the number
608/// of limbs of the `Fp`, this macro results in a
609/// * compile-time error if used in a const context
610/// * run-time error otherwise.
611///
612/// # Usage
613///
614/// ```rust
615/// # use ark_test_curves::MontFp;
616/// # use ark_test_curves::bls12_381 as ark_bls12_381;
617/// # use ark_std::{One, str::FromStr};
618/// use ark_bls12_381::Fq;
619/// const ONE: Fq = MontFp!("1");
620/// const NEG_ONE: Fq = MontFp!("-1");
621///
622/// fn check_correctness() {
623///     assert_eq!(ONE, Fq::one());
624///     assert_eq!(Fq::from_str("1").unwrap(), ONE);
625///     assert_eq!(NEG_ONE, -Fq::one());
626/// }
627/// ```
628#[macro_export]
629macro_rules! MontFp {
630    ($c0:expr) => {{
631        let (is_positive, limbs) = $crate::ark_ff_macros::to_sign_and_limbs!($c0);
632        $crate::Fp::from_sign_and_limbs(is_positive, &limbs)
633    }};
634}
635
636pub use ark_ff_macros::MontConfig;
637
638pub use MontFp;
639
640pub struct MontBackend<T: MontConfig<N>, const N: usize>(PhantomData<T>);
641
642impl<T: MontConfig<N>, const N: usize> FpConfig<N> for MontBackend<T, N> {
643    /// The modulus of the field.
644    const MODULUS: crate::BigInt<N> = T::MODULUS;
645
646    /// A multiplicative generator of the field.
647    /// `Self::GENERATOR` is an element having multiplicative order
648    /// `Self::MODULUS - 1`.
649    const GENERATOR: Fp<Self, N> = T::GENERATOR;
650
651    /// Additive identity of the field, i.e. the element `e`
652    /// such that, for all elements `f` of the field, `e + f = f`.
653    const ZERO: Fp<Self, N> = Fp::new_unchecked(BigInt([0u64; N]));
654
655    /// Multiplicative identity of the field, i.e. the element `e`
656    /// such that, for all elements `f` of the field, `e * f = f`.
657    const ONE: Fp<Self, N> = Fp::new_unchecked(T::R);
658
659    /// Negation of `Self::ONE`.
660    const NEG_ONE: Fp<Self, N> = Fp::new_unchecked(Self::MODULUS.const_sub_with_borrow(&T::R).0);
661
662    const TWO_ADICITY: u32 = Self::MODULUS.two_adic_valuation();
663    const TWO_ADIC_ROOT_OF_UNITY: Fp<Self, N> = T::TWO_ADIC_ROOT_OF_UNITY;
664    const SMALL_SUBGROUP_BASE: Option<u32> = T::SMALL_SUBGROUP_BASE;
665    const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = T::SMALL_SUBGROUP_BASE_ADICITY;
666    const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<Fp<Self, N>> = T::LARGE_SUBGROUP_ROOT_OF_UNITY;
667    const SQRT_PRECOMP: Option<crate::SqrtPrecomputation<Fp<Self, N>>> = T::SQRT_PRECOMP;
668
669    fn add_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
670        T::add_assign(a, b)
671    }
672
673    fn sub_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
674        T::sub_assign(a, b)
675    }
676
677    fn double_in_place(a: &mut Fp<Self, N>) {
678        T::double_in_place(a)
679    }
680
681    fn neg_in_place(a: &mut Fp<Self, N>) {
682        T::neg_in_place(a)
683    }
684
685    /// This modular multiplication algorithm uses Montgomery
686    /// reduction for efficient implementation. It also additionally
687    /// uses the "no-carry optimization" outlined
688    /// [here](https://hackmd.io/@zkteam/modular_multiplication) if
689    /// `P::MODULUS` has (a) a non-zero MSB, and (b) at least one
690    /// zero bit in the rest of the modulus.
691    #[inline]
692    fn mul_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
693        T::mul_assign(a, b)
694    }
695
696    fn sum_of_products<const M: usize>(a: &[Fp<Self, N>; M], b: &[Fp<Self, N>; M]) -> Fp<Self, N> {
697        T::sum_of_products(a, b)
698    }
699
700    #[inline]
701    fn square_in_place(a: &mut Fp<Self, N>) {
702        T::square_in_place(a)
703    }
704
705    fn inverse(a: &Fp<Self, N>) -> Option<Fp<Self, N>> {
706        T::inverse(a)
707    }
708
709    fn from_bigint(r: BigInt<N>) -> Option<Fp<Self, N>> {
710        T::from_bigint(r)
711    }
712
713    #[inline]
714    fn into_bigint(a: Fp<Self, N>) -> BigInt<N> {
715        T::into_bigint(a)
716    }
717}
718
719impl<T: MontConfig<N>, const N: usize> Fp<MontBackend<T, N>, N> {
720    #[doc(hidden)]
721    pub const R: BigInt<N> = T::R;
722    #[doc(hidden)]
723    pub const R2: BigInt<N> = T::R2;
724    #[doc(hidden)]
725    pub const INV: u64 = T::INV;
726
727    /// Construct a new field element from its underlying
728    /// [`struct@BigInt`] data type.
729    #[inline]
730    pub const fn new(element: BigInt<N>) -> Self {
731        let mut r = Self(element, PhantomData);
732        if r.const_is_zero() {
733            r
734        } else {
735            r = r.mul(&Fp(T::R2, PhantomData));
736            r
737        }
738    }
739
740    /// Construct a new field element from its underlying
741    /// [`struct@BigInt`] data type.
742    ///
743    /// Unlike [`Self::new`], this method does not perform Montgomery reduction.
744    /// Thus, this method should be used only when constructing
745    /// an element from an integer that has already been put in
746    /// Montgomery form.
747    #[inline]
748    pub const fn new_unchecked(element: BigInt<N>) -> Self {
749        Self(element, PhantomData)
750    }
751
752    const fn const_is_zero(&self) -> bool {
753        self.0.const_is_zero()
754    }
755
756    #[doc(hidden)]
757    const fn const_neg(self) -> Self {
758        if !self.const_is_zero() {
759            Self::new_unchecked(Self::sub_with_borrow(&T::MODULUS, &self.0))
760        } else {
761            self
762        }
763    }
764
765    /// Interpret a set of limbs (along with a sign) as a field element.
766    /// For *internal* use only; please use the `ark_ff::MontFp` macro instead
767    /// of this method
768    #[doc(hidden)]
769    pub const fn from_sign_and_limbs(is_positive: bool, limbs: &[u64]) -> Self {
770        let mut repr = BigInt([0; N]);
771        assert!(limbs.len() <= N);
772        crate::const_for!((i in 0..(limbs.len())) {
773            repr.0[i] = limbs[i];
774        });
775        let res = Self::new(repr);
776        if is_positive {
777            res
778        } else {
779            res.const_neg()
780        }
781    }
782
783    const fn mul_without_cond_subtract(mut self, other: &Self) -> (bool, Self) {
784        let (mut lo, mut hi) = ([0u64; N], [0u64; N]);
785        crate::const_for!((i in 0..N) {
786            let mut carry = 0;
787            crate::const_for!((j in 0..N) {
788                let k = i + j;
789                if k >= N {
790                    hi[k - N] = fa::mac_with_carry(hi[k - N], (self.0).0[i], (other.0).0[j], &mut carry);
791                } else {
792                    lo[k] = fa::mac_with_carry(lo[k], (self.0).0[i], (other.0).0[j], &mut carry);
793                }
794            });
795            hi[i] = carry;
796        });
797        // Montgomery reduction
798        let mut carry2 = 0;
799        crate::const_for!((i in 0..N) {
800            let tmp = lo[i].wrapping_mul(T::INV);
801            let mut carry = 0;
802            fa::mac(lo[i], tmp, T::MODULUS.0[0], &mut carry);
803            crate::const_for!((j in 1..N) {
804                let k = i + j;
805                if k >= N {
806                    hi[k - N] = fa::mac_with_carry(hi[k - N], tmp, T::MODULUS.0[j], &mut carry);
807                }  else {
808                    lo[k] = fa::mac_with_carry(lo[k], tmp, T::MODULUS.0[j], &mut carry);
809                }
810            });
811            carry2 = fa::adc(&mut hi[i], carry, carry2);
812        });
813
814        crate::const_for!((i in 0..N) {
815            (self.0).0[i] = hi[i];
816        });
817        (carry2 != 0, self)
818    }
819
820    const fn mul(self, other: &Self) -> Self {
821        let (carry, res) = self.mul_without_cond_subtract(other);
822        if T::MODULUS_HAS_SPARE_BIT {
823            res.const_subtract_modulus()
824        } else {
825            res.const_subtract_modulus_with_carry(carry)
826        }
827    }
828
829    const fn const_is_valid(&self) -> bool {
830        crate::const_for!((i in 0..N) {
831            if (self.0).0[N - i - 1] < T::MODULUS.0[N - i - 1] {
832                return true
833            } else if (self.0).0[N - i - 1] > T::MODULUS.0[N - i - 1] {
834                return false
835            }
836        });
837        false
838    }
839
840    #[inline]
841    const fn const_subtract_modulus(mut self) -> Self {
842        if !self.const_is_valid() {
843            self.0 = Self::sub_with_borrow(&self.0, &T::MODULUS);
844        }
845        self
846    }
847
848    #[inline]
849    const fn const_subtract_modulus_with_carry(mut self, carry: bool) -> Self {
850        if carry || !self.const_is_valid() {
851            self.0 = Self::sub_with_borrow(&self.0, &T::MODULUS);
852        }
853        self
854    }
855
856    const fn sub_with_borrow(a: &BigInt<N>, b: &BigInt<N>) -> BigInt<N> {
857        a.const_sub_with_borrow(b).0
858    }
859}
860
861#[cfg(test)]
862mod test {
863    use ark_std::{str::FromStr, vec::*};
864    use ark_test_curves::secp256k1::Fr;
865    use num_bigint::{BigInt, BigUint, Sign};
866
867    #[test]
868    fn test_mont_macro_correctness() {
869        let (is_positive, limbs) = str_to_limbs_u64(
870            "111192936301596926984056301862066282284536849596023571352007112326586892541694",
871        );
872        let t = Fr::from_sign_and_limbs(is_positive, &limbs);
873
874        let result: BigUint = t.into();
875        let expected = BigUint::from_str(
876            "111192936301596926984056301862066282284536849596023571352007112326586892541694",
877        )
878        .unwrap();
879
880        assert_eq!(result, expected);
881    }
882
883    fn str_to_limbs_u64(num: &str) -> (bool, Vec<u64>) {
884        let (sign, digits) = BigInt::from_str(num)
885            .expect("could not parse to bigint")
886            .to_radix_le(16);
887        let limbs = digits
888            .chunks(16)
889            .map(|chunk| {
890                let mut this = 0u64;
891                for (i, hexit) in chunk.iter().enumerate() {
892                    this += (*hexit as u64) << (4 * i);
893                }
894                this
895            })
896            .collect::<Vec<_>>();
897
898        let sign_is_positive = sign != Sign::Minus;
899        (sign_is_positive, limbs)
900    }
901}