Skip to main content

arcis_compiler/utils/
field.rs

1pub use base_field::BaseField;
2pub use scalar_field::ScalarField;
3
4macro_rules! impl_field {
5    ($FIELD: ty, $pow: literal) => {
6use crate::utils::matrix::Matrix;
7use crate::utils::number::Number;
8use crate::utils::used_field::UsedField;
9use ff::{Field, PrimeField};
10use num_traits::{Num, ToPrimitive, Zero};
11use num_bigint::{BigInt, BigUint, Sign};
12const MAX_CACHED_EXPONENT: usize = 256;
13use std::ops::Shr;
14use paste::paste;
15use crate::traits::{Invert, FromLeBytes};
16
17// We cache a few things that we tend to reuse.
18thread_local! {
19    /// The powers of two from 0 to MAX_CACHED_EXPONENT included.
20    static POWERS_OF_TWO: [$FIELD; MAX_CACHED_EXPONENT + 1] = {
21        let mut arr: [$FIELD; MAX_CACHED_EXPONENT + 1] = [<$FIELD>::ONE; MAX_CACHED_EXPONENT + 1];
22        let two = <$FIELD>::from(2);
23        for i in 0..MAX_CACHED_EXPONENT {
24            arr[i+1] = two * arr[i]
25        }
26        arr
27    };
28    /// The modulus of the field.
29    static MODULUS: Number = BigInt::from(<$FIELD>::modulus_big_uint()).into()
30}
31
32
33impl $FIELD {
34    fn modulus_big_uint() -> BigUint {
35        BigUint::from_str_radix(&(<$FIELD>::MODULUS[2..]), 16).unwrap()
36    }
37    fn modulus_number() -> Number {
38        MODULUS.with(|x| x.clone())
39    }
40
41    fn power_of_two(exponent: usize) -> $FIELD {
42        if exponent <= MAX_CACHED_EXPONENT {
43            POWERS_OF_TWO.with(|x| x[exponent])
44        } else {
45            <$FIELD>::from(2).pow([exponent as u64])
46        }
47    }
48
49    pub fn from_le_bytes_checked(bytes: [u8; 32]) -> Option<Self> {
50        Option::<$FIELD>::from(<$FIELD>::from_repr(paste! { [<$FIELD Repr>] }(bytes)))
51    }
52
53    pub fn to_le_bytes(&self) -> [u8; 32] {
54        <[u8; 32]>::try_from(self.to_repr().as_ref()).unwrap()
55    }
56    pub fn to_usize(&self) -> Option<usize> {
57        const USIZE_BYTES: usize = usize::BITS as usize / 8;
58        let bytes = self.to_le_bytes();
59        if &bytes[USIZE_BYTES..32] == &[0; 32 - USIZE_BYTES] {
60            Some(usize::from_le_bytes(bytes[0..USIZE_BYTES].try_into().unwrap()))
61        } else {
62            None
63        }
64    }
65    /// The string should only include chars in '0'..'9' with maybe a leading '-'.
66    pub fn from_simple_string(a: &str) -> Option<Self> {
67        let chars = a.as_bytes();
68        let is_negative = chars[0] == b'-';
69        let ten = Self::from(10u64);
70        let mut res = Self::ZERO;
71        for idx in (is_negative as usize)..(chars.len()) {
72            if !matches!(chars[idx], b'0'..=b'9') {
73                return None;
74            }
75            res *= ten;
76            res += Self::from((chars[idx] - b'0') as u64);
77        }
78        Some(if is_negative {
79            -res
80        } else {
81            res
82        })
83    }
84}
85
86impl From<bool> for $FIELD {
87    fn from(value: bool) -> Self {
88        if value {
89            <$FIELD>::ONE
90        } else {
91            <$FIELD>::ZERO
92        }
93    }
94}
95
96impl From<i32> for $FIELD {
97    fn from(value: i32) -> Self {
98        if value < 0 {
99            <$FIELD>::ZERO - <$FIELD>::from((-value) as u64)
100        } else {
101            <$FIELD>::from(value as u64)
102        }
103    }
104}
105
106impl From<&BigUint> for $FIELD {
107    fn from(number: &BigUint) -> Self {
108        let mut res: $FIELD = 0.into();
109        for (i, digit) in number
110            .iter_u64_digits()
111            .enumerate()
112        {
113            res += <$FIELD>::from(digit) * <$FIELD>::power_of_two(i * 64);
114        }
115        res
116    }
117}
118
119impl From<&BigInt> for $FIELD {
120    fn from(number: &BigInt) -> Self {
121        let magnitude = <$FIELD>::from(number.magnitude());
122        let zero = <$FIELD>::from(0);
123        match number.sign() {
124            Sign::Minus => zero - magnitude,
125            Sign::NoSign => zero,
126            Sign::Plus => magnitude,
127        }
128    }
129}
130
131impl From<&Number> for $FIELD {
132    fn from(number: &Number) -> Self {
133        match number {
134            Number::SmallNum(i) => (&BigInt::from(*i)).into(),
135            Number::BigNum(n) => n.into(),
136        }
137    }
138}
139
140impl From<Number> for $FIELD {
141    fn from(number: Number) -> Self {
142        (&number).into()
143    }
144}
145
146impl From<f64> for $FIELD {
147    fn from(value: f64) -> Self {
148        let mut bytes = value.to_le_bytes();
149        let sign = bytes[7] >> 7;
150        let exponent_hi = (bytes[7] & 127) as i16;
151        let exponent_lo = (bytes[6] & 240) as i16;
152        let exponent = (exponent_hi << 4) + (exponent_lo >> 4) - 1023;
153        // we get rid of the sign and the exponent
154        bytes[7] = 0;
155        bytes[6] &= 15;
156        // we need to set the implicit 1-bit (the value being 1.mantissa)
157        bytes[6] |= 16;
158        let value_unsigned = u64::from_le_bytes(bytes) >> (-exponent.min(0)).min(63);
159        <$FIELD>::power_of_two(exponent.max(0) as usize) * (if sign == 1u8 { <$FIELD>::ZERO - <$FIELD>::from(value_unsigned)} else {<$FIELD>::from(value_unsigned)})
160    }
161}
162
163impl FromLeBytes for $FIELD {
164    fn from_le_bytes(bytes: [u8; 32]) -> Self {
165        <$FIELD>::from_le_bytes_checked(bytes).unwrap()
166    }
167}
168
169fn find_alpha() -> i32 {
170    let p_minus_one = <$FIELD>::modulus_number() - 1;
171    for alpha in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47] {
172        if &p_minus_one % alpha != 0 {
173            return alpha;
174        }
175    }
176    panic!("Could not find prime alpha that does not divide p-1.")
177}
178/// Finds b such that x^alpha^b = x.
179/// Such b is the inverse of alpha modulo p-1
180fn find_alpha_inverse(alpha: i32) -> Number {
181    let q = <$FIELD>::modulus_number() - 1;
182    let m = (&q % alpha).to_i32().unwrap();
183    if m == 0 {
184        panic!("alpha divides p_minus_one");
185    }
186    // Since alpha is prime and m != 0 in F_alpha, the map k |-> m * k is a bijection
187    // F_alpha -> F_alpha. We search for the pre-image of alpha - 1, which we call n.
188    let n = (1..alpha).find(|k| (m * k) % alpha == (alpha - 1)).unwrap();
189    let l = m * n / alpha;
190    let k = q / alpha;
191    // q   = k*alpha + m
192    // n*m = l*alpha + alpha - 1
193    // n*q = n*k*alpha + n*m = (n*k + l + 1)*alpha - 1
194    n * k + l + 1
195}
196
197fn find_alphas() -> (Number, Number) {
198    let alpha = find_alpha();
199    let alpha_inverse = find_alpha_inverse(alpha);
200    (alpha.into(), alpha_inverse)
201}
202
203thread_local! {
204    static ALPHAS: (Number, Number) = find_alphas();
205}
206
207fn get_alpha() -> Number {
208    ALPHAS.with(|(alpha, _)| alpha.clone())
209}
210
211fn get_alpha_inverse() -> Number {
212    ALPHAS.with(|(_, alpha_inverse)| alpha_inverse.clone())
213}
214
215pub(super) fn build_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
216    assert_eq!(x.len(), y.len());
217    let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
218    for i in 0..x.len() {
219        for j in 0..y.len() {
220            mat[(i, j)] = (x[i] - y[j]).invert(true);
221        }
222    }
223    mat
224}
225/// Computes the inverse of a cauchy matrix.
226/// See <https://en.wikipedia.org/wiki/Cauchy_matrix>
227pub(super) fn inverse_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
228    assert_eq!(x.len(), y.len());
229    /// Computes some sort of derivation.
230    fn prime(arr: &[$FIELD], val: $FIELD) -> $FIELD {
231        arr.iter()
232            .map(|u| if *u != val { val - u } else { 1.into() })
233            .product()
234    }
235    let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
236    for i in 0..x.len() {
237        for j in 0..y.len() {
238            let a = x.iter().map(|u| y[i] - u).product::<$FIELD>();
239            let a_prime = prime(x, x[j]);
240            let b = y.iter().map(|v| x[j] - v).product::<$FIELD>();
241            let b_prime = prime(y, y[i]);
242            mat[(i, j)] = a
243                * b
244                * a_prime.invert(true)
245                * b_prime.invert(true)
246                * (y[i] - x[j]).invert(true);
247        }
248    }
249    mat
250}
251
252fn mds_matrix_and_inverse(size: usize) -> (Matrix<$FIELD>, Matrix<$FIELD>) {
253    let x = (1..=size).map(|i| <$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
254    let y = (1..=size).map(|i| -<$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
255    let mds = build_cauchy(x.as_slice(), y.as_slice());
256    let inverse_mds = inverse_cauchy(x.as_slice(), y.as_slice());
257    (mds, inverse_mds)
258}
259
260impl Shr<usize> for $FIELD {
261    type Output = $FIELD;
262
263    fn shr(self, rhs: usize) -> Self::Output {
264        self.unsigned_euclidean_division(<$FIELD>::power_of_two(rhs))
265    }
266}
267
268impl UsedField for $FIELD {
269    fn modulus() -> Number {
270        <$FIELD>::modulus_number()
271    }
272
273    fn get_alpha() -> Number {
274        get_alpha()
275    }
276
277    fn get_alpha_inverse() -> Number {
278        get_alpha_inverse()
279    }
280
281    fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>) {
282        mds_matrix_and_inverse(width)
283    }
284
285    fn power_of_two(exponent: usize) -> Self {
286        <$FIELD>::power_of_two(exponent)
287    }
288    fn exponent_close_power_of_two() -> usize {
289        $pow
290    }
291}
292
293impl Zero for $FIELD {
294    fn zero() -> Self {
295        <$FIELD>::ZERO
296    }
297    fn is_zero(&self) -> bool {
298        *self == <$FIELD>::zero()
299    }
300}
301
302    };
303}
304
305#[allow(clippy::derived_hash_with_manual_eq)]
306mod scalar_field {
307
308    mod field_derive {
309        use ff::PrimeField;
310        use serde::{Deserialize, Serialize};
311        #[derive(PrimeField, Hash, Serialize, Deserialize)]
312        // modulus = 2^252 + 27742317777372353535851937790883648493
313        #[PrimeFieldModulus = "7237005577332262213973186563042994240857116359379907606001950938285454250989"]
314        #[PrimeFieldGenerator = "2"]
315        #[PrimeFieldReprEndianness = "little"]
316        pub struct ScalarField([u64; 4]);
317    }
318
319    use curve25519_dalek::Scalar;
320    pub use field_derive::ScalarField;
321    use field_derive::ScalarFieldRepr;
322    impl_field!(ScalarField, 252);
323
324    impl From<Scalar> for ScalarField {
325        fn from(value: Scalar) -> Self {
326            ScalarField::from_le_bytes(value.to_bytes())
327        }
328    }
329}
330#[allow(clippy::derived_hash_with_manual_eq)]
331mod base_field {
332    mod field_derive {
333        use ff::PrimeField;
334        use serde::{Deserialize, Serialize};
335
336        #[derive(PrimeField, Hash, Serialize, Deserialize)]
337        // modulus = 2^255 - 19
338        #[PrimeFieldModulus = "57896044618658097711785492504343953926634992332820282019728792003956564819949"]
339        #[PrimeFieldGenerator = "2"]
340        #[PrimeFieldReprEndianness = "little"]
341        pub struct BaseField([u64; 4]);
342    }
343    pub use field_derive::BaseField;
344    use field_derive::BaseFieldRepr;
345    impl_field!(BaseField, 255);
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use crate::{
352        traits::{Invert, Pow},
353        utils::{number::Number, used_field::UsedField},
354    };
355    use ff::{Field, PrimeField};
356    use std::{f64::consts::PI, str::FromStr};
357
358    #[test]
359    fn from_f64() {
360        assert_eq!(
361            ScalarField::from(2f64.sqrt()),
362            ScalarField::from(Number::from_str("6369051672525773").unwrap())
363        );
364        assert_eq!(
365            ScalarField::from(-PI * 2f64.powi(150)),
366            ScalarField::from(
367                Number::from_str(
368                    "0x0ffffffffffff36f0255dde97400000014def9dea2f79cd65812631a5cf5d3ed"
369                )
370                .unwrap()
371            )
372        );
373        assert_eq!(
374            ScalarField::from(0.001),
375            ScalarField::from(Number::from_str("4503599627370").unwrap())
376        );
377        assert_eq!(
378            ScalarField::from(-0.00000383),
379            ScalarField::from(
380                Number::from_str(
381                    "0x1000000000000000000000000000000014def9dea2f79cd65812631658da3b61"
382                )
383                .unwrap()
384            )
385        );
386        assert_eq!(ScalarField::from(3f64 * 2f64.powi(-150)), ScalarField::ZERO);
387    }
388
389    #[test]
390    fn multiplicative_generator() {
391        let a = ScalarField::MULTIPLICATIVE_GENERATOR;
392        let b = a.pow(&((ScalarField::modulus() - 1) / 2), true);
393        assert_ne!(b, ScalarField::ONE);
394    }
395
396    #[test]
397    fn sqrt() {
398        fn test(square_root: ScalarField) {
399            let square = square_root.square();
400            let square_root = square.sqrt().unwrap();
401            assert_eq!(square_root.square(), square);
402        }
403
404        test(ScalarField::ZERO);
405        test(ScalarField::ONE);
406        use rand::rngs::OsRng;
407        for _ in 0..1024 {
408            test(ScalarField::random(OsRng));
409        }
410    }
411
412    #[test]
413    fn test_safe_field_inverse() {
414        for n in [
415            ScalarField::ZERO,
416            ScalarField::ONE,
417            ScalarField::from(2),
418            ScalarField::from(3),
419        ] {
420            let inv = n.invert(false);
421            if n == ScalarField::ZERO {
422                assert_eq!(inv, ScalarField::ZERO);
423            } else {
424                assert_eq!(n * inv, ScalarField::ONE);
425            }
426        }
427    }
428    #[test]
429    fn test_cauchy_inverse() {
430        let x = [
431            ScalarField::ONE,
432            ScalarField::from(2),
433            ScalarField::from(3),
434            ScalarField::from(4),
435            ScalarField::from(5),
436        ];
437        let y = [
438            ScalarField::ZERO,
439            -ScalarField::from(1),
440            -ScalarField::from(2),
441            -ScalarField::from(3),
442            -ScalarField::from(4),
443        ];
444        let cauchy = scalar_field::build_cauchy(&x, &y);
445        let inverse = scalar_field::inverse_cauchy(&x, &y);
446        let identity = cauchy.mat_mul(&inverse);
447        for i in 0..x.len() {
448            for j in 0..y.len() {
449                let expected = if i == j {
450                    ScalarField::ONE
451                } else {
452                    ScalarField::ZERO
453                };
454                assert_eq!(identity[(i, j)], expected);
455            }
456        }
457    }
458}