use crate::{
traits::{FromLeBytes, Invert, Pow},
utils::{matrix::Matrix, number::Number},
};
use ff::{
derive::bitvec::{order::Lsb0, view::AsBits},
PrimeField,
};
use num_bigint::{BigInt, BigUint};
use num_traits::Zero;
use rand::Rng;
use std::{cmp::Ordering, hash::Hash};
pub trait UsedField:
PrimeField
+ Hash
+ PartialOrd
+ From<Number>
+ From<i32>
+ From<bool>
+ From<f64>
+ Zero
+ std::ops::Shr<usize, Output = Self>
+ FromLeBytes
{
fn modulus() -> Number;
fn get_alpha() -> Number;
fn get_alpha_inverse() -> Number;
fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>);
fn power_of_two(exponent: usize) -> Self;
fn negative_power_of_two(exponent: usize) -> Self {
Self::ZERO - Self::power_of_two(exponent)
}
fn to_unsigned_number(self) -> Number {
BigInt::from(BigUint::from_bytes_le(self.to_repr().as_ref())).into()
}
fn to_signed_number(self) -> Number {
if self.is_ge_zero() {
self.to_unsigned_number()
} else {
-(Self::ZERO - self).to_unsigned_number()
}
}
fn is_binary(self) -> bool {
self <= Self::ONE
}
#[inline(always)]
fn is_ge_zero(self) -> bool {
self < Self::TWO_INV
}
fn is_le_zero(self) -> bool {
self >= Self::ZERO - self
}
#[inline(always)]
fn is_gt_zero(self) -> bool {
!self.is_le_zero()
}
#[inline(always)]
fn is_lt_zero(self) -> bool {
!self.is_ge_zero()
}
fn max_cyclic(self, other: Self) -> (Self, bool) {
if (other - self).is_ge_zero() {
(other, true)
} else {
(self, false)
}
}
fn min_cyclic(self, other: Self) -> (Self, bool) {
if (other - self).is_ge_zero() {
(self, false)
} else {
(other, true)
}
}
fn max(self, other: Self, signed: bool) -> Self {
let offset = if signed { Self::TWO_INV } else { Self::ZERO };
if self - offset < other - offset {
other
} else {
self
}
}
fn min(self, other: Self, signed: bool) -> Self {
let offset = if signed { Self::TWO_INV } else { Self::ZERO };
if self - offset > other - offset {
other
} else {
self
}
}
fn sort_pair(self, other: Self) -> (Self, Self) {
if (other - self).is_ge_zero() {
(self, other)
} else {
(other, self)
}
}
fn abs(self) -> Self {
if self.is_ge_zero() {
self
} else {
Self::ZERO - self
}
}
fn does_mul_overflow(self, other: Self) -> bool {
if self.is_zero_vartime() || other.is_zero_vartime() {
return false;
}
let prod = self.to_unsigned_number() * other.to_unsigned_number();
prod >= Self::modulus()
}
fn does_add_signed_overflow(self, other: Self) -> bool {
let sum = self + other;
match (self.is_ge_zero(), other.is_ge_zero()) {
(true, true) => sum.is_lt_zero(),
(true, false) => false,
(false, true) => false,
(false, false) => sum.is_ge_zero(),
}
}
fn does_add_unsigned_overflow(self, other: Self) -> bool {
if self == Self::ZERO || other == Self::ZERO {
false
} else {
self >= -other
}
}
fn unsigned_bits(self) -> usize {
let binding = self.to_repr();
let bits = binding.as_bits::<Lsb0>();
bits.len() - bits.trailing_zeros()
}
fn signed_bits(self) -> usize {
self.abs().unsigned_bits()
}
fn unsigned_bit(&self, idx: usize) -> bool {
let repr = self.to_repr();
let bits = repr.as_bits::<Lsb0>();
if idx < bits.len() {
bits[idx]
} else {
false
}
}
fn signed_bit(&self, idx: usize) -> bool {
if self.is_ge_zero() {
self.unsigned_bit(idx)
} else {
!(self.abs() - Self::ONE).unsigned_bit(idx)
}
}
fn unsigned_euclidean_division(self, other: Self) -> Self {
if other == Self::ZERO {
Self::ZERO
} else {
(self.to_unsigned_number() / other.to_unsigned_number()).into()
}
}
fn signed_euclidean_division(self, other: Self) -> Self {
if other == Self::ZERO {
Self::ZERO
} else {
(self.to_signed_number() / other.to_signed_number()).into()
}
}
fn gen_inclusive_range<R: Rng + ?Sized>(rng: &mut R, min: Self, max: Self) -> Self {
min + Self::from(Number::gen_range(
rng,
&0.into(),
&((max - min).to_unsigned_number() + 1),
))
}
fn from_bin(bin: &str) -> Self {
Self::from(
bin.chars()
.enumerate()
.fold(Number::from(0), |acc, (i, c)| {
if c == '1' {
acc + Number::power_of_two(i)
} else {
acc
}
}),
)
}
fn to_bin(&self) -> String {
(0..Self::modulus().bits()).fold(String::new(), |mut acc, i| {
if self.unsigned_bit(i) {
acc.push('1');
} else {
acc.push('0');
}
acc
})
}
fn as_power_of_two(self) -> Option<usize> {
if self == Self::ZERO {
return None;
}
let mut min_possible_exponent = 0usize;
let mut max_possible_exponent = Self::CAPACITY as usize;
while max_possible_exponent >= min_possible_exponent {
let mid = (min_possible_exponent + max_possible_exponent) / 2;
match self.partial_cmp(&Self::power_of_two(mid)) {
None => panic!("order should be total"),
Some(Ordering::Less) => {
max_possible_exponent = mid - 1;
}
Some(Ordering::Equal) => return Some(mid),
Some(Ordering::Greater) => {
min_possible_exponent = mid + 1;
}
}
}
None
}
fn signed_gt(self, other: Self) -> bool {
self.max(other, true) != other
}
fn signed_ge(self, other: Self) -> bool {
self.max(other, true) == self
}
fn signed_lt(self, other: Self) -> bool {
self.min(other, true) != other
}
fn signed_le(self, other: Self) -> bool {
self.min(other, true) == self
}
}
impl<F: UsedField> Invert for F {
fn invert(self, _is_expected_non_zero: bool) -> Self {
F::invert(&self).unwrap_or(F::ZERO)
}
}
impl<F: UsedField> Pow for F {
fn pow(self, e: &Number, _is_expected_non_zero: bool) -> Self {
let e = e % (F::modulus() - 1);
let mut e_u64 = [0u64; 4];
let bytes: [u8; 32] = e.into();
for (i, chunk) in bytes.chunks_exact(8).enumerate() {
e_u64[i] = u64::from_le_bytes(chunk.try_into().unwrap());
}
F::pow(&self, e_u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::field::ScalarField;
use ff::Field;
#[test]
fn is_ge_zero() {
for n in [
ScalarField::ZERO,
ScalarField::ONE,
ScalarField::TWO_INV - ScalarField::ONE,
ScalarField::TWO_INV,
ScalarField::ZERO - ScalarField::ONE,
] {
assert_eq!(n.is_ge_zero(), n <= ScalarField::ZERO - n)
}
}
}