use num_traits::{
ops::overflowing::{OverflowingAdd, OverflowingSub},
AsPrimitive, ConstOne, ConstZero, PrimInt, Unsigned, WrappingAdd, WrappingMul, WrappingSub,
};
use crate::fp::MAX_ROOTS;
pub trait Word:
'static
+ Unsigned
+ PrimInt
+ OverflowingAdd
+ OverflowingSub
+ WrappingAdd
+ WrappingSub
+ WrappingMul
+ ConstZero
+ ConstOne
+ From<bool>
{
const BITS: usize;
}
impl Word for u32 {
const BITS: usize = Self::BITS as usize;
}
impl Word for u64 {
const BITS: usize = Self::BITS as usize;
}
impl Word for u128 {
const BITS: usize = Self::BITS as usize;
}
pub trait FieldParameters<W: Word> {
const PRIME: W;
const MU: W;
const R2: W;
const G: W;
const NUM_ROOTS: usize;
const BIT_MASK: W;
const ROOTS: [W; MAX_ROOTS + 1];
#[cfg(test)]
const LOG2_BASE: usize;
#[cfg(test)]
const LOG2_RADIX: usize;
}
pub trait FieldOps<W: Word>: FieldParameters<W> {
#[inline(always)]
fn add(x: W, y: W) -> W {
let (z, carry) = x.overflowing_add(&y);
let (s0, b0) = z.overflowing_sub(&Self::PRIME);
let (_s1, b1) =
<W as From<bool>>::from(carry).overflowing_sub(&<W as From<bool>>::from(b0));
let mask = W::ZERO.wrapping_sub(&<W as From<bool>>::from(b1));
(z & mask) | (s0 & !mask)
}
#[inline(always)]
fn sub(x: W, y: W) -> W {
let (z0, b0) = x.overflowing_sub(&y);
let mask = W::ZERO.wrapping_sub(&<W as From<bool>>::from(b0));
z0.wrapping_add(&(mask & Self::PRIME))
}
#[inline(always)]
fn neg(x: W) -> W {
Self::sub(W::ZERO, x)
}
#[inline(always)]
fn modp(x: W) -> W {
Self::sub(x, Self::PRIME)
}
fn mul(x: W, y: W) -> W;
fn pow(x: W, exp: W) -> W {
let mut t = Self::ROOTS[0];
for i in (0..W::BITS - (exp.leading_zeros() as usize)).rev() {
t = Self::mul(t, t);
if (exp >> i) & W::ONE != W::ZERO {
t = Self::mul(t, x);
}
}
t
}
#[inline(always)]
fn inv(x: W) -> W {
Self::pow(x, Self::PRIME - W::ONE - W::ONE)
}
#[inline(always)]
fn montgomery(x: W) -> W {
Self::modp(Self::mul(x, Self::R2))
}
#[inline(always)]
fn residue(x: W) -> W {
Self::modp(Self::mul(x, W::ONE))
}
}
pub(crate) trait FieldMulOpsSingleWord<W>: FieldParameters<W>
where
W: Word + AsPrimitive<Self::DoubleWord>,
{
type DoubleWord: Word + AsPrimitive<W>;
fn mul(x: W, y: W) -> W {
let hi_lo = |v: Self::DoubleWord| -> (W, W) { ((v >> W::BITS).as_(), v.as_()) };
let (z1, z0) = hi_lo(x.as_() * y.as_());
let w = Self::MU.wrapping_mul(&z0);
let (r1, r0) = hi_lo(Self::PRIME.as_() * w.as_());
let (_zero, carry) = z0.overflowing_add(&r0);
let (cc, z) = hi_lo(z1.as_() + r1.as_() + <Self::DoubleWord as From<bool>>::from(carry));
let (s0, b0) = z.overflowing_sub(&Self::PRIME);
let (_s1, b1) = cc.overflowing_sub(&<W as From<bool>>::from(b0));
let mask = W::ZERO.wrapping_sub(&<W as From<bool>>::from(b1));
(z & mask) | (s0 & !mask)
}
}
pub(crate) trait FieldMulOpsSplitWord<W>: FieldParameters<W>
where
W: Word + AsPrimitive<Self::HalfWord>,
{
type HalfWord: Word + AsPrimitive<W>;
const MU: Self::HalfWord;
fn mul(x: W, y: W) -> W {
let high = |v: W| v >> (W::BITS / 2);
let low = |v: W| v & ((W::ONE << (W::BITS / 2)) - W::ONE);
let (x1, x0) = (high(x), low(x));
let (y1, y0) = (high(y), low(y));
let mut result = x0 * y0;
let mut carry = high(result);
let z0 = low(result);
result = x0 * y1;
let mut hi = high(result);
let mut lo = low(result);
result = lo + carry;
let mut z1 = low(result);
let mut cc = high(result);
result = hi + cc;
let mut z2 = low(result);
result = x1 * y0;
hi = high(result);
lo = low(result);
result = z1 + lo;
z1 = low(result);
cc = high(result);
result = hi + cc;
carry = low(result);
result = x1 * y1;
hi = high(result);
lo = low(result);
result = lo + carry;
lo = low(result);
cc = high(result);
result = hi + cc;
hi = low(result);
result = z2 + lo;
z2 = low(result);
cc = high(result);
result = hi + cc;
let mut z3 = low(result);
let mut w = <Self as FieldMulOpsSplitWord<W>>::MU.wrapping_mul(&z0.as_());
let p0 = low(Self::PRIME);
result = p0 * w.as_();
hi = high(result);
lo = low(result);
result = z0 + lo;
cc = high(result);
result = hi + cc;
carry = low(result);
let p1 = high(Self::PRIME);
result = p1 * w.as_();
hi = high(result);
lo = low(result);
result = lo + carry;
lo = low(result);
cc = high(result);
result = hi + cc;
hi = low(result);
result = z1 + lo;
z1 = low(result);
cc = high(result);
result = z2 + hi + cc;
z2 = low(result);
cc = high(result);
result = z3 + cc;
z3 = low(result);
w = <Self as FieldMulOpsSplitWord<W>>::MU.wrapping_mul(&z1.as_());
result = p0 * w.as_();
hi = high(result);
lo = low(result);
result = z1 + lo;
cc = high(result);
result = hi + cc;
carry = low(result);
result = p1 * w.as_();
hi = high(result);
lo = low(result);
result = lo + carry;
lo = low(result);
cc = high(result);
result = hi + cc;
hi = low(result);
result = z2 + lo;
z2 = low(result);
cc = high(result);
result = z3 + hi + cc;
z3 = low(result);
cc = high(result);
let prod = z2 | (z3 << (W::BITS / 2));
let (s0, b0) = prod.overflowing_sub(&Self::PRIME);
let (_s1, b1) = cc.overflowing_sub(&<W as From<bool>>::from(b0));
let mask = W::ZERO.wrapping_sub(&<W as From<bool>>::from(b1));
(prod & mask) | (s0 & !mask)
}
}
macro_rules! impl_field_ops_single_word {
($struct_name:ident, $W:ty, $W2:ty) => {
const _: () = assert!(<$W2>::BITS == 2 * <$W>::BITS);
impl $crate::fp::ops::FieldMulOpsSingleWord<$W> for $struct_name {
type DoubleWord = $W2;
}
impl $crate::fp::ops::FieldOps<$W> for $struct_name {
#[inline(always)]
fn mul(x: $W, y: $W) -> $W {
<Self as $crate::fp::ops::FieldMulOpsSingleWord<_>>::mul(x, y)
}
}
};
}
macro_rules! impl_field_ops_split_word {
($struct_name:ident, $W:ty, $W2:ty) => {
const _: () = assert!(2 * <$W2>::BITS == <$W>::BITS);
impl $crate::fp::ops::FieldMulOpsSplitWord<$W> for $struct_name {
type HalfWord = $W2;
const MU: Self::HalfWord = {
let mu = <$struct_name as FieldParameters<$W>>::MU;
assert!(mu <= (<$W2>::MAX as $W));
mu as $W2
};
}
impl $crate::fp::ops::FieldOps<$W> for $struct_name {
#[inline(always)]
fn mul(x: $W, y: $W) -> $W {
<Self as $crate::fp::ops::FieldMulOpsSplitWord<_>>::mul(x, y)
}
}
};
}