use core::{
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
str::FromStr,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ConstModulo32<const M: u32>(u32);
impl<const M: u32> From<u32> for ConstModulo32<M> {
fn from(x: u32) -> Self {
Self(x % M)
}
}
#[auto_impl_ops::auto_ops]
impl<const M: u32> AddAssign<&ConstModulo32<M>> for ConstModulo32<M> {
fn add_assign(&mut self, other: &Self) {
let t = self.0 + other.0;
self.0 = if t >= M { t - M } else { t };
}
}
#[auto_impl_ops::auto_ops]
impl<const M: u32> SubAssign<&ConstModulo32<M>> for ConstModulo32<M> {
fn sub_assign(&mut self, other: &Self) {
let t = self.0 + (M - other.0);
self.0 = if t >= M { t - M } else { t };
}
}
impl<const M: u32> Neg for ConstModulo32<M> {
type Output = Self;
fn neg(self) -> Self::Output {
Self(M - self.0)
}
}
impl<const M: u32> Neg for &ConstModulo32<M> {
type Output = ConstModulo32<M>;
fn neg(self) -> Self::Output {
ConstModulo32(M - self.0)
}
}
#[auto_impl_ops::auto_ops]
impl<const M: u32> MulAssign<&ConstModulo32<M>> for ConstModulo32<M> {
fn mul_assign(&mut self, other: &Self) {
self.0 = ((self.0 as u64) * (other.0 as u64) % (M as u64)) as u32;
}
}
#[auto_impl_ops::auto_ops]
impl<const M: u32> DivAssign<&ConstModulo32<M>> for ConstModulo32<M> {
fn div_assign(&mut self, other: &Self) {
let t = ring_algorithm::modulo_division(self.0 as i64, other.0 as i64, M as i64)
.expect("Can't divide");
let t = if t < 0 { t + M as i64 } else { t };
self.0 = t as u32;
}
}
impl<const M: u32> FromStr for ConstModulo32<M> {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut v = s.parse::<i128>()?;
v %= M as i128;
if v < 0 {
v += M as i128;
}
Ok(Self(v as u32))
}
}
impl<const M: u32> std::fmt::Display for ConstModulo32<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "{}", self.0)
}
}
#[cfg(feature = "num-traits")]
mod impl_num_traits {
use crate::ConstModulo32;
use num_traits::{ConstOne, ConstZero, One, Zero};
impl<const M: u32> Zero for ConstModulo32<M> {
fn zero() -> Self {
Self::ZERO
}
fn is_zero(&self) -> bool {
*self == Self::ZERO
}
}
impl<const M: u32> ConstZero for ConstModulo32<M> {
const ZERO: Self = Self(0);
}
impl<const M: u32> One for ConstModulo32<M> {
fn one() -> Self {
Self::ONE
}
}
impl<const M: u32> ConstOne for ConstModulo32<M> {
const ONE: Self = Self(1);
}
}
#[test]
fn add() {
let a = ConstModulo32::<6>(2);
let b = ConstModulo32::<6>(3);
let c = ConstModulo32::<6>(4);
let d = ConstModulo32::<6>(5);
let e = ConstModulo32::<6>(1);
let f = ConstModulo32::<6>(0);
assert_eq!(a + b, d);
assert_eq!(b + c, e);
assert_eq!(b + b, f);
}
#[test]
fn sub() {
let a = ConstModulo32::<6>(3);
let b = ConstModulo32::<6>(2);
let c = ConstModulo32::<6>(4);
let d = ConstModulo32::<6>(1);
let e = ConstModulo32::<6>(4);
let f = ConstModulo32::<6>(0);
assert_eq!(a - b, d);
assert_eq!(b - c, e);
assert_eq!(a - a, f);
}
#[test]
fn mul() {
let a = ConstModulo32::<6>(2);
let b = ConstModulo32::<6>(5);
let c = ConstModulo32::<6>(3);
let d = ConstModulo32::<6>(4);
let e = ConstModulo32::<6>(3);
let f = ConstModulo32::<6>(0);
assert_eq!(a * b, d);
assert_eq!(b * c, e);
assert_eq!(a * c, f);
}
#[test]
fn div() {
let a = ConstModulo32::<5>(2);
let b = ConstModulo32::<5>(3);
let c = ConstModulo32::<5>(4);
let d = ConstModulo32::<5>(4);
let e = ConstModulo32::<5>(2);
let f = ConstModulo32::<5>(3);
assert_eq!(a / b, d);
assert_eq!(b / c, e);
assert_eq!(a / c, f);
}