use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
pub mod factorize;
pub mod prime;
pub type Context64 = Context<u64>;
pub type Context32 = Context<u32>;
pub type Modulo64<'a> = Modulo<'a, u64>;
pub type Modulo32<'a> = Modulo<'a, u32>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Context<U> {
n: U,
inv_n: U,
r2_mod_n: U,
}
impl<U> Context<U> {
pub const fn modulus(&self) -> &U {
&self.n
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct Modulo<'a, U> {
value: U,
ctx: &'a Context<U>,
}
macro_rules! montgomery_impl {
( $single:ty, $double:ty ) => {
impl Context<$single> {
#[inline]
pub const fn new(n: $single) -> Self {
assert!(n & 1 == 1, "modulus should be an odd number");
let inv_n = {
const TABLE: u32 = {
let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
let mut table = 0;
let mut i = 0;
while i < 8 {
table |= inv_n[i] << (i * 4);
i += 1;
}
table
};
let mut inv_n = ((TABLE >> (n & 0b1110) * 2) & 0b1111) as $single;
let mut d = const { <$single>::BITS.ilog2() - 2 };
while d > 0 {
inv_n =
inv_n.wrapping_mul((2 as $single).wrapping_sub(n.wrapping_mul(inv_n)));
d -= 1;
}
debug_assert!(n.wrapping_mul(inv_n) == 1);
inv_n
};
let r2_mod_n = ((n as $double).wrapping_neg() % (n as $double)) as $single;
Self { n, inv_n, r2_mod_n }
}
#[inline(always)]
pub const fn modulo(&self, x: $single) -> Modulo<'_, $single> {
let x = self.mul(x, self.r2_mod_n);
Modulo {
value: x,
ctx: &self,
}
}
#[inline(always)]
const fn mul(&self, lhs: $single, rhs: $single) -> $single {
self.mul_add(lhs, rhs, 0)
}
#[inline(always)]
const fn mul_add(&self, lhs: $single, rhs: $single, add: $single) -> $single {
let (x_hi, x_lo) = {
let x = lhs as $double * rhs as $double + add as $double;
((x >> <$single>::BITS) as $single, x as $single)
};
let y_hi = ((x_lo.wrapping_mul(self.inv_n) as $double * self.n as $double)
>> <$single>::BITS) as $single;
let (z, b) = x_hi.overflowing_sub(y_hi);
if b {
z.wrapping_add(self.n)
} else {
z
}
}
#[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
#[inline]
pub const fn can_divide(&self, x: $single) -> bool {
let x = self.mul(x, 1);
x == 0
}
}
impl<'a> Modulo<'a, $single> {
#[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
#[inline(always)]
pub const fn get(&self) -> $single {
self.ctx.mul(self.value, 1)
}
#[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
#[inline(always)]
pub const fn modulus(&self) -> $single {
self.ctx.n
}
#[doc = concat!(" let ctx = Context::<", stringify!($single), ">::new(n);")]
#[inline(always)]
pub const fn is_zero(self) -> bool {
self.value == 0
}
#[doc = concat!(" let ctx = Context::<", stringify!($single), ">::new(n);")]
#[doc = concat!(" assert_eq!(Modulo::<'_, ", stringify!($single), ">::zero(&ctx).get(), 0);")]
#[inline(always)]
pub const fn zero(ctx: &'a Context<$single>) -> Self {
Self { value: 0, ctx }
}
#[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(n);")]
#[inline]
pub const fn pow(mut self, mut exp: $single) -> Self {
let mut result = self.ctx.modulo(1).value;
while exp > 0 {
if exp & 1 == 1 {
result = self.ctx.mul(result, self.value)
}
exp >>= 1;
self.value = self.ctx.mul(self.value, self.value)
}
self.value = result;
self
}
#[doc = concat!("let ctx = Context::<", stringify!($single), ">::new(998_244_353);")]
#[inline]
pub const fn try_inv(self) -> Result<Self, $single> {
let mut a = self.get();
let Self { ctx, .. } = self;
let mut b = ctx.n;
let mut x = ctx.modulo(1).value; let mut y = 0; let frac_1_2 = ctx.modulo(ctx.n.div_ceil(2));
while a > 0 {
x = ctx.mul(x, frac_1_2.pow(a.trailing_zeros() as $single).value);
a >>= a.trailing_zeros();
if a < b {
(a, b) = (b, a);
(x, y) = (y, x);
}
a -= b;
let (diff, b) = x.overflowing_sub(y);
x = if b { diff.wrapping_add(ctx.n) } else { diff };
}
if b == 1 {
Ok(Self { value: y, ctx })
} else {
Err(b)
}
}
}
impl<'a> Add for Modulo<'a, $single> {
type Output = Self;
#[inline(always)]
fn add(mut self, rhs: Self) -> Self {
let (sum, b) = self.value.overflowing_add(rhs.value);
self.value = if b || sum >= self.ctx.n {
sum.wrapping_sub(self.ctx.n)
} else {
sum
};
self
}
}
impl<'a> Sub for Modulo<'a, $single> {
type Output = Self;
#[inline(always)]
fn sub(mut self, rhs: Self) -> Self {
let (diff, b) = self.value.overflowing_sub(rhs.value);
self.value = if b {
diff.wrapping_add(self.ctx.n)
} else {
diff
};
self
}
}
impl<'a> Mul for Modulo<'a, $single> {
type Output = Self;
#[inline(always)]
fn mul(mut self, rhs: Self) -> Self {
self.value = self.ctx.mul(self.value, rhs.value);
self
}
}
impl<'a> Neg for Modulo<'a, $single> {
type Output = Self;
#[inline(always)]
fn neg(mut self) -> Self::Output {
self.value = if self.value == 0 {
self.value
} else {
self.ctx.n - self.value
};
self
}
}
};
}
montgomery_impl!(u64, u128);
montgomery_impl!(u32, u64);
impl<'a, U> AddAssign for Modulo<'a, U>
where
Self: Add<Output = Self> + Copy,
{
#[inline(always)]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs
}
}
impl<'a, U> SubAssign for Modulo<'a, U>
where
Self: Sub<Output = Self> + Copy,
{
#[inline(always)]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs
}
}
impl<'a, U> MulAssign for Modulo<'a, U>
where
Self: Mul<Output = Self> + Copy,
{
#[inline(always)]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs
}
}