1pub use base_field::BaseField;
2pub use scalar_field::ScalarField;
3
4macro_rules! impl_field {
5 ($FIELD: ty) => {
6use crate::utils::matrix::Matrix;
7use crate::utils::number::Number;
8use crate::utils::used_field::UsedField;
9use ff::{Field, PrimeField};
10use num_traits::{Num, ToPrimitive, Zero};
11use num_bigint::{BigInt, BigUint, Sign};
12const MAX_CACHED_EXPONENT: usize = 256;
13use std::ops::Shr;
14use paste::paste;
15use crate::traits::{Invert, FromLeBytes};
16
17thread_local! {
19 static POWERS_OF_TWO: [$FIELD; MAX_CACHED_EXPONENT + 1] = {
21 let mut arr: [$FIELD; MAX_CACHED_EXPONENT + 1] = [<$FIELD>::ONE; MAX_CACHED_EXPONENT + 1];
22 let two = <$FIELD>::from(2);
23 for i in 0..MAX_CACHED_EXPONENT {
24 arr[i+1] = two * arr[i]
25 }
26 arr
27 };
28 static MODULUS: Number = BigInt::from(<$FIELD>::modulus_big_uint()).into()
30}
31
32
33impl $FIELD {
34 fn modulus_big_uint() -> BigUint {
35 BigUint::from_str_radix(&(<$FIELD>::MODULUS[2..]), 16).unwrap()
36 }
37 fn modulus_number() -> Number {
38 MODULUS.with(|x| x.clone())
39 }
40
41 fn power_of_two(exponent: usize) -> $FIELD {
42 if exponent <= MAX_CACHED_EXPONENT {
43 POWERS_OF_TWO.with(|x| x[exponent])
44 } else {
45 <$FIELD>::from(2).pow([exponent as u64])
46 }
47 }
48
49 pub fn from_le_bytes_checked(bytes: [u8; 32]) -> Option<Self> {
50 Option::<$FIELD>::from(<$FIELD>::from_repr(paste! { [<$FIELD Repr>] }(bytes)))
51 }
52
53 pub fn to_le_bytes(&self) -> [u8; 32] {
54 <[u8; 32]>::try_from(self.to_repr().as_ref()).unwrap()
55 }
56}
57
58impl From<bool> for $FIELD {
59 fn from(value: bool) -> Self {
60 if value {
61 <$FIELD>::ONE
62 } else {
63 <$FIELD>::ZERO
64 }
65 }
66}
67
68impl From<i32> for $FIELD {
69 fn from(value: i32) -> Self {
70 if value < 0 {
71 <$FIELD>::ZERO - <$FIELD>::from((-value) as u64)
72 } else {
73 <$FIELD>::from(value as u64)
74 }
75 }
76}
77
78impl From<&BigUint> for $FIELD {
79 fn from(number: &BigUint) -> Self {
80 let mut res: $FIELD = 0.into();
81 for (i, digit) in (number % <$FIELD>::modulus_big_uint())
82 .iter_u64_digits()
83 .enumerate()
84 {
85 res += <$FIELD>::from(digit) * <$FIELD>::power_of_two(i * 64);
86 }
87 res
88 }
89}
90
91impl From<&BigInt> for $FIELD {
92 fn from(number: &BigInt) -> Self {
93 let magnitude = <$FIELD>::from(number.magnitude());
94 let zero = <$FIELD>::from(0);
95 match number.sign() {
96 Sign::Minus => zero - magnitude,
97 Sign::NoSign => zero,
98 Sign::Plus => magnitude,
99 }
100 }
101}
102
103impl From<&Number> for $FIELD {
104 fn from(number: &Number) -> Self {
105 match number {
106 Number::SmallNum(i) => (&BigInt::from(*i)).into(),
107 Number::BigNum(n) => n.into(),
108 }
109 }
110}
111
112impl From<Number> for $FIELD {
113 fn from(number: Number) -> Self {
114 (&number).into()
115 }
116}
117
118impl From<f64> for $FIELD {
119 fn from(value: f64) -> Self {
120 let mut bytes = value.to_le_bytes();
121 let sign = bytes[7] >> 7;
122 let exponent_hi = (bytes[7] & 127) as i16;
123 let exponent_lo = (bytes[6] & 240) as i16;
124 let exponent = (exponent_hi << 4) + (exponent_lo >> 4) - 1023;
125 bytes[7] = 0;
127 bytes[6] &= 15;
128 bytes[6] |= 16;
130 let value_unsigned = u64::from_le_bytes(bytes) >> (-exponent.min(0)).min(63);
131 <$FIELD>::power_of_two(exponent.max(0) as usize) * (if sign == 1u8 { <$FIELD>::ZERO - <$FIELD>::from(value_unsigned)} else {<$FIELD>::from(value_unsigned)})
132 }
133}
134
135impl FromLeBytes for $FIELD {
136 fn from_le_bytes(bytes: [u8; 32]) -> Self {
137 <$FIELD>::from_le_bytes_checked(bytes).unwrap()
138 }
139}
140
141fn find_alpha() -> i32 {
142 let p_minus_one = <$FIELD>::modulus_number() - 1;
143 for alpha in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47] {
144 if &p_minus_one % alpha != 0 {
145 return alpha;
146 }
147 }
148 panic!("Could not find prime alpha that does not divide p-1.")
149}
150fn find_alpha_inverse(alpha: i32) -> Number {
153 let q = <$FIELD>::modulus_number() - 1;
154 let m = (&q % alpha).to_i32().unwrap();
155 if m == 0 {
156 panic!("alpha divides p_minus_one");
157 }
158 let n = (1..alpha).find(|k| (m * k) % alpha == (alpha - 1)).unwrap();
161 let l = m * n / alpha;
162 let k = q / alpha;
163 n * k + l + 1
167}
168
169fn find_alphas() -> (Number, Number) {
170 let alpha = find_alpha();
171 let alpha_inverse = find_alpha_inverse(alpha);
172 (alpha.into(), alpha_inverse)
173}
174
175thread_local! {
176 static ALPHAS: (Number, Number) = find_alphas();
177}
178
179fn get_alpha() -> Number {
180 ALPHAS.with(|(alpha, _)| alpha.clone())
181}
182
183fn get_alpha_inverse() -> Number {
184 ALPHAS.with(|(_, alpha_inverse)| alpha_inverse.clone())
185}
186
187pub(super) fn build_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
188 assert_eq!(x.len(), y.len());
189 let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
190 for i in 0..x.len() {
191 for j in 0..y.len() {
192 mat[(i, j)] = (x[i] - y[j]).invert(true);
193 }
194 }
195 mat
196}
197pub(super) fn inverse_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
200 assert_eq!(x.len(), y.len());
201 fn prime(arr: &[$FIELD], val: $FIELD) -> $FIELD {
203 arr.iter()
204 .map(|u| if *u != val { val - u } else { 1.into() })
205 .product()
206 }
207 let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
208 for i in 0..x.len() {
209 for j in 0..y.len() {
210 let a = x.iter().map(|u| y[i] - u).product::<$FIELD>();
211 let a_prime = prime(x, x[j]);
212 let b = y.iter().map(|v| x[j] - v).product::<$FIELD>();
213 let b_prime = prime(y, y[i]);
214 mat[(i, j)] = a
215 * b
216 * a_prime.invert(true)
217 * b_prime.invert(true)
218 * (y[i] - x[j]).invert(true);
219 }
220 }
221 mat
222}
223
224fn mds_matrix_and_inverse(size: usize) -> (Matrix<$FIELD>, Matrix<$FIELD>) {
225 let x = (1..=size).map(|i| <$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
226 let y = (1..=size).map(|i| -<$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
227 let mds = build_cauchy(x.as_slice(), y.as_slice());
228 let inverse_mds = inverse_cauchy(x.as_slice(), y.as_slice());
229 (mds, inverse_mds)
230}
231
232impl Shr<usize> for $FIELD {
233 type Output = $FIELD;
234
235 fn shr(self, rhs: usize) -> Self::Output {
236 self.unsigned_euclidean_division(<$FIELD>::power_of_two(rhs))
237 }
238}
239
240impl UsedField for $FIELD {
241 fn modulus() -> Number {
242 <$FIELD>::modulus_number()
243 }
244
245 fn get_alpha() -> Number {
246 get_alpha()
247 }
248
249 fn get_alpha_inverse() -> Number {
250 get_alpha_inverse()
251 }
252
253 fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>) {
254 mds_matrix_and_inverse(width)
255 }
256
257 fn power_of_two(exponent: usize) -> Self {
258 <$FIELD>::power_of_two(exponent)
259 }
260}
261
262impl Zero for $FIELD {
263 fn zero() -> Self {
264 <$FIELD>::ZERO
265 }
266 fn is_zero(&self) -> bool {
267 *self == <$FIELD>::zero()
268 }
269}
270
271 };
272}
273
274#[allow(clippy::derived_hash_with_manual_eq)]
275mod scalar_field {
276
277 mod field_derive {
278 use ff::PrimeField;
279 use serde::{Deserialize, Serialize};
280 #[derive(PrimeField, Hash, Serialize, Deserialize)]
281 #[PrimeFieldModulus = "7237005577332262213973186563042994240857116359379907606001950938285454250989"]
283 #[PrimeFieldGenerator = "2"]
284 #[PrimeFieldReprEndianness = "little"]
285 pub struct ScalarField([u64; 4]);
286 }
287
288 use curve25519_dalek::Scalar;
289 pub use field_derive::ScalarField;
290 use field_derive::ScalarFieldRepr;
291 impl_field!(ScalarField);
292
293 impl From<Scalar> for ScalarField {
294 fn from(value: Scalar) -> Self {
295 ScalarField::from_le_bytes(value.to_bytes())
296 }
297 }
298}
299#[allow(clippy::derived_hash_with_manual_eq)]
300mod base_field {
301 mod field_derive {
302 use ff::PrimeField;
303 use serde::{Deserialize, Serialize};
304
305 #[derive(PrimeField, Hash, Serialize, Deserialize)]
306 #[PrimeFieldModulus = "57896044618658097711785492504343953926634992332820282019728792003956564819949"]
308 #[PrimeFieldGenerator = "2"]
309 #[PrimeFieldReprEndianness = "little"]
310 pub struct BaseField([u64; 4]);
311 }
312 pub use field_derive::BaseField;
313 use field_derive::BaseFieldRepr;
314 impl_field!(BaseField);
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::{
321 traits::{Invert, Pow},
322 utils::{number::Number, used_field::UsedField},
323 };
324 use ff::{Field, PrimeField};
325 use std::{f64::consts::PI, str::FromStr};
326
327 #[test]
328 fn from_f64() {
329 assert_eq!(
330 ScalarField::from(2f64.sqrt()),
331 ScalarField::from(Number::from_str("6369051672525773").unwrap())
332 );
333 assert_eq!(
334 ScalarField::from(-PI * 2f64.powi(150)),
335 ScalarField::from(
336 Number::from_str(
337 "0x0ffffffffffff36f0255dde97400000014def9dea2f79cd65812631a5cf5d3ed"
338 )
339 .unwrap()
340 )
341 );
342 assert_eq!(
343 ScalarField::from(0.001),
344 ScalarField::from(Number::from_str("4503599627370").unwrap())
345 );
346 assert_eq!(
347 ScalarField::from(-0.00000383),
348 ScalarField::from(
349 Number::from_str(
350 "0x1000000000000000000000000000000014def9dea2f79cd65812631658da3b61"
351 )
352 .unwrap()
353 )
354 );
355 assert_eq!(ScalarField::from(3f64 * 2f64.powi(-150)), ScalarField::ZERO);
356 }
357
358 #[test]
359 fn multiplicative_generator() {
360 let a = ScalarField::MULTIPLICATIVE_GENERATOR;
361 let b = a.pow(&((ScalarField::modulus() - 1) / 2), true);
362 assert_ne!(b, ScalarField::ONE);
363 }
364
365 #[test]
366 fn sqrt() {
367 fn test(square_root: ScalarField) {
368 let square = square_root.square();
369 let square_root = square.sqrt().unwrap();
370 assert_eq!(square_root.square(), square);
371 }
372
373 test(ScalarField::ZERO);
374 test(ScalarField::ONE);
375 use rand::rngs::OsRng;
376 for _ in 0..1024 {
377 test(ScalarField::random(OsRng));
378 }
379 }
380
381 #[test]
382 fn test_safe_field_inverse() {
383 for n in [
384 ScalarField::ZERO,
385 ScalarField::ONE,
386 ScalarField::from(2),
387 ScalarField::from(3),
388 ] {
389 let inv = n.invert(false);
390 if n == ScalarField::ZERO {
391 assert_eq!(inv, ScalarField::ZERO);
392 } else {
393 assert_eq!(n * inv, ScalarField::ONE);
394 }
395 }
396 }
397 #[test]
398 fn test_cauchy_inverse() {
399 let x = [
400 ScalarField::ONE,
401 ScalarField::from(2),
402 ScalarField::from(3),
403 ScalarField::from(4),
404 ScalarField::from(5),
405 ];
406 let y = [
407 ScalarField::ZERO,
408 -ScalarField::from(1),
409 -ScalarField::from(2),
410 -ScalarField::from(3),
411 -ScalarField::from(4),
412 ];
413 let cauchy = scalar_field::build_cauchy(&x, &y);
414 let inverse = scalar_field::inverse_cauchy(&x, &y);
415 let identity = cauchy.mat_mul(&inverse);
416 for i in 0..x.len() {
417 for j in 0..y.len() {
418 let expected = if i == j {
419 ScalarField::ONE
420 } else {
421 ScalarField::ZERO
422 };
423 assert_eq!(identity[(i, j)], expected);
424 }
425 }
426 }
427}