use std::mem::size_of;
use std::ops::{Add, AddAssign};
use crate::Group;
macro_rules! decl_int_prime_group {
($t:ty, $t_impl:ident) => {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct $t_impl<const MOD: $t>(
$t,
);
impl<const MOD: $t> $t_impl<MOD> {
pub fn new(x: $t) -> Self {
$t_impl(x % MOD)
}
}
impl<const MOD: $t> Add for $t_impl<MOD> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
$t_impl(match self.0.checked_add(rhs.0) {
Some(x) => x % MOD,
None => {
(self.0.wrapping_add(rhs.0) % MOD)
.wrapping_add(<$t>::MAX % MOD)
.wrapping_add(1)
% MOD
}
})
}
}
impl<const MOD: $t> AddAssign for $t_impl<MOD> {
fn add_assign(&mut self, rhs: Self) {
self.0 = match self.0.checked_add(rhs.0) {
Some(x) => x % MOD,
None => {
self.0
.wrapping_add(rhs.0)
.wrapping_add(<$t>::MAX % MOD)
.wrapping_add(1)
% MOD
}
};
}
}
impl<const LAMBDA: usize, const MOD: $t> Group<LAMBDA> for $t_impl<MOD> {
fn zero() -> Self {
$t_impl(0)
}
fn add_inverse(mut self) -> Self {
self.0 = match self.0.checked_add(MOD) {
Some(x) => x % MOD,
None => {
self.0
.wrapping_add(MOD)
.wrapping_add(<$t>::MAX % MOD)
.wrapping_add(1)
% MOD
}
};
self
}
}
impl<const LAMBDA: usize, const MOD: $t> From<[u8; LAMBDA]> for $t_impl<MOD> {
fn from(value: [u8; LAMBDA]) -> Self {
if cfg!(not(feature = "int-be")) {
$t_impl(<$t>::from_le_bytes(
(&value[..size_of::<$t>()]).clone().try_into().unwrap(),
))
} else {
$t_impl(<$t>::from_be_bytes(
(&value[..size_of::<$t>()]).clone().try_into().unwrap(),
))
}
}
}
impl<const LAMBDA: usize, const MOD: $t> From<$t_impl<MOD>> for [u8; LAMBDA] {
fn from(value: $t_impl<MOD>) -> Self {
let mut bs = [0; LAMBDA];
if cfg!(not(feature = "int-be")) {
bs[..size_of::<$t>()].copy_from_slice(&value.0.to_le_bytes());
} else {
bs[..size_of::<$t>()].copy_from_slice(&value.0.to_be_bytes());
}
bs
}
}
};
}
decl_int_prime_group!(u8, U8Group);
decl_int_prime_group!(u16, U16Group);
decl_int_prime_group!(u32, U32Group);
decl_int_prime_group!(u64, U64Group);
decl_int_prime_group!(u128, U128Group);
pub const PRIME_MAX_LE_U8_MAX: u8 = u8::MAX - 5 + 1;
pub const PRIME_MAX_LE_U16_MAX: u16 = u16::MAX - 15 + 1;
pub const PRIME_MAX_LE_U32_MAX: u32 = u32::MAX - 5 + 1;
pub const PRIME_MAX_LE_U64_MAX: u64 = u64::MAX - 59 + 1;
pub const PRIME_MAX_LE_U128_MAX: u128 = u128::MAX - 159 + 1;