Skip to main content

arcis_compiler/utils/
field.rs

1pub use base_field::BaseField;
2pub use scalar_field::ScalarField;
3
4macro_rules! impl_field {
5    ($FIELD: ty) => {
6use crate::utils::matrix::Matrix;
7use crate::utils::number::Number;
8use crate::utils::used_field::UsedField;
9use ff::{Field, PrimeField};
10use num_traits::{Num, ToPrimitive, Zero};
11use num_bigint::{BigInt, BigUint, Sign};
12const MAX_CACHED_EXPONENT: usize = 256;
13use std::ops::Shr;
14use paste::paste;
15use crate::traits::{Invert, FromLeBytes};
16
17// We cache a few things that we tend to reuse.
18thread_local! {
19    /// The powers of two from 0 to MAX_CACHED_EXPONENT included.
20    static POWERS_OF_TWO: [$FIELD; MAX_CACHED_EXPONENT + 1] = {
21        let mut arr: [$FIELD; MAX_CACHED_EXPONENT + 1] = [<$FIELD>::ONE; MAX_CACHED_EXPONENT + 1];
22        let two = <$FIELD>::from(2);
23        for i in 0..MAX_CACHED_EXPONENT {
24            arr[i+1] = two * arr[i]
25        }
26        arr
27    };
28    /// The modulus of the field.
29    static MODULUS: Number = BigInt::from(<$FIELD>::modulus_big_uint()).into()
30}
31
32
33impl $FIELD {
34    fn modulus_big_uint() -> BigUint {
35        BigUint::from_str_radix(&(<$FIELD>::MODULUS[2..]), 16).unwrap()
36    }
37    fn modulus_number() -> Number {
38        MODULUS.with(|x| x.clone())
39    }
40
41    fn power_of_two(exponent: usize) -> $FIELD {
42        if exponent <= MAX_CACHED_EXPONENT {
43            POWERS_OF_TWO.with(|x| x[exponent])
44        } else {
45            <$FIELD>::from(2).pow([exponent as u64])
46        }
47    }
48
49    pub fn from_le_bytes_checked(bytes: [u8; 32]) -> Option<Self> {
50        Option::<$FIELD>::from(<$FIELD>::from_repr(paste! { [<$FIELD Repr>] }(bytes)))
51    }
52
53    pub fn to_le_bytes(&self) -> [u8; 32] {
54        <[u8; 32]>::try_from(self.to_repr().as_ref()).unwrap()
55    }
56}
57
58impl From<bool> for $FIELD {
59    fn from(value: bool) -> Self {
60        if value {
61            <$FIELD>::ONE
62        } else {
63            <$FIELD>::ZERO
64        }
65    }
66}
67
68impl From<i32> for $FIELD {
69    fn from(value: i32) -> Self {
70        if value < 0 {
71            <$FIELD>::ZERO - <$FIELD>::from((-value) as u64)
72        } else {
73            <$FIELD>::from(value as u64)
74        }
75    }
76}
77
78impl From<&BigUint> for $FIELD {
79    fn from(number: &BigUint) -> Self {
80        let mut res: $FIELD = 0.into();
81        for (i, digit) in (number % <$FIELD>::modulus_big_uint())
82            .iter_u64_digits()
83            .enumerate()
84        {
85            res += <$FIELD>::from(digit) * <$FIELD>::power_of_two(i * 64);
86        }
87        res
88    }
89}
90
91impl From<&BigInt> for $FIELD {
92    fn from(number: &BigInt) -> Self {
93        let magnitude = <$FIELD>::from(number.magnitude());
94        let zero = <$FIELD>::from(0);
95        match number.sign() {
96            Sign::Minus => zero - magnitude,
97            Sign::NoSign => zero,
98            Sign::Plus => magnitude,
99        }
100    }
101}
102
103impl From<&Number> for $FIELD {
104    fn from(number: &Number) -> Self {
105        match number {
106            Number::SmallNum(i) => (&BigInt::from(*i)).into(),
107            Number::BigNum(n) => n.into(),
108        }
109    }
110}
111
112impl From<Number> for $FIELD {
113    fn from(number: Number) -> Self {
114        (&number).into()
115    }
116}
117
118impl From<f64> for $FIELD {
119    fn from(value: f64) -> Self {
120        let mut bytes = value.to_le_bytes();
121        let sign = bytes[7] >> 7;
122        let exponent_hi = (bytes[7] & 127) as i16;
123        let exponent_lo = (bytes[6] & 240) as i16;
124        let exponent = (exponent_hi << 4) + (exponent_lo >> 4) - 1023;
125        // we get rid of the sign and the exponent
126        bytes[7] = 0;
127        bytes[6] &= 15;
128        // we need to set the implicit 1-bit (the value being 1.mantissa)
129        bytes[6] |= 16;
130        let value_unsigned = u64::from_le_bytes(bytes) >> (-exponent.min(0)).min(63);
131        <$FIELD>::power_of_two(exponent.max(0) as usize) * (if sign == 1u8 { <$FIELD>::ZERO - <$FIELD>::from(value_unsigned)} else {<$FIELD>::from(value_unsigned)})
132    }
133}
134
135impl FromLeBytes for $FIELD {
136    fn from_le_bytes(bytes: [u8; 32]) -> Self {
137        <$FIELD>::from_le_bytes_checked(bytes).unwrap()
138    }
139}
140
141fn find_alpha() -> i32 {
142    let p_minus_one = <$FIELD>::modulus_number() - 1;
143    for alpha in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47] {
144        if &p_minus_one % alpha != 0 {
145            return alpha;
146        }
147    }
148    panic!("Could not find prime alpha that does not divide p-1.")
149}
150/// Finds b such that x^alpha^b = x.
151/// Such b is the inverse of alpha modulo p-1
152fn find_alpha_inverse(alpha: i32) -> Number {
153    let q = <$FIELD>::modulus_number() - 1;
154    let m = (&q % alpha).to_i32().unwrap();
155    if m == 0 {
156        panic!("alpha divides p_minus_one");
157    }
158    // Since alpha is prime and m != 0 in F_alpha, the map k |-> m * k is a bijection
159    // F_alpha -> F_alpha. We search for the pre-image of alpha - 1, which we call n.
160    let n = (1..alpha).find(|k| (m * k) % alpha == (alpha - 1)).unwrap();
161    let l = m * n / alpha;
162    let k = q / alpha;
163    // q   = k*alpha + m
164    // n*m = l*alpha + alpha - 1
165    // n*q = n*k*alpha + n*m = (n*k + l + 1)*alpha - 1
166    n * k + l + 1
167}
168
169fn find_alphas() -> (Number, Number) {
170    let alpha = find_alpha();
171    let alpha_inverse = find_alpha_inverse(alpha);
172    (alpha.into(), alpha_inverse)
173}
174
175thread_local! {
176    static ALPHAS: (Number, Number) = find_alphas();
177}
178
179fn get_alpha() -> Number {
180    ALPHAS.with(|(alpha, _)| alpha.clone())
181}
182
183fn get_alpha_inverse() -> Number {
184    ALPHAS.with(|(_, alpha_inverse)| alpha_inverse.clone())
185}
186
187pub(super) fn build_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
188    assert_eq!(x.len(), y.len());
189    let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
190    for i in 0..x.len() {
191        for j in 0..y.len() {
192            mat[(i, j)] = (x[i] - y[j]).invert(true);
193        }
194    }
195    mat
196}
197/// Computes the inverse of a cauchy matrix.
198/// See <https://en.wikipedia.org/wiki/Cauchy_matrix>
199pub(super) fn inverse_cauchy(x: &[$FIELD], y: &[$FIELD]) -> Matrix<$FIELD> {
200    assert_eq!(x.len(), y.len());
201    /// Computes some sort of derivation.
202    fn prime(arr: &[$FIELD], val: $FIELD) -> $FIELD {
203        arr.iter()
204            .map(|u| if *u != val { val - u } else { 1.into() })
205            .product()
206    }
207    let mut mat: Matrix<$FIELD> = Matrix::new((x.len(), y.len()), <$FIELD>::ZERO);
208    for i in 0..x.len() {
209        for j in 0..y.len() {
210            let a = x.iter().map(|u| y[i] - u).product::<$FIELD>();
211            let a_prime = prime(x, x[j]);
212            let b = y.iter().map(|v| x[j] - v).product::<$FIELD>();
213            let b_prime = prime(y, y[i]);
214            mat[(i, j)] = a
215                * b
216                * a_prime.invert(true)
217                * b_prime.invert(true)
218                * (y[i] - x[j]).invert(true);
219        }
220    }
221    mat
222}
223
224fn mds_matrix_and_inverse(size: usize) -> (Matrix<$FIELD>, Matrix<$FIELD>) {
225    let x = (1..=size).map(|i| <$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
226    let y = (1..=size).map(|i| -<$FIELD>::from(i as u64)).collect::<Vec<$FIELD>>();
227    let mds = build_cauchy(x.as_slice(), y.as_slice());
228    let inverse_mds = inverse_cauchy(x.as_slice(), y.as_slice());
229    (mds, inverse_mds)
230}
231
232impl Shr<usize> for $FIELD {
233    type Output = $FIELD;
234
235    fn shr(self, rhs: usize) -> Self::Output {
236        self.unsigned_euclidean_division(<$FIELD>::power_of_two(rhs))
237    }
238}
239
240impl UsedField for $FIELD {
241    fn modulus() -> Number {
242        <$FIELD>::modulus_number()
243    }
244
245    fn get_alpha() -> Number {
246        get_alpha()
247    }
248
249    fn get_alpha_inverse() -> Number {
250        get_alpha_inverse()
251    }
252
253    fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>) {
254        mds_matrix_and_inverse(width)
255    }
256
257    fn power_of_two(exponent: usize) -> Self {
258        <$FIELD>::power_of_two(exponent)
259    }
260}
261
262impl Zero for $FIELD {
263    fn zero() -> Self {
264        <$FIELD>::ZERO
265    }
266    fn is_zero(&self) -> bool {
267        *self == <$FIELD>::zero()
268    }
269}
270
271    };
272}
273
274#[allow(clippy::derived_hash_with_manual_eq)]
275mod scalar_field {
276
277    mod field_derive {
278        use ff::PrimeField;
279        use serde::{Deserialize, Serialize};
280        #[derive(PrimeField, Hash, Serialize, Deserialize)]
281        // modulus = 2^252 + 27742317777372353535851937790883648493
282        #[PrimeFieldModulus = "7237005577332262213973186563042994240857116359379907606001950938285454250989"]
283        #[PrimeFieldGenerator = "2"]
284        #[PrimeFieldReprEndianness = "little"]
285        pub struct ScalarField([u64; 4]);
286    }
287
288    use curve25519_dalek::Scalar;
289    pub use field_derive::ScalarField;
290    use field_derive::ScalarFieldRepr;
291    impl_field!(ScalarField);
292
293    impl From<Scalar> for ScalarField {
294        fn from(value: Scalar) -> Self {
295            ScalarField::from_le_bytes(value.to_bytes())
296        }
297    }
298}
299#[allow(clippy::derived_hash_with_manual_eq)]
300mod base_field {
301    mod field_derive {
302        use ff::PrimeField;
303        use serde::{Deserialize, Serialize};
304
305        #[derive(PrimeField, Hash, Serialize, Deserialize)]
306        // modulus = 2^255 - 19
307        #[PrimeFieldModulus = "57896044618658097711785492504343953926634992332820282019728792003956564819949"]
308        #[PrimeFieldGenerator = "2"]
309        #[PrimeFieldReprEndianness = "little"]
310        pub struct BaseField([u64; 4]);
311    }
312    pub use field_derive::BaseField;
313    use field_derive::BaseFieldRepr;
314    impl_field!(BaseField);
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::{
321        traits::{Invert, Pow},
322        utils::{number::Number, used_field::UsedField},
323    };
324    use ff::{Field, PrimeField};
325    use std::{f64::consts::PI, str::FromStr};
326
327    #[test]
328    fn from_f64() {
329        assert_eq!(
330            ScalarField::from(2f64.sqrt()),
331            ScalarField::from(Number::from_str("6369051672525773").unwrap())
332        );
333        assert_eq!(
334            ScalarField::from(-PI * 2f64.powi(150)),
335            ScalarField::from(
336                Number::from_str(
337                    "0x0ffffffffffff36f0255dde97400000014def9dea2f79cd65812631a5cf5d3ed"
338                )
339                .unwrap()
340            )
341        );
342        assert_eq!(
343            ScalarField::from(0.001),
344            ScalarField::from(Number::from_str("4503599627370").unwrap())
345        );
346        assert_eq!(
347            ScalarField::from(-0.00000383),
348            ScalarField::from(
349                Number::from_str(
350                    "0x1000000000000000000000000000000014def9dea2f79cd65812631658da3b61"
351                )
352                .unwrap()
353            )
354        );
355        assert_eq!(ScalarField::from(3f64 * 2f64.powi(-150)), ScalarField::ZERO);
356    }
357
358    #[test]
359    fn multiplicative_generator() {
360        let a = ScalarField::MULTIPLICATIVE_GENERATOR;
361        let b = a.pow(&((ScalarField::modulus() - 1) / 2), true);
362        assert_ne!(b, ScalarField::ONE);
363    }
364
365    #[test]
366    fn sqrt() {
367        fn test(square_root: ScalarField) {
368            let square = square_root.square();
369            let square_root = square.sqrt().unwrap();
370            assert_eq!(square_root.square(), square);
371        }
372
373        test(ScalarField::ZERO);
374        test(ScalarField::ONE);
375        use rand::rngs::OsRng;
376        for _ in 0..1024 {
377            test(ScalarField::random(OsRng));
378        }
379    }
380
381    #[test]
382    fn test_safe_field_inverse() {
383        for n in [
384            ScalarField::ZERO,
385            ScalarField::ONE,
386            ScalarField::from(2),
387            ScalarField::from(3),
388        ] {
389            let inv = n.invert(false);
390            if n == ScalarField::ZERO {
391                assert_eq!(inv, ScalarField::ZERO);
392            } else {
393                assert_eq!(n * inv, ScalarField::ONE);
394            }
395        }
396    }
397    #[test]
398    fn test_cauchy_inverse() {
399        let x = [
400            ScalarField::ONE,
401            ScalarField::from(2),
402            ScalarField::from(3),
403            ScalarField::from(4),
404            ScalarField::from(5),
405        ];
406        let y = [
407            ScalarField::ZERO,
408            -ScalarField::from(1),
409            -ScalarField::from(2),
410            -ScalarField::from(3),
411            -ScalarField::from(4),
412        ];
413        let cauchy = scalar_field::build_cauchy(&x, &y);
414        let inverse = scalar_field::inverse_cauchy(&x, &y);
415        let identity = cauchy.mat_mul(&inverse);
416        for i in 0..x.len() {
417            for j in 0..y.len() {
418                let expected = if i == j {
419                    ScalarField::ONE
420                } else {
421                    ScalarField::ZERO
422                };
423                assert_eq!(identity[(i, j)], expected);
424            }
425        }
426    }
427}