use ark_std::{marker::PhantomData, Zero};
use super::{Fp, FpConfig};
use crate::{biginteger::arithmetic as fa, BigInt, BigInteger, PrimeField, SqrtPrecomputation};
use ark_ff_macros::unroll_for_loops;
pub trait MontConfig<const N: usize>: 'static + Sync + Send + Sized {
const MODULUS: BigInt<N>;
const R: BigInt<N> = Self::MODULUS.montgomery_r();
const R2: BigInt<N> = Self::MODULUS.montgomery_r2();
const INV: u64 = inv::<Self, N>();
const GENERATOR: Fp<MontBackend<Self, N>, N>;
#[doc(hidden)]
const CAN_USE_NO_CARRY_MUL_OPT: bool = can_use_no_carry_mul_optimization::<Self, N>();
#[doc(hidden)]
const CAN_USE_NO_CARRY_SQUARE_OPT: bool = can_use_no_carry_mul_optimization::<Self, N>();
#[doc(hidden)]
const MODULUS_HAS_SPARE_BIT: bool = modulus_has_spare_bit::<Self, N>();
const TWO_ADIC_ROOT_OF_UNITY: Fp<MontBackend<Self, N>, N>;
const SMALL_SUBGROUP_BASE: Option<u32> = None;
const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = None;
const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<Fp<MontBackend<Self, N>, N>> = None;
const SQRT_PRECOMP: Option<SqrtPrecomputation<Fp<MontBackend<Self, N>, N>>> =
sqrt_precomputation::<N, Self>();
#[doc(hidden)]
const MODULUS_PLUS_ONE_DIV_FOUR: Option<BigInt<N>> = {
match Self::MODULUS.mod_4() == 3 {
true => {
let (modulus_plus_one, carry) =
Self::MODULUS.const_add_with_carry(&BigInt::<N>::one());
let mut result = modulus_plus_one.divide_by_2_round_down();
result.0[N - 1] |= (carry as u64) << 63;
Some(result.divide_by_2_round_down())
},
false => None,
}
};
#[inline(always)]
fn add_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
let c = a.0.add_with_carry(&b.0);
if Self::MODULUS_HAS_SPARE_BIT {
a.subtract_modulus()
} else {
a.subtract_modulus_with_carry(c)
}
}
#[inline(always)]
fn sub_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
if b.0 > a.0 {
a.0.add_with_carry(&Self::MODULUS);
}
a.0.sub_with_borrow(&b.0);
}
#[inline(always)]
fn double_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
let c = a.0.mul2();
if Self::MODULUS_HAS_SPARE_BIT {
a.subtract_modulus()
} else {
a.subtract_modulus_with_carry(c)
}
}
#[inline(always)]
fn neg_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
if !a.is_zero() {
let mut tmp = Self::MODULUS;
tmp.sub_with_borrow(&a.0);
a.0 = tmp;
}
}
#[unroll_for_loops(12)]
#[inline(always)]
fn mul_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
if Self::CAN_USE_NO_CARRY_MUL_OPT {
if N <= 6
&& N > 1
&& cfg!(all(
feature = "asm",
target_feature = "bmi2",
target_feature = "adx",
target_arch = "x86_64"
))
{
#[cfg(
all(
feature = "asm",
target_feature = "bmi2",
target_feature = "adx",
target_arch = "x86_64"
)
)]
#[allow(unsafe_code, unused_mut)]
#[rustfmt::skip]
match N {
2 => { ark_ff_asm::x86_64_asm_mul!(2, (a.0).0, (b.0).0); },
3 => { ark_ff_asm::x86_64_asm_mul!(3, (a.0).0, (b.0).0); },
4 => { ark_ff_asm::x86_64_asm_mul!(4, (a.0).0, (b.0).0); },
5 => { ark_ff_asm::x86_64_asm_mul!(5, (a.0).0, (b.0).0); },
6 => { ark_ff_asm::x86_64_asm_mul!(6, (a.0).0, (b.0).0); },
_ => unsafe { ark_std::hint::unreachable_unchecked() },
};
} else {
let mut r = [0u64; N];
for i in 0..N {
let mut carry1 = 0u64;
r[0] = fa::mac(r[0], (a.0).0[0], (b.0).0[i], &mut carry1);
let k = r[0].wrapping_mul(Self::INV);
let mut carry2 = 0u64;
fa::mac_discard(r[0], k, Self::MODULUS.0[0], &mut carry2);
for j in 1..N {
r[j] = fa::mac_with_carry(r[j], (a.0).0[j], (b.0).0[i], &mut carry1);
r[j - 1] = fa::mac_with_carry(r[j], k, Self::MODULUS.0[j], &mut carry2);
}
r[N - 1] = carry1 + carry2;
}
(a.0).0 = r;
}
a.subtract_modulus();
} else {
let (carry, res) = a.mul_without_cond_subtract(b);
*a = res;
if Self::MODULUS_HAS_SPARE_BIT {
a.subtract_modulus_with_carry(carry);
} else {
a.subtract_modulus();
}
}
}
#[inline(always)]
#[unroll_for_loops(12)]
fn square_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
if N == 1 {
*a *= *a;
return;
}
if Self::CAN_USE_NO_CARRY_SQUARE_OPT
&& (2..=6).contains(&N)
&& cfg!(all(
feature = "asm",
target_feature = "bmi2",
target_feature = "adx",
target_arch = "x86_64"
))
{
#[cfg(all(
feature = "asm",
target_feature = "bmi2",
target_feature = "adx",
target_arch = "x86_64"
))]
#[allow(unsafe_code, unused_mut)]
#[rustfmt::skip]
match N {
2 => { ark_ff_asm::x86_64_asm_square!(2, (a.0).0); },
3 => { ark_ff_asm::x86_64_asm_square!(3, (a.0).0); },
4 => { ark_ff_asm::x86_64_asm_square!(4, (a.0).0); },
5 => { ark_ff_asm::x86_64_asm_square!(5, (a.0).0); },
6 => { ark_ff_asm::x86_64_asm_square!(6, (a.0).0); },
_ => unsafe { ark_std::hint::unreachable_unchecked() },
};
a.subtract_modulus();
return;
}
let mut r = crate::const_helpers::MulBuffer::<N>::zeroed();
let mut carry = 0;
for i in 0..(N - 1) {
for j in (i + 1)..N {
r[i + j] = fa::mac_with_carry(r[i + j], (a.0).0[i], (a.0).0[j], &mut carry);
}
r.b1[i] = carry;
carry = 0;
}
r.b1[N - 1] = r.b1[N - 2] >> 63;
for i in 2..(2 * N - 1) {
r[2 * N - i] = (r[2 * N - i] << 1) | (r[2 * N - (i + 1)] >> 63);
}
r.b0[1] <<= 1;
for i in 0..N {
r[2 * i] = fa::mac_with_carry(r[2 * i], (a.0).0[i], (a.0).0[i], &mut carry);
carry = fa::adc(&mut r[2 * i + 1], 0, carry);
}
let mut carry2 = 0;
for i in 0..N {
let k = r[i].wrapping_mul(Self::INV);
let mut carry = 0;
fa::mac_discard(r[i], k, Self::MODULUS.0[0], &mut carry);
for j in 1..N {
r[j + i] = fa::mac_with_carry(r[j + i], k, Self::MODULUS.0[j], &mut carry);
}
carry2 = fa::adc(&mut r.b1[i], carry, carry2);
}
(a.0).0.copy_from_slice(&r.b1);
if Self::MODULUS_HAS_SPARE_BIT {
a.subtract_modulus();
} else {
a.subtract_modulus_with_carry(carry2 != 0);
}
}
fn inverse(a: &Fp<MontBackend<Self, N>, N>) -> Option<Fp<MontBackend<Self, N>, N>> {
if a.is_zero() {
None
} else {
let one = BigInt::from(1u64);
let mut u = a.0;
let mut v = Self::MODULUS;
let mut b = Fp::new_unchecked(Self::R2); let mut c = Fp::zero();
while u != one && v != one {
while u.is_even() {
u.div2();
if b.0.is_even() {
b.0.div2();
} else {
let carry = b.0.add_with_carry(&Self::MODULUS);
b.0.div2();
if !Self::MODULUS_HAS_SPARE_BIT && carry {
(b.0).0[N - 1] |= 1 << 63;
}
}
}
while v.is_even() {
v.div2();
if c.0.is_even() {
c.0.div2();
} else {
let carry = c.0.add_with_carry(&Self::MODULUS);
c.0.div2();
if !Self::MODULUS_HAS_SPARE_BIT && carry {
(c.0).0[N - 1] |= 1 << 63;
}
}
}
if v < u {
u.sub_with_borrow(&v);
b -= &c;
} else {
v.sub_with_borrow(&u);
c -= &b;
}
}
if u == one {
Some(b)
} else {
Some(c)
}
}
}
fn from_bigint(r: BigInt<N>) -> Option<Fp<MontBackend<Self, N>, N>> {
let mut r = Fp::new_unchecked(r);
if r.is_zero() {
Some(r)
} else if r.is_geq_modulus() {
None
} else {
r *= &Fp::new_unchecked(Self::R2);
Some(r)
}
}
#[inline]
#[unroll_for_loops(12)]
#[allow(clippy::modulo_one)]
fn into_bigint(a: Fp<MontBackend<Self, N>, N>) -> BigInt<N> {
let mut tmp = a.0;
let mut r = tmp.0;
for i in 0..N {
let k = r[i].wrapping_mul(Self::INV);
let mut carry = 0;
fa::mac_with_carry(r[i], k, Self::MODULUS.0[0], &mut carry);
for j in 1..N {
r[(j + i) % N] =
fa::mac_with_carry(r[(j + i) % N], k, Self::MODULUS.0[j], &mut carry);
}
r[i % N] = carry;
}
tmp.0 = r;
tmp
}
#[unroll_for_loops(12)]
fn sum_of_products<const M: usize>(
a: &[Fp<MontBackend<Self, N>, N>; M],
b: &[Fp<MontBackend<Self, N>, N>; M],
) -> Fp<MontBackend<Self, N>, N> {
let modulus_size = Self::MODULUS.const_num_bits() as usize;
if modulus_size >= 64 * N - 1 {
a.iter().zip(b).map(|(a, b)| *a * b).sum()
} else if M == 2 {
let result = (0..N).fold(BigInt::zero(), |mut result, j| {
let mut carry_a = 0;
let mut carry_b = 0;
for (a, b) in a.iter().zip(b) {
let a = &a.0;
let b = &b.0;
let mut carry2 = 0;
result.0[0] = fa::mac(result.0[0], a.0[j], b.0[0], &mut carry2);
for k in 1..N {
result.0[k] = fa::mac_with_carry(result.0[k], a.0[j], b.0[k], &mut carry2);
}
carry_b = fa::adc(&mut carry_a, carry_b, carry2);
}
let k = result.0[0].wrapping_mul(Self::INV);
let mut carry2 = 0;
fa::mac_discard(result.0[0], k, Self::MODULUS.0[0], &mut carry2);
for i in 1..N {
result.0[i - 1] =
fa::mac_with_carry(result.0[i], k, Self::MODULUS.0[i], &mut carry2);
}
result.0[N - 1] = fa::adc_no_carry(carry_a, carry_b, &mut carry2);
result
});
let mut result = Fp::new_unchecked(result);
result.subtract_modulus();
debug_assert_eq!(
a.iter().zip(b).map(|(a, b)| *a * b).sum::<Fp<_, N>>(),
result
);
result
} else {
let chunk_size = 2 * (N * 64 - modulus_size) - 1;
a.chunks(chunk_size)
.zip(b.chunks(chunk_size))
.map(|(a, b)| {
let result = (0..N).fold(BigInt::zero(), |mut result, j| {
let (temp, carry) = a.iter().zip(b).fold(
(result, 0),
|(mut temp, mut carry), (Fp(a, _), Fp(b, _))| {
let mut carry2 = 0;
temp.0[0] = fa::mac(temp.0[0], a.0[j], b.0[0], &mut carry2);
for k in 1..N {
temp.0[k] =
fa::mac_with_carry(temp.0[k], a.0[j], b.0[k], &mut carry2);
}
carry = fa::adc_no_carry(carry, 0, &mut carry2);
(temp, carry)
},
);
let k = temp.0[0].wrapping_mul(Self::INV);
let mut carry2 = 0;
fa::mac_discard(temp.0[0], k, Self::MODULUS.0[0], &mut carry2);
for i in 1..N {
result.0[i - 1] =
fa::mac_with_carry(temp.0[i], k, Self::MODULUS.0[i], &mut carry2);
}
result.0[N - 1] = fa::adc_no_carry(carry, 0, &mut carry2);
result
});
let mut result = Fp::new_unchecked(result);
result.subtract_modulus();
debug_assert_eq!(
a.iter().zip(b).map(|(a, b)| *a * b).sum::<Fp<_, N>>(),
result
);
result
})
.sum()
}
}
}
pub const fn inv<T: MontConfig<N>, const N: usize>() -> u64 {
let mut inv = 1u64;
crate::const_for!((_i in 0..63) {
inv = inv.wrapping_mul(inv);
inv = inv.wrapping_mul(T::MODULUS.0[0]);
});
inv.wrapping_neg()
}
#[inline]
pub const fn can_use_no_carry_mul_optimization<T: MontConfig<N>, const N: usize>() -> bool {
let top_bit_is_zero = T::MODULUS.0[N - 1] >> 63 == 0;
let mut all_remaining_bits_are_one = T::MODULUS.0[N - 1] == u64::MAX >> 1;
crate::const_for!((i in 1..N) {
all_remaining_bits_are_one &= T::MODULUS.0[N - i - 1] == u64::MAX;
});
top_bit_is_zero && !all_remaining_bits_are_one
}
#[inline]
pub const fn modulus_has_spare_bit<T: MontConfig<N>, const N: usize>() -> bool {
T::MODULUS.0[N - 1] >> 63 == 0
}
#[inline]
pub const fn can_use_no_carry_square_optimization<T: MontConfig<N>, const N: usize>() -> bool {
let top_two_bits_are_zero = T::MODULUS.0[N - 1] >> 62 == 0;
let mut all_remaining_bits_are_one = T::MODULUS.0[N - 1] == u64::MAX >> 2;
crate::const_for!((i in 1..N) {
all_remaining_bits_are_one &= T::MODULUS.0[N - i - 1] == u64::MAX;
});
top_two_bits_are_zero && !all_remaining_bits_are_one
}
pub const fn sqrt_precomputation<const N: usize, T: MontConfig<N>>(
) -> Option<SqrtPrecomputation<Fp<MontBackend<T, N>, N>>> {
match T::MODULUS.mod_4() {
3 => match T::MODULUS_PLUS_ONE_DIV_FOUR.as_ref() {
Some(BigInt(modulus_plus_one_div_four)) => Some(SqrtPrecomputation::Case3Mod4 {
modulus_plus_one_div_four,
}),
None => None,
},
_ => Some(SqrtPrecomputation::TonelliShanks {
two_adicity: <MontBackend<T, N>>::TWO_ADICITY,
quadratic_nonresidue_to_trace: T::TWO_ADIC_ROOT_OF_UNITY,
trace_of_modulus_minus_one_div_two:
&<Fp<MontBackend<T, N>, N>>::TRACE_MINUS_ONE_DIV_TWO.0,
}),
}
}
#[macro_export]
macro_rules! MontFp {
($c0:expr) => {{
let (is_positive, limbs) = $crate::ark_ff_macros::to_sign_and_limbs!($c0);
$crate::Fp::from_sign_and_limbs(is_positive, &limbs)
}};
}
pub use ark_ff_macros::MontConfig;
pub use MontFp;
pub struct MontBackend<T: MontConfig<N>, const N: usize>(PhantomData<T>);
impl<T: MontConfig<N>, const N: usize> FpConfig<N> for MontBackend<T, N> {
const MODULUS: crate::BigInt<N> = T::MODULUS;
const GENERATOR: Fp<Self, N> = T::GENERATOR;
const ZERO: Fp<Self, N> = Fp::new_unchecked(BigInt([0u64; N]));
const ONE: Fp<Self, N> = Fp::new_unchecked(T::R);
const TWO_ADICITY: u32 = Self::MODULUS.two_adic_valuation();
const TWO_ADIC_ROOT_OF_UNITY: Fp<Self, N> = T::TWO_ADIC_ROOT_OF_UNITY;
const SMALL_SUBGROUP_BASE: Option<u32> = T::SMALL_SUBGROUP_BASE;
const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = T::SMALL_SUBGROUP_BASE_ADICITY;
const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<Fp<Self, N>> = T::LARGE_SUBGROUP_ROOT_OF_UNITY;
const SQRT_PRECOMP: Option<crate::SqrtPrecomputation<Fp<Self, N>>> = T::SQRT_PRECOMP;
fn add_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
T::add_assign(a, b)
}
fn sub_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
T::sub_assign(a, b)
}
fn double_in_place(a: &mut Fp<Self, N>) {
T::double_in_place(a)
}
fn neg_in_place(a: &mut Fp<Self, N>) {
T::neg_in_place(a)
}
#[inline]
fn mul_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
T::mul_assign(a, b)
}
fn sum_of_products<const M: usize>(a: &[Fp<Self, N>; M], b: &[Fp<Self, N>; M]) -> Fp<Self, N> {
T::sum_of_products(a, b)
}
#[inline]
#[allow(unused_braces, clippy::absurd_extreme_comparisons)]
fn square_in_place(a: &mut Fp<Self, N>) {
T::square_in_place(a)
}
fn inverse(a: &Fp<Self, N>) -> Option<Fp<Self, N>> {
T::inverse(a)
}
fn from_bigint(r: BigInt<N>) -> Option<Fp<Self, N>> {
T::from_bigint(r)
}
#[inline]
#[allow(clippy::modulo_one)]
fn into_bigint(a: Fp<Self, N>) -> BigInt<N> {
T::into_bigint(a)
}
}
impl<T: MontConfig<N>, const N: usize> Fp<MontBackend<T, N>, N> {
#[doc(hidden)]
pub const R: BigInt<N> = T::R;
#[doc(hidden)]
pub const R2: BigInt<N> = T::R2;
#[doc(hidden)]
pub const INV: u64 = T::INV;
#[inline]
pub const fn new(element: BigInt<N>) -> Self {
let mut r = Self(element, PhantomData);
if r.const_is_zero() {
r
} else {
r = r.mul(&Fp(T::R2, PhantomData));
r
}
}
#[inline]
pub const fn new_unchecked(element: BigInt<N>) -> Self {
Self(element, PhantomData)
}
const fn const_is_zero(&self) -> bool {
self.0.const_is_zero()
}
#[doc(hidden)]
const fn const_neg(self) -> Self {
if !self.const_is_zero() {
Self::new_unchecked(Self::sub_with_borrow(&T::MODULUS, &self.0))
} else {
self
}
}
#[doc(hidden)]
pub const fn from_sign_and_limbs(is_positive: bool, limbs: &[u64]) -> Self {
let mut repr = BigInt::<N>([0; N]);
assert!(limbs.len() <= N);
crate::const_for!((i in 0..(limbs.len())) {
repr.0[i] = limbs[i];
});
let res = Self::new(repr);
if is_positive {
res
} else {
res.const_neg()
}
}
const fn mul_without_cond_subtract(mut self, other: &Self) -> (bool, Self) {
let (mut lo, mut hi) = ([0u64; N], [0u64; N]);
crate::const_for!((i in 0..N) {
let mut carry = 0;
crate::const_for!((j in 0..N) {
let k = i + j;
if k >= N {
hi[k - N] = mac_with_carry!(hi[k - N], (self.0).0[i], (other.0).0[j], &mut carry);
} else {
lo[k] = mac_with_carry!(lo[k], (self.0).0[i], (other.0).0[j], &mut carry);
}
});
hi[i] = carry;
});
let mut carry2 = 0;
crate::const_for!((i in 0..N) {
let tmp = lo[i].wrapping_mul(T::INV);
let mut carry;
mac!(lo[i], tmp, T::MODULUS.0[0], &mut carry);
crate::const_for!((j in 1..N) {
let k = i + j;
if k >= N {
hi[k - N] = mac_with_carry!(hi[k - N], tmp, T::MODULUS.0[j], &mut carry);
} else {
lo[k] = mac_with_carry!(lo[k], tmp, T::MODULUS.0[j], &mut carry);
}
});
hi[i] = adc!(hi[i], carry, &mut carry2);
});
crate::const_for!((i in 0..N) {
(self.0).0[i] = hi[i];
});
(carry2 != 0, self)
}
const fn mul(self, other: &Self) -> Self {
let (carry, res) = self.mul_without_cond_subtract(other);
if T::MODULUS_HAS_SPARE_BIT {
res.const_subtract_modulus()
} else {
res.const_subtract_modulus_with_carry(carry)
}
}
const fn const_is_valid(&self) -> bool {
crate::const_for!((i in 0..N) {
if (self.0).0[(N - i - 1)] < T::MODULUS.0[(N - i - 1)] {
return true
} else if (self.0).0[(N - i - 1)] > T::MODULUS.0[(N - i - 1)] {
return false
}
});
false
}
#[inline]
const fn const_subtract_modulus(mut self) -> Self {
if !self.const_is_valid() {
self.0 = Self::sub_with_borrow(&self.0, &T::MODULUS);
}
self
}
#[inline]
const fn const_subtract_modulus_with_carry(mut self, carry: bool) -> Self {
if carry || !self.const_is_valid() {
self.0 = Self::sub_with_borrow(&self.0, &T::MODULUS);
}
self
}
const fn sub_with_borrow(a: &BigInt<N>, b: &BigInt<N>) -> BigInt<N> {
a.const_sub_with_borrow(b).0
}
}
#[cfg(test)]
mod test {
use ark_std::{str::FromStr, vec::Vec};
use ark_test_curves::secp256k1::Fr;
use num_bigint::{BigInt, BigUint, Sign};
#[test]
fn test_mont_macro_correctness() {
let (is_positive, limbs) = str_to_limbs_u64(
"111192936301596926984056301862066282284536849596023571352007112326586892541694",
);
let t = Fr::from_sign_and_limbs(is_positive, &limbs);
let result: BigUint = t.into();
let expected = BigUint::from_str(
"111192936301596926984056301862066282284536849596023571352007112326586892541694",
)
.unwrap();
assert_eq!(result, expected);
}
fn str_to_limbs_u64(num: &str) -> (bool, Vec<u64>) {
let (sign, digits) = BigInt::from_str(num)
.expect("could not parse to bigint")
.to_radix_le(16);
let limbs = digits
.chunks(16)
.map(|chunk| {
let mut this = 0u64;
for (i, hexit) in chunk.iter().enumerate() {
this += (*hexit as u64) << (4 * i);
}
this
})
.collect::<Vec<_>>();
let sign_is_positive = sign != Sign::Minus;
(sign_is_positive, limbs)
}
}