use crate::{AdditiveGroup, BigInt, FftField, One, PrimeField, SqrtPrecomputation, Zero};
use ark_std::{
cmp::*,
fmt::{Display, Formatter, Result as FmtResult},
hash::Hash,
marker::PhantomData,
str::FromStr,
};
use educe::Educe;
use num_traits::Unsigned;
pub use ark_ff_macros::SmallFpConfig;
pub trait SmallFpConfig: Send + Sync + 'static + Sized {
type T: Copy
+ Default
+ PartialEq
+ Eq
+ Hash
+ Sync
+ Send
+ PartialOrd
+ Display
+ Unsigned
+ core::fmt::Debug
+ core::ops::Add<Output = Self::T>
+ core::ops::Sub<Output = Self::T>
+ core::ops::Mul<Output = Self::T>
+ core::ops::Div<Output = Self::T>
+ core::ops::Rem<Output = Self::T>
+ Into<u128>
+ TryFrom<u128>;
const MODULUS: Self::T;
const MODULUS_U128: u128;
const NUM_BIG_INT_LIMBS: usize = 1;
const GENERATOR: SmallFp<Self>;
const ZERO: SmallFp<Self>;
const ONE: SmallFp<Self>;
const NEG_ONE: SmallFp<Self>;
const TWO_ADICITY: u32;
const TWO_ADIC_ROOT_OF_UNITY: SmallFp<Self>;
const SMALL_SUBGROUP_BASE: Option<u32> = None;
const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = None;
const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<SmallFp<Self>> = None;
const SQRT_PRECOMP: Option<SqrtPrecomputation<SmallFp<Self>>>;
fn add_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>);
fn sub_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>);
fn double_in_place(a: &mut SmallFp<Self>);
fn neg_in_place(a: &mut SmallFp<Self>);
fn mul_assign(a: &mut SmallFp<Self>, b: &SmallFp<Self>);
fn sum_of_products<const T: usize>(
a: &[SmallFp<Self>; T],
b: &[SmallFp<Self>; T],
) -> SmallFp<Self>;
fn square_in_place(a: &mut SmallFp<Self>);
fn inverse(a: &SmallFp<Self>) -> Option<SmallFp<Self>>;
fn new(value: Self::T) -> SmallFp<Self>;
fn from_bigint(other: BigInt<1>) -> Option<SmallFp<Self>>;
fn into_bigint(other: SmallFp<Self>) -> BigInt<1>;
}
#[derive(Educe)]
#[educe(Default, Hash, Clone, Copy, PartialEq, Eq)]
pub struct SmallFp<P: SmallFpConfig> {
pub value: P::T,
_phantom: PhantomData<P>,
}
impl<P: SmallFpConfig> SmallFp<P> {
#[doc(hidden)]
#[inline]
pub fn is_geq_modulus(&self) -> bool {
self.value >= P::MODULUS
}
pub const fn from_raw(value: P::T) -> Self {
Self {
value,
_phantom: PhantomData,
}
}
#[inline]
pub fn new(value: P::T) -> Self {
P::new(value)
}
pub const fn num_bits_to_shave() -> usize {
primitive_type_bit_size(P::MODULUS_U128) - (Self::MODULUS_BIT_SIZE as usize)
}
}
impl<P: SmallFpConfig> ark_std::fmt::Debug for SmallFp<P> {
fn fmt(&self, f: &mut Formatter<'_>) -> ark_std::fmt::Result {
ark_std::fmt::Debug::fmt(&self.into_bigint(), f)
}
}
impl<P: SmallFpConfig> Zero for SmallFp<P> {
#[inline]
fn zero() -> Self {
P::ZERO
}
#[inline]
fn is_zero(&self) -> bool {
*self == P::ZERO
}
}
impl<P: SmallFpConfig> One for SmallFp<P> {
#[inline]
fn one() -> Self {
P::ONE
}
#[inline]
fn is_one(&self) -> bool {
*self == P::ONE
}
}
impl<P: SmallFpConfig> AdditiveGroup for SmallFp<P> {
type Scalar = Self;
const ZERO: Self = P::ZERO;
#[inline]
fn double(&self) -> Self {
let mut temp = *self;
AdditiveGroup::double_in_place(&mut temp);
temp
}
#[inline]
fn double_in_place(&mut self) -> &mut Self {
P::double_in_place(self);
self
}
#[inline]
fn neg_in_place(&mut self) -> &mut Self {
P::neg_in_place(self);
self
}
}
const fn const_to_bigint(value: u128) -> BigInt<1> {
BigInt::<1>::new([value as u64])
}
const fn const_num_bits_u128(value: u128) -> u32 {
if value == 0 {
0
} else {
128 - value.leading_zeros()
}
}
const fn primitive_type_bit_size(modulus_u128: u128) -> usize {
match modulus_u128 {
x if x <= u8::MAX as u128 => 8,
x if x <= u16::MAX as u128 => 16,
x if x <= u32::MAX as u128 => 32,
_ => 64,
}
}
impl<P: SmallFpConfig> PrimeField for SmallFp<P> {
type BigInt = BigInt<1>;
const MODULUS: Self::BigInt = const_to_bigint(P::MODULUS_U128);
const MODULUS_MINUS_ONE_DIV_TWO: Self::BigInt = Self::MODULUS.divide_by_2_round_down();
const MODULUS_BIT_SIZE: u32 = const_num_bits_u128(P::MODULUS_U128);
const TRACE: Self::BigInt = Self::MODULUS.two_adic_coefficient();
const TRACE_MINUS_ONE_DIV_TWO: Self::BigInt = Self::TRACE.divide_by_2_round_down();
#[inline]
fn from_bigint(r: BigInt<1>) -> Option<Self> {
P::from_bigint(r)
}
fn into_bigint(self) -> BigInt<1> {
P::into_bigint(self)
}
}
impl<P: SmallFpConfig> FftField for SmallFp<P> {
const GENERATOR: Self = P::GENERATOR;
const TWO_ADICITY: u32 = P::TWO_ADICITY;
const TWO_ADIC_ROOT_OF_UNITY: Self = P::TWO_ADIC_ROOT_OF_UNITY;
const SMALL_SUBGROUP_BASE: Option<u32> = P::SMALL_SUBGROUP_BASE;
const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = P::SMALL_SUBGROUP_BASE_ADICITY;
const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<Self> = P::LARGE_SUBGROUP_ROOT_OF_UNITY;
}
impl<P: SmallFpConfig> Ord for SmallFp<P> {
#[inline(always)]
fn cmp(&self, other: &Self) -> Ordering {
self.into_bigint().cmp(&other.into_bigint())
}
}
impl<P: SmallFpConfig> PartialOrd for SmallFp<P> {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<P: SmallFpConfig> ark_std::rand::distributions::Distribution<SmallFp<P>>
for ark_std::rand::distributions::Standard
{
#[inline]
fn sample<R: ark_std::rand::Rng + ?Sized>(&self, rng: &mut R) -> SmallFp<P> {
macro_rules! sample_loop {
($ty:ty) => {
loop {
let mut val: $ty = rng.sample(ark_std::rand::distributions::Standard);
let shave_bits = SmallFp::<P>::num_bits_to_shave();
let mask = <$ty>::MAX >> shave_bits;
val &= mask;
if val > 0 && u128::from(val) < P::MODULUS_U128 {
return SmallFp::from(val);
}
}
};
}
match P::MODULUS_U128 {
modulus if modulus <= u8::MAX as u128 => sample_loop!(u8),
modulus if modulus <= u16::MAX as u128 => sample_loop!(u16),
modulus if modulus <= u32::MAX as u128 => sample_loop!(u32),
_ => sample_loop!(u64),
}
}
}
#[derive(Debug)]
pub enum ParseSmallFpError {
Empty,
InvalidFormat,
InvalidLeadingZero,
}
impl<P: SmallFpConfig> FromStr for SmallFp<P> {
type Err = ParseSmallFpError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
return Err(ParseSmallFpError::Empty);
}
if s.starts_with('0') && s.len() > 1 {
return Err(ParseSmallFpError::InvalidLeadingZero);
}
match s.parse::<u128>() {
Ok(val) => Ok(SmallFp::from(val)),
Err(_) => Err(ParseSmallFpError::InvalidFormat),
}
}
}
impl<P: SmallFpConfig> Display for SmallFp<P> {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
let bigint = P::into_bigint(*self);
write!(f, "{}", bigint)
}
}