use super::*;
use crate::small_fp::utils::{
compute_large_subgroup_root, compute_two_adic_root_of_unity, compute_two_adicity,
generate_montgomery_bigint_casts, generate_sqrt_precomputation, mod_mul_const, pow_mod_const,
};
use crate::utils::find_conservative_subgroup_base;
pub(crate) fn backend_impl(
repr_type: &proc_macro2::TokenStream,
modulus: u128,
generator: u128,
) -> (proc_macro2::TokenStream, u128) {
assert!(modulus > 1, "modulus must be greater than 1");
assert!(
modulus % 2 == 1,
"modulus must be odd for Montgomery multiplication"
);
assert!(
modulus < (1u128 << 64),
"modulus must be < 2^64 for SmallFp"
);
let repr_type_str = repr_type.to_string();
let k_bits = 128 - modulus.leading_zeros();
let r: u128 = 1u128 << k_bits;
let r_mod_n = r % modulus;
let r_mask = r - 1;
let n_prime = mod_inverse_pow2(modulus, k_bits);
let one_mont = r_mod_n;
let generator_mont = mod_mul_const(generator % modulus, r_mod_n % modulus, modulus);
let two_adicity = compute_two_adicity(modulus);
let two_adic_root = compute_two_adic_root_of_unity(modulus, two_adicity, generator);
let two_adic_root_mont = mod_mul_const(two_adic_root, r_mod_n, modulus);
let neg_one_mont = mod_mul_const(modulus - 1, r_mod_n, modulus);
let modulus_big = num_bigint::BigUint::from(modulus);
let mixed_radix_impl = if let Some((base, power)) =
find_conservative_subgroup_base(&modulus_big)
{
let large_root = compute_large_subgroup_root(modulus, generator, two_adicity, base, power);
let large_root_mont = mod_mul_const(large_root, r_mod_n, modulus);
quote! {
const SMALL_SUBGROUP_BASE: Option<u32> = Some(#base);
const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = Some(#power);
const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<SmallFp<Self>> = Some(SmallFp::from_raw(#large_root_mont as Self::T));
}
} else {
quote! {}
};
let (from_bigint_impl, into_bigint_impl) = generate_montgomery_bigint_casts();
let sqrt_precomp_impl = generate_sqrt_precomputation(modulus, two_adicity, Some(r_mod_n));
let r2 = mod_mul_const(r_mod_n, r_mod_n, modulus);
let mul_impl = generate_mul_impl(repr_type, modulus, k_bits, r_mask, n_prime);
let inverse_impl = generate_inverse_impl(repr_type, modulus, r_mod_n, r2);
let type_bits = match repr_type_str.as_str() {
"u8" => 8u32,
"u16" => 16u32,
"u32" => 32u32,
"u64" => 64u32,
_ => panic!("unsupported type"),
};
let has_spare_bit = modulus.leading_zeros() >= (128 - type_bits + 1);
let add_assign_impl = if has_spare_bit {
quote! {
#[inline(always)]
fn add_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>) {
let val = a.value.wrapping_add(b.value);
a.value = if val >= Self::MODULUS { val - Self::MODULUS } else { val };
}
}
} else {
quote! {
#[inline(always)]
fn add_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>) {
let (mut val, overflow) = a.value.overflowing_add(b.value);
if overflow {
val += Self::T::MAX - Self::MODULUS + 1
}
if val >= Self::MODULUS {
val -= Self::MODULUS;
}
a.value = val;
}
}
};
let ts = quote! {
type T = #repr_type;
const MODULUS: Self::T = #modulus as Self::T;
const MODULUS_U128: u128 = #modulus;
const GENERATOR: SmallFp<Self> = SmallFp::from_raw(#generator_mont as Self::T);
const ZERO: SmallFp<Self> = SmallFp::from_raw(0 as Self::T);
const ONE: SmallFp<Self> = SmallFp::from_raw(#one_mont as Self::T);
const NEG_ONE: SmallFp<Self> = SmallFp::from_raw(#neg_one_mont as Self::T);
const TWO_ADICITY: u32 = #two_adicity;
const TWO_ADIC_ROOT_OF_UNITY: SmallFp<Self> = SmallFp::from_raw(#two_adic_root_mont as Self::T);
#mixed_radix_impl
#sqrt_precomp_impl
#add_assign_impl
#[inline(always)]
fn sub_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>) {
if a.value >= b.value {
a.value -= b.value;
} else {
a.value = Self::MODULUS - (b.value - a.value);
}
}
#[inline(always)]
fn double_in_place(a: &mut SmallFp<Self>) {
let tmp = *a;
Self::add_assign(a, &tmp);
}
#[inline(always)]
fn neg_in_place(a: &mut SmallFp<Self>) {
if a.value != Self::ZERO.value {
a.value = Self::MODULUS - a.value;
}
}
#mul_impl
#inverse_impl
#[inline(always)]
fn sum_of_products<const T: usize>(
a: &[SmallFp<Self>; T],
b: &[SmallFp<Self>; T],) -> SmallFp<Self> {
match T {
1 => {
let mut prod = a[0];
Self::mul_assign(&mut prod, &b[0]);
prod
},
2 => {
let mut prod1 = a[0];
Self::mul_assign(&mut prod1, &b[0]);
let mut prod2 = a[1];
Self::mul_assign(&mut prod2, &b[1]);
Self::add_assign(&mut prod1, &prod2);
prod1
},
_ => {
let mut acc = Self::ZERO;
for (x, y) in a.iter().zip(b.iter()) {
let mut prod = *x;
Self::mul_assign(&mut prod, y);
Self::add_assign(&mut acc, &prod);
}
acc
}
}
}
#[inline(always)]
fn square_in_place(a: &mut SmallFp<Self>) {
let tmp = *a;
Self::mul_assign(a, &tmp);
}
#[inline]
fn new(value: Self::T) -> SmallFp<Self> {
let reduced_value = value % Self::MODULUS;
let mut tmp = SmallFp::from_raw(reduced_value);
let r2_elem = SmallFp::from_raw(#r2 as Self::T);
Self::mul_assign(&mut tmp, &r2_elem);
tmp
}
#from_bigint_impl
#into_bigint_impl
};
(ts, r_mod_n)
}
fn generate_inverse_impl(
repr_type: &proc_macro2::TokenStream,
modulus: u128,
r_mod_n: u128,
r2: u128,
) -> proc_macro2::TokenStream {
let repr_type_str = repr_type.to_string();
let field_bits = 128 - modulus.leading_zeros();
let num_iters = 2 * field_bits - 2;
let half = (modulus + 1) / 2;
let two_neg_iters = pow_mod_const(half, num_iters as u128, modulus);
let r3 = mod_mul_const(r2, r_mod_n, modulus);
let corr = mod_mul_const(r3, two_neg_iters, modulus);
#[allow(clippy::if_not_else)]
if repr_type_str != "u64" {
quote! {
#[inline]
fn inverse(a: &SmallFp<Self>) -> Option<SmallFp<Self>> {
if a.value == 0 {
return None;
}
let mut rem_a: u64 = a.value as u64;
let mut rem_b: u64 = Self::MODULUS as u64;
let mut bezout_a: i64 = 1;
let mut bezout_b: i64 = 0;
let mut i = 0u32;
while i < #num_iters {
if rem_a & 1 != 0 {
if rem_a < rem_b {
(rem_a, rem_b) = (rem_b, rem_a);
(bezout_a, bezout_b) = (bezout_b, bezout_a);
}
rem_a -= rem_b;
bezout_a -= bezout_b;
}
rem_a >>= 1;
bezout_b <<= 1;
i += 1;
}
let bezout_canonical = bezout_b.rem_euclid(Self::MODULUS as i64) as u64;
let mut bezout_b_field = SmallFp::from_raw(bezout_canonical as Self::T);
let corr_field = SmallFp::from_raw(#corr as Self::T);
Self::mul_assign(&mut bezout_b_field, &corr_field);
Some(bezout_b_field)
}
}
} else {
let half_iters = num_iters / 2;
let half_iters_i64 = half_iters - 1;
quote! {
#[inline]
fn inverse(a: &SmallFp<Self>) -> Option<SmallFp<Self>> {
if a.value == 0 {
return None;
}
let mut rem_a: u64 = a.value;
let mut rem_b: u64 = Self::MODULUS;
let (r1_bezout_a, r1_bezout_b): (i128, i128);
{
let (mut bezout_a, mut mod_a, mut bezout_b, mut mod_b): (i64, i64, i64, i64) = (1, 0, 0, 1);
let mut i = 0u32;
while i < #half_iters_i64 {
if rem_a & 1 != 0 {
if rem_a < rem_b {
(rem_a, rem_b) = (rem_b, rem_a);
(bezout_a, bezout_b) = (bezout_b, bezout_a);
(mod_a, mod_b) = (mod_b, mod_a);
}
rem_a -= rem_b; rem_a >>= 1;
bezout_a -= bezout_b; mod_a -= mod_b;
} else {
rem_a >>= 1;
}
bezout_b <<= 1; mod_b <<= 1;
i += 1;
}
let (mut bezout_a, mut mod_a, mut bezout_b, mut mod_b) =
(bezout_a as i128, mod_a as i128, bezout_b as i128, mod_b as i128);
if rem_a & 1 != 0 {
if rem_a < rem_b {
(rem_a, rem_b) = (rem_b, rem_a);
(bezout_a, bezout_b) = (bezout_b, bezout_a);
(mod_a, mod_b) = (mod_b, mod_a);
}
rem_a -= rem_b; rem_a >>= 1;
bezout_a -= bezout_b; mod_a -= mod_b;
} else {
rem_a >>= 1;
}
bezout_b <<= 1; mod_b <<= 1;
r1_bezout_a = bezout_a; r1_bezout_b = bezout_b;
}
let (r2_bezout_b, r2_mod_b): (i128, i128);
{
let (mut bezout_a, mut mod_a, mut bezout_b, mut mod_b): (i64, i64, i64, i64) = (1, 0, 0, 1);
let mut i = 0u32;
while i < #half_iters_i64 {
if rem_a & 1 != 0 {
if rem_a < rem_b {
(rem_a, rem_b) = (rem_b, rem_a);
(bezout_a, bezout_b) = (bezout_b, bezout_a);
(mod_a, mod_b) = (mod_b, mod_a);
}
rem_a -= rem_b; rem_a >>= 1;
bezout_a -= bezout_b; mod_a -= mod_b;
} else {
rem_a >>= 1;
}
bezout_b <<= 1; mod_b <<= 1;
i += 1;
}
let (mut bezout_a, mut mod_a, mut bezout_b, mut mod_b) =
(bezout_a as i128, mod_a as i128, bezout_b as i128, mod_b as i128);
if rem_a & 1 != 0 {
if rem_a < rem_b {
(rem_a, rem_b) = (rem_b, rem_a);
(bezout_a, bezout_b) = (bezout_b, bezout_a);
(mod_a, mod_b) = (mod_b, mod_a);
}
rem_a -= rem_b; rem_a >>= 1;
bezout_a -= bezout_b; mod_a -= mod_b;
} else {
rem_a >>= 1;
}
bezout_b <<= 1; mod_b <<= 1;
r2_bezout_b = bezout_b; r2_mod_b = mod_b;
}
let bezout_raw = r2_bezout_b * r1_bezout_a + r2_mod_b * r1_bezout_b;
let p = Self::MODULUS as i128;
let bezout_canonical = ((bezout_raw % p) + p) as u128 % Self::MODULUS as u128;
let mut bezout_field = SmallFp::from_raw(bezout_canonical as Self::T);
let corr_field = SmallFp::from_raw(#corr as Self::T);
Self::mul_assign(&mut bezout_field, &corr_field);
Some(bezout_field)
}
}
}
}
fn generate_mul_impl(
repr_type: &proc_macro2::TokenStream,
modulus: u128,
k_bits: u32,
r_mask: u128,
n_prime: u128,
) -> proc_macro2::TokenStream {
let repr_type_str = repr_type.to_string();
let field_bits = 128 - modulus.leading_zeros();
let is_mersenne = field_bits >= 2 && modulus == (1u128 << field_bits) - 1;
let mul_ty = match repr_type_str.as_str() {
"u8" => quote! { u16 },
"u16" => quote! { u32 },
"u32" => quote! { u64 },
_ => quote! { u128 },
};
match (repr_type_str.as_str(), is_mersenne) {
("u8" | "u16", false) => {
quote! {
#[inline(always)]
fn mul_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>) {
const MODULUS_MUL_TY: #mul_ty = #modulus as #mul_ty;
const N_PRIME: #repr_type = #n_prime as #repr_type;
const MASK: #mul_ty = #r_mask as #mul_ty;
const K_BITS: u32 = #k_bits;
let tmp = (a.value as #mul_ty) * (b.value as #mul_ty);
let carry1 = (tmp >> K_BITS) as #repr_type;
let r = (tmp & MASK) as #repr_type;
let m = r.wrapping_mul(N_PRIME);
let tmp = (r as #mul_ty) + ((m as #mul_ty) * MODULUS_MUL_TY);
let carry2 = (tmp >> K_BITS) as #repr_type;
let mut r = (carry1 as #mul_ty) + (carry2 as #mul_ty);
if r >= MODULUS_MUL_TY { r -= MODULUS_MUL_TY; }
a.value = r as #repr_type;
}
}
},
("u8" | "u16" | "u32", true) => {
quote! {
#[inline(always)]
fn mul_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>) {
const K: u32 = #field_bits;
const MODULUS: #mul_ty = #modulus as #mul_ty;
let prod = (a.value as #mul_ty) * (b.value as #mul_ty);
let mut r = (prod & MODULUS) + (prod >> K);
if r >= MODULUS { r -= MODULUS; }
a.value = r as #repr_type;
}
}
},
("u32", false) => {
quote! {
#[inline(always)]
fn mul_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>) {
const MODULUS_MUL_TY: u64 = #modulus as u64;
const N_PRIME: u64 = #n_prime as u64;
const R_MASK: u64 = #r_mask as u64;
let t = (a.value as u64) * (b.value as u64);
let k = t.wrapping_mul(N_PRIME) & R_MASK;
let mut r = (t + (k * MODULUS_MUL_TY)) >> #k_bits;
if r >= MODULUS_MUL_TY { r -= MODULUS_MUL_TY; }
a.value = r as u32;
}
}
},
_ => {
let shift_bits = 128 - k_bits;
quote! {
#[inline(always)]
fn mul_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>) {
const MODULUS_MUL_TY: u128 = #modulus as u128;
const N_PRIME: u128 = #n_prime as u128;
const R_MASK: u128 = #r_mask as u128;
let mut t = (a.value as u128) * (b.value as u128);
let k = t.wrapping_mul(N_PRIME) & R_MASK;
let (t, overflow) = t.overflowing_add(k * MODULUS_MUL_TY);
let mut r = (t >> #k_bits) + ((overflow as u128) << #shift_bits);
if r >= MODULUS_MUL_TY { r -= MODULUS_MUL_TY; }
a.value = r as u64;
}
}
},
}
}
fn mod_inverse_pow2(n: u128, k_bits: u32) -> u128 {
const ITER: usize = 7; let mut inv = 1u128;
for _ in 0..ITER {
inv = inv.wrapping_mul(2u128.wrapping_sub(n.wrapping_mul(inv)));
}
let mask = (1u128 << k_bits) - 1;
inv.wrapping_neg() & mask
}
pub(crate) fn exit_impl(modulus: u128, r_mod_p: u128) -> proc_macro2::TokenStream {
quote! {
pub fn exit(a: &mut SmallFp<Self>) {
let one = SmallFp::from_raw(1 as <Self as SmallFpConfig>::T);
<Self as SmallFpConfig>::mul_assign(a, &one);
}
pub const fn from_u128(value: u128) -> SmallFp<Self> {
const MODULUS: u128 = #modulus;
const R_MOD_P: u128 = #r_mod_p;
const fn mod_mul(mut a: u128, mut b: u128, m: u128) -> u128 {
a %= m;
let mut result = 0u128;
while b > 0 {
if b & 1 == 1 {
result = (result + a) % m;
}
a = (a + a) % m;
b >>= 1;
}
result
}
let val = value % MODULUS;
let mont = mod_mul(val, R_MOD_P, MODULUS);
SmallFp::from_raw(mont as <Self as SmallFpConfig>::T)
}
}
}