Skip to main content

arcis_compiler/utils/
used_field.rs

1use crate::{
2    traits::{FromLeBytes, Invert, Pow},
3    utils::{matrix::Matrix, number::Number},
4};
5use ff::{
6    derive::bitvec::{order::Lsb0, view::AsBits},
7    PrimeField,
8};
9use num_bigint::{BigInt, BigUint};
10use num_traits::Zero;
11use rand::Rng;
12use std::{cmp::Ordering, hash::Hash};
13
14/// A collection of functions on prime fields.
15/// A lot of these functions are linked to the different orders one can define over a cyclic group.
16pub trait UsedField:
17    PrimeField
18    + Hash
19    + PartialOrd
20    + From<Number>
21    + From<i32>
22    + From<bool>
23    + From<f64>
24    + Zero
25    + std::ops::Shr<usize, Output = Self>
26    + FromLeBytes
27{
28    /// The prime number p such that the field is F_p.
29    fn modulus() -> Number;
30
31    /// The smallest prime such that it does not divide p-1.
32    fn get_alpha() -> Number;
33
34    /// The smallest positive integer n such that alpha*n = 1 mod (p-1).
35    fn get_alpha_inverse() -> Number;
36
37    /// An MDS matrix and its inverse.
38    fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>);
39
40    /// Returns two^exponent.
41    fn power_of_two(exponent: usize) -> Self;
42
43    /// Returns -two^exponent.
44    fn negative_power_of_two(exponent: usize) -> Self {
45        Self::ZERO - Self::power_of_two(exponent)
46    }
47
48    fn to_unsigned_number(self) -> Number {
49        BigInt::from(BigUint::from_bytes_le(self.to_repr().as_ref())).into()
50    }
51
52    fn to_signed_number(self) -> Number {
53        if self.is_ge_zero() {
54            self.to_unsigned_number()
55        } else {
56            -(Self::ZERO - self).to_unsigned_number()
57        }
58    }
59
60    /// Whether a number is binary or not.
61    fn is_binary(self) -> bool {
62        self <= Self::ONE
63    }
64
65    /// Whether a number is greater or equal to zero according to the signed order.
66    #[inline(always)]
67    fn is_ge_zero(self) -> bool {
68        // should be equivalent to self <= Self::ZERO - self (see test)
69        self < Self::TWO_INV
70    }
71
72    /// Whether a number is less than or equal to zero according to the signed order.
73    fn is_le_zero(self) -> bool {
74        self >= Self::ZERO - self
75    }
76
77    /// Whether a number is greater than zero according to the signed order.
78    #[inline(always)]
79    fn is_gt_zero(self) -> bool {
80        !self.is_le_zero()
81    }
82
83    /// Whether a number is less than zero according to the signed order.
84    #[inline(always)]
85    fn is_lt_zero(self) -> bool {
86        !self.is_ge_zero()
87    }
88
89    /// Max according to the cyclic order on the smaller interval between the two field elements.
90    fn max_cyclic(self, other: Self) -> Self {
91        if (other - self).is_ge_zero() {
92            other
93        } else {
94            self
95        }
96    }
97    /// Min according to the cyclic order on the smaller interval between the two field elements.
98    fn min_cyclic(self, other: Self) -> Self {
99        if (other - self).is_ge_zero() {
100            self
101        } else {
102            other
103        }
104    }
105    /// Max according to (un)signed order.
106    fn max(self, other: Self, signed: bool) -> Self {
107        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
108        if self - offset < other - offset {
109            other
110        } else {
111            self
112        }
113    }
114    /// Min according to (un)signed order.
115    fn min(self, other: Self, signed: bool) -> Self {
116        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
117        if self - offset > other - offset {
118            other
119        } else {
120            self
121        }
122    }
123    /// Sort according to the cyclic order on the smaller interval between the two field elements.
124    fn sort_pair(self, other: Self) -> (Self, Self) {
125        if (other - self).is_ge_zero() {
126            (self, other)
127        } else {
128            (other, self)
129        }
130    }
131    /// Abs according to the signed order.
132    fn abs(self) -> Self {
133        if self.is_ge_zero() {
134            self
135        } else {
136            Self::ZERO - self
137        }
138    }
139    // assuming self and other >= 0
140    fn does_mul_overflow(self, other: Self) -> bool {
141        let zero = Self::ZERO;
142        if self == zero || other == zero {
143            return false;
144        }
145        let prod = self.to_unsigned_number() * other.to_unsigned_number();
146        prod >= Self::modulus()
147    }
148    /// The number of bits of self in unsigned notation.
149    fn unsigned_bits(self) -> usize {
150        let binding = self.to_repr();
151        let bits = binding.as_bits::<Lsb0>();
152        bits.len() - bits.trailing_zeros()
153    }
154    /// The number of bits of self in signed notation.
155    fn signed_bits(self) -> usize {
156        self.abs().unsigned_bits()
157    }
158    /// The idx bit of self in unsigned notation.
159    fn unsigned_bit(&self, idx: usize) -> bool {
160        let repr = self.to_repr();
161        let bits = repr.as_bits::<Lsb0>();
162        if idx < bits.len() {
163            bits[idx]
164        } else {
165            false
166        }
167    }
168    /// The idx bit of self in signed notation.
169    fn signed_bit(&self, idx: usize) -> bool {
170        if self.is_ge_zero() {
171            self.unsigned_bit(idx)
172        } else {
173            !(self.abs() - Self::ONE).unsigned_bit(idx)
174        }
175    }
176    /// The unsigned Euclidean division. Returns 0 if the divisor is 0.
177    fn unsigned_euclidean_division(self, other: Self) -> Self {
178        if other == Self::ZERO {
179            Self::ZERO
180        } else {
181            (self.to_unsigned_number() / other.to_unsigned_number()).into()
182        }
183    }
184    /// The signed Euclidean division. Returns 0 if the divisor is 0.
185    fn signed_euclidean_division(self, other: Self) -> Self {
186        if other == Self::ZERO {
187            Self::ZERO
188        } else {
189            (self.to_signed_number() / other.to_signed_number()).into()
190        }
191    }
192    /// Generates a field element between min and max, included.
193    fn gen_inclusive_range<R: Rng + ?Sized>(rng: &mut R, min: Self, max: Self) -> Self {
194        min + Self::from(Number::gen_range(
195            rng,
196            &0.into(),
197            &((max - min).to_unsigned_number() + 1),
198        ))
199    }
200
201    /// Converts a number in lsb-to-msb binary expansion to the corresponding element in Self.
202    fn from_bin(bin: &str) -> Self {
203        Self::from(
204            bin.chars()
205                .enumerate()
206                .fold(Number::from(0), |acc, (i, c)| {
207                    if c == '1' {
208                        acc + Number::power_of_two(i)
209                    } else {
210                        acc
211                    }
212                }),
213        )
214    }
215
216    /// Converts self to its lsb-to-msb binary expansion.
217    fn to_bin(&self) -> String {
218        (0..Self::modulus().bits()).fold(String::new(), |mut acc, i| {
219            if self.unsigned_bit(i) {
220                acc.push('1');
221            } else {
222                acc.push('0');
223            }
224            acc
225        })
226    }
227    fn as_power_of_two(self) -> Option<usize> {
228        if self == Self::ZERO {
229            return None;
230        }
231        let mut min_possible_exponent = 0usize;
232        let mut max_possible_exponent = Self::CAPACITY as usize;
233        while max_possible_exponent >= min_possible_exponent {
234            let mid = (min_possible_exponent + max_possible_exponent) / 2;
235            match self.partial_cmp(&Self::power_of_two(mid)) {
236                None => panic!("order should be total"),
237                Some(Ordering::Less) => {
238                    max_possible_exponent = mid - 1;
239                }
240                Some(Ordering::Equal) => return Some(mid),
241                Some(Ordering::Greater) => {
242                    min_possible_exponent = mid + 1;
243                }
244            }
245        }
246        None
247    }
248}
249
250impl<F: UsedField> Invert for F {
251    fn invert(self, _is_expected_non_zero: bool) -> Self {
252        F::invert(&self).unwrap_or(F::ZERO)
253    }
254}
255
256impl<F: UsedField> Pow for F {
257    fn pow(self, e: &Number, _is_expected_non_zero: bool) -> Self {
258        let e = e % (F::modulus() - 1);
259        let mut e_u64 = [0u64; 4];
260        let bytes: [u8; 32] = e.into();
261        for (i, chunk) in bytes.chunks_exact(8).enumerate() {
262            e_u64[i] = u64::from_le_bytes(chunk.try_into().unwrap());
263        }
264
265        F::pow(&self, e_u64)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::utils::field::ScalarField;
273    use ff::Field;
274    #[test]
275    fn is_ge_zero() {
276        for n in [
277            ScalarField::ZERO,
278            ScalarField::ONE,
279            ScalarField::TWO_INV - ScalarField::ONE,
280            ScalarField::TWO_INV,
281            ScalarField::ZERO - ScalarField::ONE,
282        ] {
283            assert_eq!(n.is_ge_zero(), n <= ScalarField::ZERO - n)
284        }
285    }
286}