#![allow(clippy::unknown_clippy_lints)]
use num_traits::{
checked_pow, CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedSub, NumOps, One, Pow,
Signed, Unsigned, WrappingAdd, WrappingMul, WrappingNeg, WrappingSub, Zero,
};
use core::{
cmp::Ordering,
convert::{TryFrom, TryInto},
fmt,
marker::PhantomData,
mem, ops,
};
use crate::error::ArithmeticError;
#[cfg(feature = "bigint")]
mod bigint;
pub trait Arithmetic<T> {
fn add(&self, x: T, y: T) -> Result<T, ArithmeticError>;
fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError>;
fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError>;
fn div(&self, x: T, y: T) -> Result<T, ArithmeticError>;
fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError>;
fn neg(&self, x: T) -> Result<T, ArithmeticError>;
fn eq(&self, x: &T, y: &T) -> bool;
}
pub trait OrdArithmetic<T>: Arithmetic<T> {
fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering>;
}
impl<T> fmt::Debug for dyn OrdArithmetic<T> + '_ {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.debug_tuple("OrdArithmetic").finish()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct StdArithmetic;
impl<T> Arithmetic<T> for StdArithmetic
where
T: Clone + NumOps + PartialEq + ops::Neg<Output = T> + Pow<T, Output = T>,
{
#[inline]
fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
Ok(x + y)
}
#[inline]
fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
Ok(x - y)
}
#[inline]
fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
Ok(x * y)
}
#[inline]
fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
Ok(x / y)
}
#[inline]
fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
Ok(x.pow(y))
}
#[inline]
fn neg(&self, x: T) -> Result<T, ArithmeticError> {
Ok(-x)
}
#[inline]
fn eq(&self, x: &T, y: &T) -> bool {
*x == *y
}
}
impl<T> OrdArithmetic<T> for StdArithmetic
where
Self: Arithmetic<T>,
T: PartialOrd,
{
fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
x.partial_cmp(y)
}
}
#[cfg(all(test, feature = "std"))]
static_assertions::assert_impl_all!(StdArithmetic: OrdArithmetic<f32>, OrdArithmetic<f64>);
#[cfg(all(test, feature = "complex"))]
static_assertions::assert_impl_all!(
StdArithmetic: Arithmetic<num_complex::Complex32>,
Arithmetic<num_complex::Complex64>
);
pub trait CheckedArithmeticKind<T> {
fn checked_neg(value: T) -> Option<T>;
}
#[derive(Debug)]
pub struct CheckedArithmetic<Kind = Checked>(PhantomData<Kind>);
impl<Kind> Clone for CheckedArithmetic<Kind> {
fn clone(&self) -> Self {
Self(self.0)
}
}
impl<Kind> Copy for CheckedArithmetic<Kind> {}
impl<Kind> Default for CheckedArithmetic<Kind> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<Kind> CheckedArithmetic<Kind> {
pub const fn new() -> Self {
Self(PhantomData)
}
}
impl<T, Kind> Arithmetic<T> for CheckedArithmetic<Kind>
where
T: Clone + PartialEq + Zero + One + CheckedAdd + CheckedSub + CheckedMul + CheckedDiv,
Kind: CheckedArithmeticKind<T>,
usize: TryFrom<T>,
{
#[inline]
fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
x.checked_add(&y).ok_or(ArithmeticError::IntegerOverflow)
}
#[inline]
fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
x.checked_sub(&y).ok_or(ArithmeticError::IntegerOverflow)
}
#[inline]
fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
x.checked_mul(&y).ok_or(ArithmeticError::IntegerOverflow)
}
#[inline]
fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
x.checked_div(&y).ok_or(ArithmeticError::DivisionByZero)
}
#[inline]
#[allow(clippy::map_err_ignore)]
fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
let exp = usize::try_from(y).map_err(|_| ArithmeticError::InvalidExponent)?;
checked_pow(x, exp).ok_or(ArithmeticError::IntegerOverflow)
}
#[inline]
fn neg(&self, x: T) -> Result<T, ArithmeticError> {
Kind::checked_neg(x).ok_or(ArithmeticError::IntegerOverflow)
}
#[inline]
fn eq(&self, x: &T, y: &T) -> bool {
*x == *y
}
}
#[derive(Debug)]
pub struct Checked(());
impl<T: CheckedNeg> CheckedArithmeticKind<T> for Checked {
fn checked_neg(value: T) -> Option<T> {
value.checked_neg()
}
}
#[derive(Debug)]
pub struct NegateOnlyZero(());
impl<T: Unsigned + Zero> CheckedArithmeticKind<T> for NegateOnlyZero {
fn checked_neg(value: T) -> Option<T> {
if value.is_zero() {
Some(value)
} else {
None
}
}
}
#[derive(Debug)]
pub struct Unchecked(());
impl<T: Signed> CheckedArithmeticKind<T> for Unchecked {
fn checked_neg(value: T) -> Option<T> {
Some(-value)
}
}
impl<T, Kind> OrdArithmetic<T> for CheckedArithmetic<Kind>
where
Self: Arithmetic<T>,
T: PartialOrd,
{
#[inline]
fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
x.partial_cmp(y)
}
}
#[cfg(test)]
static_assertions::assert_impl_all!(
CheckedArithmetic: OrdArithmetic<u8>,
OrdArithmetic<i8>,
OrdArithmetic<u16>,
OrdArithmetic<i16>,
OrdArithmetic<u32>,
OrdArithmetic<i32>,
OrdArithmetic<u64>,
OrdArithmetic<i64>,
OrdArithmetic<u128>,
OrdArithmetic<i128>
);
#[derive(Debug, Clone, Copy, Default)]
pub struct WrappingArithmetic;
impl<T> Arithmetic<T> for WrappingArithmetic
where
T: Copy
+ PartialEq
+ Zero
+ One
+ WrappingAdd
+ WrappingSub
+ WrappingMul
+ WrappingNeg
+ ops::Div<T, Output = T>,
usize: TryFrom<T>,
{
#[inline]
fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
Ok(x.wrapping_add(&y))
}
#[inline]
fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
Ok(x.wrapping_sub(&y))
}
#[inline]
fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
Ok(x.wrapping_mul(&y))
}
#[inline]
fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
if y.is_zero() {
Err(ArithmeticError::DivisionByZero)
} else if y.wrapping_neg().is_one() {
Ok(x.wrapping_neg())
} else {
Ok(x / y)
}
}
#[inline]
#[allow(clippy::map_err_ignore)]
fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
let exp = usize::try_from(y).map_err(|_| ArithmeticError::InvalidExponent)?;
Ok(wrapping_exp(x, exp))
}
#[inline]
fn neg(&self, x: T) -> Result<T, ArithmeticError> {
Ok(x.wrapping_neg())
}
#[inline]
fn eq(&self, x: &T, y: &T) -> bool {
*x == *y
}
}
impl<T> OrdArithmetic<T> for WrappingArithmetic
where
Self: Arithmetic<T>,
T: PartialOrd,
{
#[inline]
fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
x.partial_cmp(y)
}
}
fn wrapping_exp<T: Copy + One + WrappingMul>(mut base: T, mut exp: usize) -> T {
if exp == 0 {
return T::one();
}
while exp & 1 == 0 {
base = base.wrapping_mul(&base);
exp >>= 1;
}
if exp == 1 {
return base;
}
let mut acc = base;
while exp > 1 {
exp >>= 1;
base = base.wrapping_mul(&base);
if exp & 1 == 1 {
acc = acc.wrapping_mul(&base);
}
}
acc
}
#[cfg(test)]
static_assertions::assert_impl_all!(
WrappingArithmetic: OrdArithmetic<u8>,
OrdArithmetic<i8>,
OrdArithmetic<u16>,
OrdArithmetic<i16>,
OrdArithmetic<u32>,
OrdArithmetic<i32>,
OrdArithmetic<u64>,
OrdArithmetic<i64>,
OrdArithmetic<u128>,
OrdArithmetic<i128>
);
pub trait DoubleWidth: Sized + Unsigned {
type Wide: Copy + From<Self> + TryInto<Self> + NumOps + Unsigned;
type SignedWide: Copy + From<Self> + TryInto<Self> + NumOps + Zero + One + Signed + PartialOrd;
}
impl DoubleWidth for u8 {
type Wide = u16;
type SignedWide = i16;
}
impl DoubleWidth for u16 {
type Wide = u32;
type SignedWide = i32;
}
impl DoubleWidth for u32 {
type Wide = u64;
type SignedWide = i64;
}
impl DoubleWidth for u64 {
type Wide = u128;
type SignedWide = i128;
}
#[derive(Debug, Clone, Copy)]
pub struct ModularArithmetic<T> {
modulus: T,
}
impl<T> ModularArithmetic<T>
where
T: Clone + PartialEq + NumOps + Unsigned + Zero + One,
{
pub fn new(modulus: T) -> Self {
assert!(!modulus.is_zero(), "Modulus cannot be 0");
assert!(!modulus.is_one(), "Modulus cannot be 1");
Self { modulus }
}
pub fn modulus(&self) -> &T {
&self.modulus
}
}
impl<T> ModularArithmetic<T>
where
T: Copy + PartialEq + NumOps + Unsigned + Zero + One + DoubleWidth,
{
#[inline]
fn mul_inner(self, x: T, y: T) -> T {
let wide = (<T::Wide>::from(x) * <T::Wide>::from(y)) % <T::Wide>::from(self.modulus);
wide.try_into().ok().unwrap()
}
fn invert(self, value: T) -> Option<T> {
let value = value % self.modulus;
let mut t = <T::SignedWide>::zero();
let mut new_t = <T::SignedWide>::one();
let modulus = <T::SignedWide>::from(self.modulus);
let mut r = modulus;
let mut new_r = <T::SignedWide>::from(value);
while !new_r.is_zero() {
let quotient = r / new_r;
t = t - quotient * new_t;
mem::swap(&mut new_t, &mut t);
r = r - quotient * new_r;
mem::swap(&mut new_r, &mut r);
}
if r > <T::SignedWide>::one() {
None
} else {
if t.is_negative() {
t = t + modulus;
}
Some(t.try_into().ok().unwrap())
}
}
fn modular_exp(self, base: T, mut exp: usize) -> T {
if exp == 0 {
return T::one();
}
let wide_modulus = <T::Wide>::from(self.modulus);
let mut base = <T::Wide>::from(base % self.modulus);
while exp & 1 == 0 {
base = (base * base) % wide_modulus;
exp >>= 1;
}
if exp == 1 {
return base.try_into().ok().unwrap();
}
let mut acc = base;
while exp > 1 {
exp >>= 1;
base = (base * base) % wide_modulus;
if exp & 1 == 1 {
acc = (acc * base) % wide_modulus;
}
}
acc.try_into().ok().unwrap()
}
}
impl<T> Arithmetic<T> for ModularArithmetic<T>
where
T: Copy + PartialEq + NumOps + Zero + One + DoubleWidth,
usize: TryFrom<T>,
{
#[inline]
fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
let wide = (<T::Wide>::from(x) + <T::Wide>::from(y)) % <T::Wide>::from(self.modulus);
Ok(wide.try_into().ok().unwrap())
}
#[inline]
fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
let y = y % self.modulus;
self.add(x, self.modulus - y)
}
#[inline]
fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
Ok(self.mul_inner(x, y))
}
#[inline]
fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
if y.is_zero() {
Err(ArithmeticError::DivisionByZero)
} else {
let y_inv = self.invert(y).ok_or(ArithmeticError::NoInverse)?;
self.mul(x, y_inv)
}
}
#[inline]
#[allow(clippy::map_err_ignore)]
fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
let exp = usize::try_from(y).map_err(|_| ArithmeticError::InvalidExponent)?;
Ok(self.modular_exp(x, exp))
}
#[inline]
fn neg(&self, x: T) -> Result<T, ArithmeticError> {
let x = x % self.modulus;
Ok(self.modulus - x)
}
#[inline]
fn eq(&self, x: &T, y: &T) -> bool {
*x % self.modulus == *y % self.modulus
}
}
#[cfg(test)]
static_assertions::assert_impl_all!(ModularArithmetic<u8>: Arithmetic<u8>);
#[cfg(test)]
static_assertions::assert_impl_all!(ModularArithmetic<u16>: Arithmetic<u16>);
#[cfg(test)]
static_assertions::assert_impl_all!(ModularArithmetic<u32>: Arithmetic<u32>);
#[cfg(test)]
static_assertions::assert_impl_all!(ModularArithmetic<u64>: Arithmetic<u64>);
pub struct FullArithmetic<T, A> {
base: A,
comparison: fn(&T, &T) -> Option<Ordering>,
}
impl<T, A: Clone> Clone for FullArithmetic<T, A> {
fn clone(&self) -> Self {
Self {
base: self.base.clone(),
comparison: self.comparison,
}
}
}
impl<T, A: Copy> Copy for FullArithmetic<T, A> {}
impl<T, A: fmt::Debug> fmt::Debug for FullArithmetic<T, A> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("FullArithmetic")
.field("base", &self.base)
.finish()
}
}
impl<T, A> Arithmetic<T> for FullArithmetic<T, A>
where
A: Arithmetic<T>,
{
#[inline]
fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
self.base.add(x, y)
}
#[inline]
fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
self.base.sub(x, y)
}
#[inline]
fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
self.base.mul(x, y)
}
#[inline]
fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
self.base.div(x, y)
}
#[inline]
fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
self.base.pow(x, y)
}
#[inline]
fn neg(&self, x: T) -> Result<T, ArithmeticError> {
self.base.neg(x)
}
#[inline]
fn eq(&self, x: &T, y: &T) -> bool {
self.base.eq(x, y)
}
}
impl<T, A> OrdArithmetic<T> for FullArithmetic<T, A>
where
A: Arithmetic<T>,
{
fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
(self.comparison)(x, y)
}
}
pub trait ArithmeticExt<T>: Arithmetic<T> + Sized {
fn without_comparisons(self) -> FullArithmetic<T, Self> {
FullArithmetic {
base: self,
comparison: |_, _| None,
}
}
fn with_natural_comparison(self) -> FullArithmetic<T, Self>
where
T: PartialOrd,
{
FullArithmetic {
base: self,
comparison: |x, y| x.partial_cmp(y),
}
}
fn with_comparison(
self,
comparison: fn(&T, &T) -> Option<Ordering>,
) -> FullArithmetic<T, Self> {
FullArithmetic {
base: self,
comparison,
}
}
}
impl<T, A> ArithmeticExt<T> for A where A: Arithmetic<T> {}
#[cfg(test)]
mod tests {
use super::*;
use rand::{rngs::StdRng, Rng, SeedableRng};
#[test]
fn modular_arithmetic_basics() {
let arithmetic = ModularArithmetic::new(11_u32);
assert_eq!(arithmetic.add(1, 5).unwrap(), 6);
assert_eq!(arithmetic.add(2, 9).unwrap(), 0);
assert_eq!(arithmetic.add(5, 9).unwrap(), 3);
assert_eq!(arithmetic.add(5, 20).unwrap(), 3);
assert_eq!(arithmetic.sub(5, 9).unwrap(), 7);
assert_eq!(arithmetic.sub(5, 20).unwrap(), 7);
assert_eq!(arithmetic.mul(5, 4).unwrap(), 9);
assert_eq!(arithmetic.mul(11, 4).unwrap(), 0);
assert_eq!(u32::max_value() % 11, 3);
assert_eq!(
arithmetic.mul(u32::max_value(), u32::max_value()).unwrap(),
9
);
assert_eq!(arithmetic.div(1, 4).unwrap(), 3);
assert_eq!(arithmetic.div(2, 4).unwrap(), 6);
assert_eq!(arithmetic.div(1, 9).unwrap(), 5);
assert_eq!(arithmetic.pow(2, 5).unwrap(), 10);
assert_eq!(arithmetic.pow(3, 10).unwrap(), 1);
assert_eq!(arithmetic.pow(3, 4).unwrap(), 4);
assert_eq!(arithmetic.pow(7, 3).unwrap(), 2);
}
#[test]
fn modular_arithmetic_never_overflows() {
const MODULUS: u8 = 241;
let arithmetic = ModularArithmetic::new(MODULUS);
for x in 0..=u8::max_value() {
for y in 0..=u8::max_value() {
let expected = (u16::from(x) + u16::from(y)) % u16::from(MODULUS);
assert_eq!(u16::from(arithmetic.add(x, y).unwrap()), expected);
let mut expected = (i16::from(x) - i16::from(y)) % i16::from(MODULUS);
if expected < 0 {
expected += i16::from(MODULUS);
}
assert_eq!(i16::from(arithmetic.sub(x, y).unwrap()), expected);
let expected = (u16::from(x) * u16::from(y)) % u16::from(MODULUS);
assert_eq!(u16::from(arithmetic.mul(x, y).unwrap()), expected);
}
}
for x in 0..=u8::max_value() {
let inv = arithmetic.invert(x);
if x % MODULUS == 0 {
assert!(inv.is_none());
} else {
let inv = u16::from(inv.unwrap());
assert_eq!((inv * u16::from(x)) % u16::from(MODULUS), 1);
}
}
}
const SAMPLE_COUNT: usize = 25_000;
fn mini_fuzz_for_prime_modulus(modulus: u64) {
let arithmetic = ModularArithmetic::new(modulus);
let unsigned_wide_mod = u128::from(modulus);
let signed_wide_mod = i128::from(modulus);
let mut rng = StdRng::seed_from_u64(modulus);
for (x, y) in (0..SAMPLE_COUNT).map(|_| rng.gen::<(u64, u64)>()) {
let expected = (u128::from(x) + u128::from(y)) % unsigned_wide_mod;
assert_eq!(u128::from(arithmetic.add(x, y).unwrap()), expected);
let mut expected = (i128::from(x) - i128::from(y)) % signed_wide_mod;
if expected < 0 {
expected += signed_wide_mod;
}
assert_eq!(i128::from(arithmetic.sub(x, y).unwrap()), expected);
let expected = (u128::from(x) * u128::from(y)) % unsigned_wide_mod;
assert_eq!(u128::from(arithmetic.mul(x, y).unwrap()), expected);
}
for x in (0..SAMPLE_COUNT).map(|_| rng.gen::<u64>()) {
let inv = arithmetic.invert(x);
if x % modulus == 0 {
assert!(inv.is_none());
} else {
let inv = u128::from(inv.unwrap());
assert_eq!((inv * u128::from(x)) % unsigned_wide_mod, 1);
}
}
for _ in 0..(SAMPLE_COUNT / 10) {
let x = rng.gen::<u64>();
let wide_x = u128::from(x);
let exp = rng.gen_range(1_u64, 1_000);
let expected_pow = (0..exp).fold(1_u128, |acc, _| (acc * wide_x) % unsigned_wide_mod);
assert_eq!(u128::from(arithmetic.pow(x, exp).unwrap()), expected_pow);
if x % modulus != 0 {
let pow = arithmetic.pow(x, modulus - 1).unwrap();
assert_eq!(pow, 1);
}
}
}
#[test]
fn mini_fuzz_for_small_modulus() {
mini_fuzz_for_prime_modulus(3);
mini_fuzz_for_prime_modulus(7);
mini_fuzz_for_prime_modulus(23);
mini_fuzz_for_prime_modulus(61);
}
#[test]
fn mini_fuzz_for_u32_modulus() {
mini_fuzz_for_prime_modulus(3_000_000_019);
mini_fuzz_for_prime_modulus(3_500_000_011);
mini_fuzz_for_prime_modulus(4_000_000_007);
}
#[test]
fn mini_fuzz_for_large_u64_modulus() {
mini_fuzz_for_prime_modulus(2_594_642_710_891_962_701);
mini_fuzz_for_prime_modulus(5_647_618_287_156_850_721);
mini_fuzz_for_prime_modulus(9_223_372_036_854_775_837);
mini_fuzz_for_prime_modulus(10_902_486_311_044_492_273);
}
}