pub use modtype_derive::{use_modtype, ConstValue, ModType};
use num::{
integer, BigInt, BigUint, CheckedAdd as _, CheckedMul as _, CheckedSub as _, Float,
FromPrimitive, Integer, Num, One as _, PrimInt, Signed, ToPrimitive as _, Unsigned, Zero as _,
};
use std::convert::Infallible;
use std::fmt;
use std::iter::{Product, Sum};
use std::marker::PhantomData;
use std::num::ParseIntError;
use std::ops::{
AddAssign, BitAndAssign, BitOrAssign, BitXorAssign, DivAssign, MulAssign, RemAssign, SubAssign,
};
use std::str::FromStr;
pub trait UnsignedPrimitive:
Unsigned
+ PrimInt
+ Integer
+ Num<FromStrRadixErr = ParseIntError>
+ FromStr<Err = ParseIntError>
+ FromPrimitive
+ Into<BigUint>
+ Into<BigInt>
+ Default
+ fmt::Display
+ fmt::Debug
+ fmt::LowerHex
+ fmt::UpperHex
+ Sum
+ Product
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ BitAndAssign
+ BitOrAssign
+ BitXorAssign
+ Send
+ Sync
+ 'static
{
}
impl UnsignedPrimitive for u8 {}
impl UnsignedPrimitive for u16 {}
impl UnsignedPrimitive for u32 {}
impl UnsignedPrimitive for u64 {}
impl UnsignedPrimitive for u128 {}
impl UnsignedPrimitive for usize {}
pub trait SignedPrimitive:
Signed
+ PrimInt
+ Integer
+ Num<FromStrRadixErr = ParseIntError>
+ FromStr<Err = ParseIntError>
+ FromPrimitive
+ Into<BigInt>
+ Default
+ fmt::Display
+ fmt::Debug
+ fmt::LowerHex
+ fmt::UpperHex
+ Sum
+ Product
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ BitAndAssign
+ BitOrAssign
+ BitXorAssign
+ Send
+ Sync
+ 'static
{
}
impl SignedPrimitive for i8 {}
impl SignedPrimitive for i16 {}
impl SignedPrimitive for i32 {}
impl SignedPrimitive for i64 {}
impl SignedPrimitive for i128 {}
impl SignedPrimitive for isize {}
pub trait FloatPrimitive:
Signed
+ Float
+ Num<FromStrRadixErr = num::traits::ParseFloatError>
+ FromStr<Err = std::num::ParseFloatError>
+ FromPrimitive
+ Default
+ fmt::Display
+ fmt::Debug
+ Sum
+ Product
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Send
+ Sync
+ 'static
{
}
impl FloatPrimitive for f32 {}
impl FloatPrimitive for f64 {}
pub trait ConstValue {
type Value: Copy;
const VALUE: Self::Value;
}
pub trait Cartridge {
type Target: UnsignedPrimitive;
type Features: Features;
#[inline]
fn new(value: Self::Target, modulus: Self::Target) -> Self::Target {
if value >= modulus {
value % modulus
} else {
value
}
}
#[inline]
fn get(value: Self::Target, _modulus: Self::Target) -> Self::Target {
value
}
#[inline]
fn from_biguint(value: BigUint, modulus: Self::Target) -> Self::Target {
let modulus = Into::<BigUint>::into(modulus);
(value % modulus).to_string().parse().unwrap()
}
#[inline]
fn from_bigint(mut value: BigInt, modulus: Self::Target) -> Self::Target {
let is_neg = value.is_negative();
if is_neg {
value = -value;
}
let modulus_big = Into::<BigInt>::into(modulus);
let acc = (value % modulus_big)
.to_string()
.parse::<Self::Target>()
.unwrap();
if is_neg {
modulus - acc
} else {
acc
}
}
#[inline]
fn fmt_display(
value: Self::Target,
_modulus: Self::Target,
fmt: &mut fmt::Formatter,
) -> fmt::Result {
<Self::Target as fmt::Display>::fmt(&value, fmt)
}
#[inline]
fn fmt_debug(
value: Self::Target,
_modulus: Self::Target,
_ty: &'static str,
fmt: &mut fmt::Formatter,
) -> fmt::Result {
<Self::Target as fmt::Debug>::fmt(&value, fmt)
}
#[inline]
fn from_str(str: &str, modulus: Self::Target) -> Result<Self::Target, ParseIntError> {
str.parse().map(|v| Self::new(v, modulus))
}
#[inline]
fn neg(value: Self::Target, modulus: Self::Target) -> Self::Target
where
Self::Features: Features<Subtraction = True>,
{
modulus - value
}
#[inline]
fn add(lhs: Self::Target, rhs: Self::Target, modulus: Self::Target) -> Self::Target
where
Self::Features: Features<Addition = True>,
{
Self::new(lhs + rhs, modulus)
}
#[inline]
fn sub(lhs: Self::Target, rhs: Self::Target, modulus: Self::Target) -> Self::Target
where
Self::Features: Features<Subtraction = True>,
{
let acc = if lhs < rhs {
modulus + lhs - rhs
} else {
lhs - rhs
};
Self::new(acc, modulus)
}
#[inline]
fn mul(lhs: Self::Target, rhs: Self::Target, modulus: Self::Target) -> Self::Target
where
Self::Features: Features<Multiplication = True>,
{
Self::new(lhs * rhs, modulus)
}
#[inline]
fn div(lhs: Self::Target, rhs: Self::Target, modulus: Self::Target) -> Self::Target
where
Self::Features: Features<Division = True>,
{
if rhs == Self::Target::zero() {
panic!("attempt to divide by zero");
}
Self::checked_div(lhs, rhs, modulus)
.expect("could not divide. if the modulus is a prime, THIS IS A BUG.")
}
#[inline]
fn rem(lhs: Self::Target, rhs: Self::Target, modulus: Self::Target) -> Self::Target
where
Self::Features: Features<Division = True>,
{
if rhs == Self::Target::zero() {
panic!("attempt to calculate the remainder with a divisor of zero");
}
if integer::gcd(rhs, modulus) != Self::Target::one() {
panic!("{}/{} (mod {}) does not exist", lhs, rhs, modulus);
}
Self::Target::zero()
}
#[inline]
fn inv(value: Self::Target, modulus: Self::Target) -> Self::Target
where
Self::Features: Features<Division = True>,
{
Self::div(Self::Target::one(), value, modulus)
}
#[inline]
fn from_str_radix(
str: &str,
radix: u32,
modulus: Self::Target,
) -> Result<Self::Target, ParseIntError>
where
Self::Features:
Features<Addition = True, Subtraction = True, Multiplication = True, Division = True>,
{
Self::Target::from_str_radix(str, radix).map(|v| Self::new(v, modulus))
}
#[inline]
fn min_value(_modulus: Self::Target) -> Self::Target {
Self::Target::zero()
}
#[inline]
fn max_value(modulus: Self::Target) -> Self::Target {
modulus - Self::Target::one()
}
#[inline]
fn zero(_modulus: Self::Target) -> Self::Target
where
Self::Features: Features<Addition = True>,
{
Self::Target::zero()
}
#[inline]
fn is_zero(value: Self::Target, _modulus: Self::Target) -> bool
where
Self::Features: Features<Addition = True>,
{
value == Self::Target::zero()
}
#[inline]
fn one(_modulus: Self::Target) -> Self::Target
where
Self::Features: Features<Multiplication = True>,
{
Self::Target::one()
}
#[inline]
fn is_one(value: Self::Target, _modulus: Self::Target) -> bool
where
Self::Features: Features<Multiplication = True>,
{
value == Self::Target::one()
}
#[inline]
fn from_i64(value: i64, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_i128(value.to_i128()?, modulus)
}
#[inline]
fn from_u64(value: u64, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_u128(value.to_u128()?, modulus)
}
#[inline]
fn from_isize(value: isize, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_i128(value.to_i128()?, modulus)
}
#[inline]
fn from_i8(value: i8, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_i128(value.to_i128()?, modulus)
}
#[inline]
fn from_i16(value: i16, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_i128(value.to_i128()?, modulus)
}
#[inline]
fn from_i32(value: i32, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_i128(value.to_i128()?, modulus)
}
#[inline]
fn from_i128(value: i128, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
if value < 0 {
Self::from_u128((-value).to_u128()?, modulus).map(|v| Self::neg(v, modulus))
} else {
Self::from_u128(value.to_u128()?, modulus)
}
}
#[inline]
fn from_usize(value: usize, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_u128(value.to_u128()?, modulus)
}
#[inline]
fn from_u8(value: u8, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_u128(value.to_u128()?, modulus)
}
#[inline]
fn from_u16(value: u16, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_u128(value.to_u128()?, modulus)
}
#[inline]
fn from_u32(value: u32, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
Self::from_u128(value.to_u128()?, modulus)
}
#[inline]
fn from_u128(mut value: u128, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
let modulus = modulus.to_u128()?;
if value >= modulus {
value %= modulus;
}
Self::Target::from_u128(value)
}
#[inline]
fn from_float_prim<F: FloatPrimitive>(value: F, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True, Multiplication = True, Division = True>,
{
let (man, exp, sign) = value.integer_decode();
let acc = Self::mul(
Self::from_u64(man, modulus)?,
Self::pow_signed(Self::Target::one() + Self::Target::one(), exp, modulus),
modulus,
);
Some(match sign {
-1 => Self::neg(acc, modulus),
_ => acc,
})
}
#[inline]
fn checked_neg(value: Self::Target, modulus: Self::Target) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True>,
{
Some(Self::neg(value, modulus))
}
#[inline]
fn checked_add(
lhs: Self::Target,
rhs: Self::Target,
modulus: Self::Target,
) -> Option<Self::Target>
where
Self::Features: Features<Addition = True>,
{
lhs.checked_add(&rhs).map(|v| Self::new(v, modulus))
}
#[inline]
fn checked_sub(
lhs: Self::Target,
rhs: Self::Target,
modulus: Self::Target,
) -> Option<Self::Target>
where
Self::Features: Features<Subtraction = True>,
{
(lhs + modulus)
.checked_sub(&rhs)
.map(|v| Self::new(v, modulus))
}
#[inline]
fn checked_mul(
lhs: Self::Target,
rhs: Self::Target,
modulus: Self::Target,
) -> Option<Self::Target>
where
Self::Features: Features<Multiplication = True>,
{
lhs.checked_mul(&rhs).map(|v| Self::new(v, modulus))
}
#[inline]
fn checked_div(
lhs: Self::Target,
rhs: Self::Target,
modulus: Self::Target,
) -> Option<Self::Target>
where
Self::Features: Features<Division = True>,
{
#[allow(clippy::many_single_char_names)]
fn egcd(a: i128, b: i128) -> (i128, i128, i128) {
if a == 0 {
(b, 0, 1)
} else {
let (d, u, v) = egcd(b % a, a);
(d, v - (b / a) * u, u)
}
}
let lhs = lhs.to_i128()?;
let rhs = rhs.to_i128()?;
let modulus = modulus.to_i128()?;
if rhs == 0 {
return None;
}
let (d, u, _) = egcd(rhs, modulus);
if rhs % d != 0 {
return None;
}
let mut acc = (lhs * u) % modulus;
if acc < 0 {
acc += modulus;
}
Self::Target::from_i128(acc)
}
#[inline]
fn checked_rem(
_lhs: Self::Target,
rhs: Self::Target,
modulus: Self::Target,
) -> Option<Self::Target> {
if integer::gcd(rhs, modulus) == Self::Target::one() {
Some(Self::Target::zero())
} else {
None
}
}
#[inline]
fn pow_unsigned<E: UnsignedPrimitive>(
base: Self::Target,
exp: E,
modulus: Self::Target,
) -> Self::Target
where
Self::Features: Features<Multiplication = True>,
{
let (mut base, mut exp, mut acc) = (base, exp, Self::Target::one());
while exp > E::zero() {
if (exp & E::one()) == E::one() {
acc = Self::mul(acc, base, modulus);
}
exp /= E::one() + E::one();
base = Self::mul(base, base, modulus);
}
acc
}
#[inline]
fn pow_signed<E: SignedPrimitive>(
base: Self::Target,
exp: E,
modulus: Self::Target,
) -> Self::Target
where
Self::Features: Features<Multiplication = True, Division = True>,
{
let (mut base, mut exp, mut acc) = (base, exp, Self::Target::one());
let exp_neg = exp < E::zero();
if exp_neg {
exp = -exp;
}
while exp > E::zero() {
if (exp & E::one()) == E::one() {
acc = Self::mul(acc, base, modulus);
}
exp /= E::one() + E::one();
base = Self::mul(base, base, modulus);
}
if exp_neg {
acc = Self::inv(acc, modulus);
}
acc
}
}
pub enum DefaultCartridge<T: UnsignedPrimitive> {
Infallible(Infallible, PhantomData<fn() -> T>),
}
impl<T: UnsignedPrimitive> Cartridge for DefaultCartridge<T> {
type Target = T;
type Features = DefaultFeatures;
}
pub enum NonPrime<T: UnsignedPrimitive> {
Infallible(Infallible, PhantomData<fn() -> T>),
}
impl<T: UnsignedPrimitive> Cartridge for NonPrime<T> {
type Target = T;
type Features = DefaultFeatures;
}
pub trait Features {
type Addition: TypedBool;
type Subtraction: TypedBool;
type Multiplication: TypedBool;
type Division: TypedBool;
}
pub enum DefaultFeatures {}
impl Features for DefaultFeatures {
type Addition = True;
type Subtraction = True;
type Multiplication = True;
type Division = True;
}
pub trait TypedBool {}
pub enum True {}
impl TypedBool for True {}
pub enum False {}
impl TypedBool for False {}
pub type DefaultModType<M> =
ModType<<M as ConstValue>::Value, DefaultCartridge<<M as ConstValue>::Value>, M>;
#[derive(crate::ModType)]
#[modtype(modulus = "M::VALUE", cartridge = "C", modtype = "crate")]
pub struct ModType<T: UnsignedPrimitive, C: Cartridge<Target = T>, M: ConstValue<Value = T>> {
#[modtype(value)]
value: T,
phantom: PhantomData<fn() -> (C, M)>,
}
impl<T: UnsignedPrimitive, C: Cartridge<Target = T>, M: ConstValue<Value = T>> ModType<T, C, M> {
#[inline]
pub fn modulus() -> T {
<M as ConstValue>::VALUE
}
#[inline]
pub fn new(value: T) -> Self {
Self {
value: C::new(value, Self::modulus()),
phantom: PhantomData,
}
}
#[inline]
pub fn new_unchecked(value: T) -> Self {
Self {
value,
phantom: PhantomData,
}
}
#[inline]
pub fn get(self) -> T {
self.value
}
#[inline]
pub fn get_mut_unchecked(&mut self) -> &mut T {
&mut self.value
}
}
pub mod field_param {
use crate::{Cartridge, DefaultCartridge, UnsignedPrimitive};
use std::marker::PhantomData;
pub type DefaultModType<T> = ModType<T, DefaultCartridge<T>>;
#[derive(crate::ModType)]
#[modtype(
modulus = "self.modulus",
cartridge = "C",
modtype = "crate",
non_static_modulus
)]
pub struct ModType<T: UnsignedPrimitive, C: Cartridge<Target = T>> {
#[modtype(value)]
value: T,
modulus: T,
phantom: PhantomData<fn() -> C>,
}
impl<T: UnsignedPrimitive, C: Cartridge<Target = T>> ModType<T, C> {
#[inline]
pub fn new(value: T, modulus: T) -> Self {
Self {
value: C::new(value, modulus),
modulus,
phantom: PhantomData,
}
}
#[inline]
pub fn new_unchecked(value: T, modulus: T) -> Self {
Self {
value,
modulus,
phantom: PhantomData,
}
}
#[inline]
pub fn factory(modulus: T) -> impl Fn(T) -> Self {
move |n| Self::new(n, modulus)
}
#[inline]
pub fn get(self) -> T {
self.value
}
#[inline]
pub fn modulus(self) -> T {
self.modulus
}
}
}
pub mod thread_local {
use crate::{Cartridge, DefaultCartridge, UnsignedPrimitive};
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::thread::LocalKey;
pub type DefaultModType<T> = ModType<T, DefaultCartridge<T>>;
#[derive(crate::ModType)]
#[modtype(
modulus = "unsafe { T::modulus() }",
cartridge = "C",
modtype = "crate"
)]
pub struct ModType<T: HasThreadLocalModulus, C: Cartridge<Target = T>> {
#[modtype(value)]
value: T,
phantom: PhantomData<fn() -> C>,
}
impl<T: HasThreadLocalModulus, C: Cartridge<Target = T>> ModType<T, C> {
pub fn with<O, F: FnOnce(fn(T) -> Self) -> O>(modulus: T, f: F) -> O {
unsafe { T::set_modulus(modulus) };
let ret = f(Self::new);
unsafe { T::set_modulus(T::zero()) };
ret
}
#[inline]
pub fn modulus() -> T {
unsafe { T::modulus() }
}
#[inline]
pub fn new(value: T) -> Self {
Self {
value: C::new(value, Self::modulus()),
phantom: PhantomData,
}
}
#[inline]
pub fn new_unchecked(value: T) -> Self {
Self {
value,
phantom: PhantomData,
}
}
#[inline]
pub fn get(self) -> T {
self.value
}
#[inline]
pub fn get_mut_unchecked(&mut self) -> &mut T {
&mut self.value
}
}
pub trait HasThreadLocalModulus: UnsignedPrimitive {
fn local_key() -> &'static LocalKey<UnsafeCell<Self>>;
unsafe fn modulus() -> Self {
Self::local_key().with(|m| *m.get())
}
unsafe fn set_modulus(modulus: Self) {
Self::local_key().with(|m| *m.get() = modulus)
}
}
impl HasThreadLocalModulus for u8 {
fn local_key() -> &'static LocalKey<UnsafeCell<Self>> {
&MODULUS_U8
}
}
impl HasThreadLocalModulus for u16 {
fn local_key() -> &'static LocalKey<UnsafeCell<Self>> {
&MODULUS_U16
}
}
impl HasThreadLocalModulus for u32 {
fn local_key() -> &'static LocalKey<UnsafeCell<Self>> {
&MODULUS_U32
}
}
impl HasThreadLocalModulus for u64 {
fn local_key() -> &'static LocalKey<UnsafeCell<Self>> {
&MODULUS_U64
}
}
impl HasThreadLocalModulus for u128 {
fn local_key() -> &'static LocalKey<UnsafeCell<Self>> {
&MODULUS_U128
}
}
impl HasThreadLocalModulus for usize {
fn local_key() -> &'static LocalKey<UnsafeCell<Self>> {
&MODULUS_USIZE
}
}
thread_local! {
static MODULUS_U8: UnsafeCell<u8> = UnsafeCell::new(0);
static MODULUS_U16: UnsafeCell<u16> = UnsafeCell::new(0);
static MODULUS_U32: UnsafeCell<u32> = UnsafeCell::new(0);
static MODULUS_U64: UnsafeCell<u64> = UnsafeCell::new(0);
static MODULUS_U128: UnsafeCell<u128> = UnsafeCell::new(0);
static MODULUS_USIZE: UnsafeCell<usize> = UnsafeCell::new(0);
}
}