pub use base_field::BaseField;
pub use scalar_field::ScalarField;
macro_rules! impl_field {
($FIELD: ty) => {
use crate::utils::matrix::Matrix;
use crate::utils::number::Number;
use crate::utils::used_field::UsedField;
use ff::{Field, PrimeField};
use num_traits::{Num, ToPrimitive, Zero};
use num_bigint::{BigInt, BigUint, Sign};
const MAX_CACHED_EXPONENT: usize = 256;
use std::ops::Shr;
use paste::paste;
use crate::traits::{Invert, FromLeBytes};
thread_local! {
static POWERS_OF_TWO: [$FIELD; MAX_CACHED_EXPONENT + 1] = {
let mut arr: [$FIELD; MAX_CACHED_EXPONENT + 1] = [<$FIELD>::ONE; MAX_CACHED_EXPONENT + 1];
let two = <$FIELD>::from(2);
for i in 0..MAX_CACHED_EXPONENT {
arr[i+1] = two * arr[i]
}
arr
};
static MODULUS: Number = BigInt::from(<$FIELD>::modulus_big_uint()).into()
}
impl $FIELD {
fn modulus_big_uint() -> BigUint {
BigUint::from_str_radix(&(<$FIELD>::MODULUS[2..]), 16).unwrap()
}
fn modulus_number() -> Number {
MODULUS.with(|x| x.clone())
}
fn power_of_two(exponent: usize) -> $FIELD {
if exponent <= MAX_CACHED_EXPONENT {
POWERS_OF_TWO.with(|x| x[exponent])
} else {
<$FIELD>::from(2).pow([exponent as u64])
}
}
pub fn from_le_bytes_checked(bytes: [u8; 32]) -> Option<Self> {
Option::<$FIELD>::from(<$FIELD>::from_repr(paste! { [<$FIELD Repr>] }(bytes)))
}
pub fn to_le_bytes(&self) -> [u8; 32] {
<[u8; 32]>::try_from(self.to_repr().as_ref()).unwrap()
}
pub fn to_usize(&self) -> Option<usize> {
const USIZE_BYTES: usize = usize::BITS as usize / 8;
let bytes = self.to_le_bytes();
if &bytes[USIZE_BYTES..32] == &[0; 32 - USIZE_BYTES] {
Some(usize::from_le_bytes(bytes[0..USIZE_BYTES].try_into().unwrap()))
} else {
None
}
}
pub fn from_simple_string(a: &str) -> Option<Self> {
let chars = a.as_bytes();
let is_negative = chars[0] == b'-';
let ten = Self::from(10u64);
let mut res = Self::ZERO;
for idx in (is_negative as usize)..(chars.len()) {
if !matches!(chars[idx], b'0'..=b'9') {
return None;
}
res *= ten;
res += Self::from((chars[idx] - b'0') as u64);
}
Some(if is_negative {
-res
} else {
res
})
}
}
impl From<bool> for $FIELD {
fn from(value: bool) -> Self {
if value {
<$FIELD>::ONE
} else {
<$FIELD>::ZERO
}
}
}
impl From<i32> for $FIELD {
fn from(value: i32) -> Self {
if value < 0 {
<$FIELD>::ZERO - <$FIELD>::from((-value) as u64)
} else {
<$FIELD>::from(value as u64)
}
}
}
impl From<&BigUint> for $FIELD {
fn from(number: &BigUint) -> Self {
let mut res: $FIELD = 0.into();
for (i, digit) in number
.iter_u64_digits()
.enumerate()
{
res += <$FIELD>::from(digit) * <$FIELD>::power_of_two(i * 64);
}
res
}
}
impl From<&BigInt> for $FIELD {
fn from(number: &BigInt) -> Self {
let magnitude = <$FIELD>::from(number.magnitude());
let zero = <$FIELD>::from(0);
match number.sign() {
Sign::Minus => zero - magnitude,
Sign::NoSign => zero,
Sign::Plus => magnitude,
}
}
}
impl From<&Number> for $FIELD {
fn from(number: &Number) -> Self {
match number {
Number::SmallNum(i) => (&BigInt::from(*i)).into(),
Number::BigNum(n) => n.into(),
}
}
}
impl From<Number> for $FIELD {
fn from(number: Number) -> Self {
(&number).into()
}
}
impl From<f64> for $FIELD {
fn from(value: f64) -> Self {
let mut bytes = value.to_le_bytes();
let sign = bytes[7] >> 7;
let exponent_hi = (bytes[7] & 127) as i16;
let exponent_lo = (bytes[6] & 240) as i16;
let exponent = (exponent_hi << 4) + (exponent_lo >> 4) - 1023;
bytes[7] = 0;
bytes[6] &= 15;
bytes[6] |= 16;
let value_unsigned = u64::from_le_bytes(bytes) >> (-exponent.min(0)).min(63);
<$FIELD>::power_of_two(exponent.max(0) as usize) * (if sign == 1u8 { <$FIELD>::ZERO - <$FIELD>::from(value_unsigned)} else {<$FIELD>::from(value_unsigned)})
}
}
impl FromLeBytes for $FIELD {
fn from_le_bytes(bytes: [u8; 32]) -> Self {
<$FIELD>::from_le_bytes_checked(bytes).unwrap()
}
}
fn find_alpha() -> i32 {
let p_minus_one = <$FIELD>::modulus_number() - 1;
for alpha in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47] {
if &p_minus_one % alpha != 0 {
return alpha;
}
}
panic!("Could not find prime alpha that does not divide p-1.")
}
fn find_alpha_inverse(alpha: i32) -> Number {
let q = <$FIELD>::modulus_number() - 1;
let m = (&q % alpha).to_i32().unwrap();
if m == 0 {
panic!("alpha divides p_minus_one");
}
let n = (1..alpha).find(|k| (m * k) % alpha == (alpha - 1)).unwrap();
let l = m * n / alpha;
let k = q / alpha;
n * k + l + 1
}
fn find_alphas() -> (Number, Number) {
let alpha = find_alpha();
let alpha_inverse = find_alpha_inverse(alpha);
(alpha.into(), alpha_inverse)
}
thread_local! {
static ALPHAS: (Number, Number) = find_alphas();
}
fn get_alpha() -> Number {
ALPHAS.with(|(alpha, _)| alpha.clone())
}
fn get_alpha_inverse() -> Number {
ALPHAS.with(|(_, alpha_inverse)| alpha_inverse.clone())
}
pub(super) fn build_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
assert_eq!(x.len(), y.len());
let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
for i in 0..x.len() {
for j in 0..y.len() {
mat[(i, j)] = (x[i] - y[j]).invert(true);
}
}
mat
}
pub(super) fn inverse_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
assert_eq!(x.len(), y.len());
fn prime(arr: &[$FIELD], val: $FIELD) -> $FIELD {
arr.iter()
.map(|u| if *u != val { val - u } else { 1.into() })
.product()
}
let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
for i in 0..x.len() {
for j in 0..y.len() {
let a = x.iter().map(|u| y[i] - u).product::<$FIELD>();
let a_prime = prime(x, x[j]);
let b = y.iter().map(|v| x[j] - v).product::<$FIELD>();
let b_prime = prime(y, y[i]);
mat[(i, j)] = a
* b
* a_prime.invert(true)
* b_prime.invert(true)
* (y[i] - x[j]).invert(true);
}
}
mat
}
fn mds_matrix_and_inverse(size: usize) -> (Matrix<$FIELD>, Matrix<$FIELD>) {
let x = (1..=size).map(|i| <$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
let y = (1..=size).map(|i| -<$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
let mds = build_cauchy(x.as_slice(), y.as_slice());
let inverse_mds = inverse_cauchy(x.as_slice(), y.as_slice());
(mds, inverse_mds)
}
impl Shr<usize> for $FIELD {
type Output = $FIELD;
fn shr(self, rhs: usize) -> Self::Output {
self.unsigned_euclidean_division(<$FIELD>::power_of_two(rhs))
}
}
impl UsedField for $FIELD {
fn modulus() -> Number {
<$FIELD>::modulus_number()
}
fn get_alpha() -> Number {
get_alpha()
}
fn get_alpha_inverse() -> Number {
get_alpha_inverse()
}
fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>) {
mds_matrix_and_inverse(width)
}
fn power_of_two(exponent: usize) -> Self {
<$FIELD>::power_of_two(exponent)
}
}
impl Zero for $FIELD {
fn zero() -> Self {
<$FIELD>::ZERO
}
fn is_zero(&self) -> bool {
*self == <$FIELD>::zero()
}
}
};
}
#[allow(clippy::derived_hash_with_manual_eq)]
mod scalar_field {
mod field_derive {
use ff::PrimeField;
use serde::{Deserialize, Serialize};
#[derive(PrimeField, Hash, Serialize, Deserialize)]
#[PrimeFieldModulus = "7237005577332262213973186563042994240857116359379907606001950938285454250989"]
#[PrimeFieldGenerator = "2"]
#[PrimeFieldReprEndianness = "little"]
pub struct ScalarField([u64; 4]);
}
use curve25519_dalek::Scalar;
pub use field_derive::ScalarField;
use field_derive::ScalarFieldRepr;
impl_field!(ScalarField);
impl From<Scalar> for ScalarField {
fn from(value: Scalar) -> Self {
ScalarField::from_le_bytes(value.to_bytes())
}
}
}
#[allow(clippy::derived_hash_with_manual_eq)]
mod base_field {
mod field_derive {
use ff::PrimeField;
use serde::{Deserialize, Serialize};
#[derive(PrimeField, Hash, Serialize, Deserialize)]
#[PrimeFieldModulus = "57896044618658097711785492504343953926634992332820282019728792003956564819949"]
#[PrimeFieldGenerator = "2"]
#[PrimeFieldReprEndianness = "little"]
pub struct BaseField([u64; 4]);
}
pub use field_derive::BaseField;
use field_derive::BaseFieldRepr;
impl_field!(BaseField);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
traits::{Invert, Pow},
utils::{number::Number, used_field::UsedField},
};
use ff::{Field, PrimeField};
use std::{f64::consts::PI, str::FromStr};
#[test]
fn from_f64() {
assert_eq!(
ScalarField::from(2f64.sqrt()),
ScalarField::from(Number::from_str("6369051672525773").unwrap())
);
assert_eq!(
ScalarField::from(-PI * 2f64.powi(150)),
ScalarField::from(
Number::from_str(
"0x0ffffffffffff36f0255dde97400000014def9dea2f79cd65812631a5cf5d3ed"
)
.unwrap()
)
);
assert_eq!(
ScalarField::from(0.001),
ScalarField::from(Number::from_str("4503599627370").unwrap())
);
assert_eq!(
ScalarField::from(-0.00000383),
ScalarField::from(
Number::from_str(
"0x1000000000000000000000000000000014def9dea2f79cd65812631658da3b61"
)
.unwrap()
)
);
assert_eq!(ScalarField::from(3f64 * 2f64.powi(-150)), ScalarField::ZERO);
}
#[test]
fn multiplicative_generator() {
let a = ScalarField::MULTIPLICATIVE_GENERATOR;
let b = a.pow(&((ScalarField::modulus() - 1) / 2), true);
assert_ne!(b, ScalarField::ONE);
}
#[test]
fn sqrt() {
fn test(square_root: ScalarField) {
let square = square_root.square();
let square_root = square.sqrt().unwrap();
assert_eq!(square_root.square(), square);
}
test(ScalarField::ZERO);
test(ScalarField::ONE);
use rand::rngs::OsRng;
for _ in 0..1024 {
test(ScalarField::random(OsRng));
}
}
#[test]
fn test_safe_field_inverse() {
for n in [
ScalarField::ZERO,
ScalarField::ONE,
ScalarField::from(2),
ScalarField::from(3),
] {
let inv = n.invert(false);
if n == ScalarField::ZERO {
assert_eq!(inv, ScalarField::ZERO);
} else {
assert_eq!(n * inv, ScalarField::ONE);
}
}
}
#[test]
fn test_cauchy_inverse() {
let x = [
ScalarField::ONE,
ScalarField::from(2),
ScalarField::from(3),
ScalarField::from(4),
ScalarField::from(5),
];
let y = [
ScalarField::ZERO,
-ScalarField::from(1),
-ScalarField::from(2),
-ScalarField::from(3),
-ScalarField::from(4),
];
let cauchy = scalar_field::build_cauchy(&x, &y);
let inverse = scalar_field::inverse_cauchy(&x, &y);
let identity = cauchy.mat_mul(&inverse);
for i in 0..x.len() {
for j in 0..y.len() {
let expected = if i == j {
ScalarField::ONE
} else {
ScalarField::ZERO
};
assert_eq!(identity[(i, j)], expected);
}
}
}
}