1pub use base_field::BaseField;
2pub use scalar_field::ScalarField;
3
4macro_rules! impl_field {
5 ($FIELD: ty) => {
6use crate::utils::matrix::Matrix;
7use crate::utils::number::Number;
8use crate::utils::used_field::UsedField;
9use ff::{Field, PrimeField};
10use num_traits::{Num, ToPrimitive, Zero};
11use num_bigint::{BigInt, BigUint, Sign};
12const MAX_CACHED_EXPONENT: usize = 256;
13use std::ops::Shr;
14use paste::paste;
15use crate::traits::{Invert, FromLeBytes};
16
17thread_local! {
19 static POWERS_OF_TWO: [$FIELD; MAX_CACHED_EXPONENT + 1] = {
21 let mut arr: [$FIELD; MAX_CACHED_EXPONENT + 1] = [<$FIELD>::ONE; MAX_CACHED_EXPONENT + 1];
22 let two = <$FIELD>::from(2);
23 for i in 0..MAX_CACHED_EXPONENT {
24 arr[i+1] = two * arr[i]
25 }
26 arr
27 };
28 static MODULUS: Number = BigInt::from(<$FIELD>::modulus_big_uint()).into()
30}
31
32
33impl $FIELD {
34 fn modulus_big_uint() -> BigUint {
35 BigUint::from_str_radix(&(<$FIELD>::MODULUS[2..]), 16).unwrap()
36 }
37 fn modulus_number() -> Number {
38 MODULUS.with(|x| x.clone())
39 }
40
41 fn power_of_two(exponent: usize) -> $FIELD {
42 if exponent <= MAX_CACHED_EXPONENT {
43 POWERS_OF_TWO.with(|x| x[exponent])
44 } else {
45 <$FIELD>::from(2).pow([exponent as u64])
46 }
47 }
48
49 pub fn from_le_bytes_checked(bytes: [u8; 32]) -> Option<Self> {
50 Option::<$FIELD>::from(<$FIELD>::from_repr(paste! { [<$FIELD Repr>] }(bytes)))
51 }
52
53 pub fn to_le_bytes(&self) -> [u8; 32] {
54 <[u8; 32]>::try_from(self.to_repr().as_ref()).unwrap()
55 }
56 pub fn to_usize(&self) -> Option<usize> {
57 const USIZE_BYTES: usize = usize::BITS as usize / 8;
58 let bytes = self.to_le_bytes();
59 if &bytes[USIZE_BYTES..32] == &[0; 32 - USIZE_BYTES] {
60 Some(usize::from_le_bytes(bytes[0..USIZE_BYTES].try_into().unwrap()))
61 } else {
62 None
63 }
64 }
65 pub fn from_simple_string(a: &str) -> Option<Self> {
67 let chars = a.as_bytes();
68 let is_negative = chars[0] == b'-';
69 let ten = Self::from(10u64);
70 let mut res = Self::ZERO;
71 for idx in (is_negative as usize)..(chars.len()) {
72 if !matches!(chars[idx], b'0'..=b'9') {
73 return None;
74 }
75 res *= ten;
76 res += Self::from((chars[idx] - b'0') as u64);
77 }
78 Some(if is_negative {
79 -res
80 } else {
81 res
82 })
83 }
84}
85
86impl From<bool> for $FIELD {
87 fn from(value: bool) -> Self {
88 if value {
89 <$FIELD>::ONE
90 } else {
91 <$FIELD>::ZERO
92 }
93 }
94}
95
96impl From<i32> for $FIELD {
97 fn from(value: i32) -> Self {
98 if value < 0 {
99 <$FIELD>::ZERO - <$FIELD>::from((-value) as u64)
100 } else {
101 <$FIELD>::from(value as u64)
102 }
103 }
104}
105
106impl From<&BigUint> for $FIELD {
107 fn from(number: &BigUint) -> Self {
108 let mut res: $FIELD = 0.into();
109 for (i, digit) in number
110 .iter_u64_digits()
111 .enumerate()
112 {
113 res += <$FIELD>::from(digit) * <$FIELD>::power_of_two(i * 64);
114 }
115 res
116 }
117}
118
119impl From<&BigInt> for $FIELD {
120 fn from(number: &BigInt) -> Self {
121 let magnitude = <$FIELD>::from(number.magnitude());
122 let zero = <$FIELD>::from(0);
123 match number.sign() {
124 Sign::Minus => zero - magnitude,
125 Sign::NoSign => zero,
126 Sign::Plus => magnitude,
127 }
128 }
129}
130
131impl From<&Number> for $FIELD {
132 fn from(number: &Number) -> Self {
133 match number {
134 Number::SmallNum(i) => (&BigInt::from(*i)).into(),
135 Number::BigNum(n) => n.into(),
136 }
137 }
138}
139
140impl From<Number> for $FIELD {
141 fn from(number: Number) -> Self {
142 (&number).into()
143 }
144}
145
146impl From<f64> for $FIELD {
147 fn from(value: f64) -> Self {
148 let mut bytes = value.to_le_bytes();
149 let sign = bytes[7] >> 7;
150 let exponent_hi = (bytes[7] & 127) as i16;
151 let exponent_lo = (bytes[6] & 240) as i16;
152 let exponent = (exponent_hi << 4) + (exponent_lo >> 4) - 1023;
153 bytes[7] = 0;
155 bytes[6] &= 15;
156 bytes[6] |= 16;
158 let value_unsigned = u64::from_le_bytes(bytes) >> (-exponent.min(0)).min(63);
159 <$FIELD>::power_of_two(exponent.max(0) as usize) * (if sign == 1u8 { <$FIELD>::ZERO - <$FIELD>::from(value_unsigned)} else {<$FIELD>::from(value_unsigned)})
160 }
161}
162
163impl FromLeBytes for $FIELD {
164 fn from_le_bytes(bytes: [u8; 32]) -> Self {
165 <$FIELD>::from_le_bytes_checked(bytes).unwrap()
166 }
167}
168
169fn find_alpha() -> i32 {
170 let p_minus_one = <$FIELD>::modulus_number() - 1;
171 for alpha in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47] {
172 if &p_minus_one % alpha != 0 {
173 return alpha;
174 }
175 }
176 panic!("Could not find prime alpha that does not divide p-1.")
177}
178fn find_alpha_inverse(alpha: i32) -> Number {
181 let q = <$FIELD>::modulus_number() - 1;
182 let m = (&q % alpha).to_i32().unwrap();
183 if m == 0 {
184 panic!("alpha divides p_minus_one");
185 }
186 let n = (1..alpha).find(|k| (m * k) % alpha == (alpha - 1)).unwrap();
189 let l = m * n / alpha;
190 let k = q / alpha;
191 n * k + l + 1
195}
196
197fn find_alphas() -> (Number, Number) {
198 let alpha = find_alpha();
199 let alpha_inverse = find_alpha_inverse(alpha);
200 (alpha.into(), alpha_inverse)
201}
202
203thread_local! {
204 static ALPHAS: (Number, Number) = find_alphas();
205}
206
207fn get_alpha() -> Number {
208 ALPHAS.with(|(alpha, _)| alpha.clone())
209}
210
211fn get_alpha_inverse() -> Number {
212 ALPHAS.with(|(_, alpha_inverse)| alpha_inverse.clone())
213}
214
215pub(super) fn build_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
216 assert_eq!(x.len(), y.len());
217 let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
218 for i in 0..x.len() {
219 for j in 0..y.len() {
220 mat[(i, j)] = (x[i] - y[j]).invert(true);
221 }
222 }
223 mat
224}
225pub(super) fn inverse_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
228 assert_eq!(x.len(), y.len());
229 fn prime(arr: &[$FIELD], val: $FIELD) -> $FIELD {
231 arr.iter()
232 .map(|u| if *u != val { val - u } else { 1.into() })
233 .product()
234 }
235 let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
236 for i in 0..x.len() {
237 for j in 0..y.len() {
238 let a = x.iter().map(|u| y[i] - u).product::<$FIELD>();
239 let a_prime = prime(x, x[j]);
240 let b = y.iter().map(|v| x[j] - v).product::<$FIELD>();
241 let b_prime = prime(y, y[i]);
242 mat[(i, j)] = a
243 * b
244 * a_prime.invert(true)
245 * b_prime.invert(true)
246 * (y[i] - x[j]).invert(true);
247 }
248 }
249 mat
250}
251
252fn mds_matrix_and_inverse(size: usize) -> (Matrix<$FIELD>, Matrix<$FIELD>) {
253 let x = (1..=size).map(|i| <$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
254 let y = (1..=size).map(|i| -<$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
255 let mds = build_cauchy(x.as_slice(), y.as_slice());
256 let inverse_mds = inverse_cauchy(x.as_slice(), y.as_slice());
257 (mds, inverse_mds)
258}
259
260impl Shr<usize> for $FIELD {
261 type Output = $FIELD;
262
263 fn shr(self, rhs: usize) -> Self::Output {
264 self.unsigned_euclidean_division(<$FIELD>::power_of_two(rhs))
265 }
266}
267
268impl UsedField for $FIELD {
269 fn modulus() -> Number {
270 <$FIELD>::modulus_number()
271 }
272
273 fn get_alpha() -> Number {
274 get_alpha()
275 }
276
277 fn get_alpha_inverse() -> Number {
278 get_alpha_inverse()
279 }
280
281 fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>) {
282 mds_matrix_and_inverse(width)
283 }
284
285 fn power_of_two(exponent: usize) -> Self {
286 <$FIELD>::power_of_two(exponent)
287 }
288}
289
290impl Zero for $FIELD {
291 fn zero() -> Self {
292 <$FIELD>::ZERO
293 }
294 fn is_zero(&self) -> bool {
295 *self == <$FIELD>::zero()
296 }
297}
298
299 };
300}
301
302#[allow(clippy::derived_hash_with_manual_eq)]
303mod scalar_field {
304
305 mod field_derive {
306 use ff::PrimeField;
307 use serde::{Deserialize, Serialize};
308 #[derive(PrimeField, Hash, Serialize, Deserialize)]
309 #[PrimeFieldModulus = "7237005577332262213973186563042994240857116359379907606001950938285454250989"]
311 #[PrimeFieldGenerator = "2"]
312 #[PrimeFieldReprEndianness = "little"]
313 pub struct ScalarField([u64; 4]);
314 }
315
316 use curve25519_dalek::Scalar;
317 pub use field_derive::ScalarField;
318 use field_derive::ScalarFieldRepr;
319 impl_field!(ScalarField);
320
321 impl From<Scalar> for ScalarField {
322 fn from(value: Scalar) -> Self {
323 ScalarField::from_le_bytes(value.to_bytes())
324 }
325 }
326}
327#[allow(clippy::derived_hash_with_manual_eq)]
328mod base_field {
329 mod field_derive {
330 use ff::PrimeField;
331 use serde::{Deserialize, Serialize};
332
333 #[derive(PrimeField, Hash, Serialize, Deserialize)]
334 #[PrimeFieldModulus = "57896044618658097711785492504343953926634992332820282019728792003956564819949"]
336 #[PrimeFieldGenerator = "2"]
337 #[PrimeFieldReprEndianness = "little"]
338 pub struct BaseField([u64; 4]);
339 }
340 pub use field_derive::BaseField;
341 use field_derive::BaseFieldRepr;
342 impl_field!(BaseField);
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::{
349 traits::{Invert, Pow},
350 utils::{number::Number, used_field::UsedField},
351 };
352 use ff::{Field, PrimeField};
353 use std::{f64::consts::PI, str::FromStr};
354
355 #[test]
356 fn from_f64() {
357 assert_eq!(
358 ScalarField::from(2f64.sqrt()),
359 ScalarField::from(Number::from_str("6369051672525773").unwrap())
360 );
361 assert_eq!(
362 ScalarField::from(-PI * 2f64.powi(150)),
363 ScalarField::from(
364 Number::from_str(
365 "0x0ffffffffffff36f0255dde97400000014def9dea2f79cd65812631a5cf5d3ed"
366 )
367 .unwrap()
368 )
369 );
370 assert_eq!(
371 ScalarField::from(0.001),
372 ScalarField::from(Number::from_str("4503599627370").unwrap())
373 );
374 assert_eq!(
375 ScalarField::from(-0.00000383),
376 ScalarField::from(
377 Number::from_str(
378 "0x1000000000000000000000000000000014def9dea2f79cd65812631658da3b61"
379 )
380 .unwrap()
381 )
382 );
383 assert_eq!(ScalarField::from(3f64 * 2f64.powi(-150)), ScalarField::ZERO);
384 }
385
386 #[test]
387 fn multiplicative_generator() {
388 let a = ScalarField::MULTIPLICATIVE_GENERATOR;
389 let b = a.pow(&((ScalarField::modulus() - 1) / 2), true);
390 assert_ne!(b, ScalarField::ONE);
391 }
392
393 #[test]
394 fn sqrt() {
395 fn test(square_root: ScalarField) {
396 let square = square_root.square();
397 let square_root = square.sqrt().unwrap();
398 assert_eq!(square_root.square(), square);
399 }
400
401 test(ScalarField::ZERO);
402 test(ScalarField::ONE);
403 use rand::rngs::OsRng;
404 for _ in 0..1024 {
405 test(ScalarField::random(OsRng));
406 }
407 }
408
409 #[test]
410 fn test_safe_field_inverse() {
411 for n in [
412 ScalarField::ZERO,
413 ScalarField::ONE,
414 ScalarField::from(2),
415 ScalarField::from(3),
416 ] {
417 let inv = n.invert(false);
418 if n == ScalarField::ZERO {
419 assert_eq!(inv, ScalarField::ZERO);
420 } else {
421 assert_eq!(n * inv, ScalarField::ONE);
422 }
423 }
424 }
425 #[test]
426 fn test_cauchy_inverse() {
427 let x = [
428 ScalarField::ONE,
429 ScalarField::from(2),
430 ScalarField::from(3),
431 ScalarField::from(4),
432 ScalarField::from(5),
433 ];
434 let y = [
435 ScalarField::ZERO,
436 -ScalarField::from(1),
437 -ScalarField::from(2),
438 -ScalarField::from(3),
439 -ScalarField::from(4),
440 ];
441 let cauchy = scalar_field::build_cauchy(&x, &y);
442 let inverse = scalar_field::inverse_cauchy(&x, &y);
443 let identity = cauchy.mat_mul(&inverse);
444 for i in 0..x.len() {
445 for j in 0..y.len() {
446 let expected = if i == j {
447 ScalarField::ONE
448 } else {
449 ScalarField::ZERO
450 };
451 assert_eq!(identity[(i, j)], expected);
452 }
453 }
454 }
455}