#![warn(missing_docs)]
use std::fmt;
use std::ops::{Add, Mul, Sub};
pub trait Modular {
fn to_modulo(self, modulus: u32) -> Modulo;
fn is_congruent(self, with: impl Into<i32>, modulus: u32) -> bool;
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct Modulo {
remainder: i32,
modulus: u32,
}
impl Modulo {
pub fn remainder(self) -> i32 {
self.remainder
}
pub fn modulus(self) -> u32 {
self.modulus
}
}
impl Modular for i32 {
fn to_modulo(self, modulus: u32) -> Modulo {
Modulo {
remainder: self % modulus as i32,
modulus,
}
}
fn is_congruent(self, with: impl Into<i32>, modulus: u32) -> bool {
(self - with.into()) % modulus as i32 == 0
}
}
impl Add for Modulo {
type Output = Self;
fn add(self, rhs: Self) -> Self {
if self.modulus() != rhs.modulus() {
panic!("Addition is only valid for modulo numbers with the same dividend")
}
(self.remainder() + rhs.remainder()).to_modulo(self.modulus())
}
}
impl Sub for Modulo {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
if self.modulus() != rhs.modulus() {
panic!("Subtraction is only valid for modulo numbers with the same dividend")
}
if self.remainder() >= rhs.remainder() {
modulo!(self.remainder() - rhs.remainder(), self.modulus())
} else {
modulo!(
self.remainder() - rhs.remainder() + self.modulus() as i32,
self.modulus()
)
}
}
}
impl Mul for Modulo {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
if self.modulus() != rhs.modulus() {
panic!("Multiplication is only valid for modulo numbers with the same dividend")
}
(self.remainder() * rhs.remainder()).to_modulo(self.modulus())
}
}
impl fmt::Display for Modulo {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?} mod {:?}", self.remainder, self.modulus)
}
}
#[macro_export]
macro_rules! modulo {
($rem:expr, $div:expr) => {
$rem.to_modulo($div)
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_using_trait() {
assert!(27.to_modulo(5) == modulo!(2, 5));
}
#[test]
fn create_using_macro() {
assert!(modulo!(99, 4) == 99.to_modulo(4));
}
#[test]
fn get_remainder() {
assert_eq!(modulo!(26, 11).remainder(), 4);
}
#[test]
fn get_modulus() {
assert_eq!(modulo!(121, 17).modulus(), 17);
}
#[test]
fn add_successfully() {
assert!(modulo!(23, 4) + modulo!(11, 4) == modulo!(2, 4));
}
#[test]
#[should_panic]
fn add_panics_with_different_moduli() {
assert!(modulo!(23, 5) + modulo!(11, 6) == modulo!(2, 5));
}
#[test]
fn subtract_successfully() {
assert!(modulo!(22, 4) - modulo!(13, 4) == modulo!(1, 4));
}
#[test]
#[should_panic]
fn subtract_panics_with_different_moduli() {
assert!(modulo!(47, 43) - modulo!(5, 27) == modulo!(12, 13));
}
#[test]
fn multiply_successfully() {
assert!(modulo!(2, 4) * modulo!(19, 4) == modulo!(2, 4));
}
#[test]
#[should_panic]
fn multiply_panics_with_different_moduli() {
assert!(modulo!(91, 92) - modulo!(8, 9) == modulo!(12, 47));
}
#[test]
fn string_representation() {
let mod_new = modulo!(6, 7u32);
assert_eq!(format!("{}", mod_new), "6 mod 7");
}
}