Skip to main content

arcis_compiler/utils/
used_field.rs

1use crate::{
2    traits::{FromLeBytes, Invert, Pow},
3    utils::{matrix::Matrix, number::Number},
4};
5use ff::{
6    derive::bitvec::{order::Lsb0, view::AsBits},
7    PrimeField,
8};
9use num_bigint::{BigInt, BigUint};
10use num_traits::Zero;
11use rand::Rng;
12use std::{cmp::Ordering, hash::Hash};
13
14/// A collection of functions on prime fields.
15/// A lot of these functions are linked to the different orders one can define over a cyclic group.
16pub trait UsedField:
17    PrimeField
18    + Hash
19    + PartialOrd
20    + From<Number>
21    + From<i32>
22    + From<bool>
23    + From<f64>
24    + Zero
25    + std::ops::Shr<usize, Output = Self>
26    + FromLeBytes
27{
28    /// The prime number p such that the field is F_p.
29    fn modulus() -> Number;
30
31    /// The smallest prime such that it does not divide p-1.
32    fn get_alpha() -> Number;
33
34    /// The smallest positive integer n such that alpha*n = 1 mod (p-1).
35    fn get_alpha_inverse() -> Number;
36
37    /// An MDS matrix and its inverse.
38    fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>);
39
40    /// Returns two^exponent.
41    fn power_of_two(exponent: usize) -> Self;
42
43    /// Returns -two^exponent.
44    fn negative_power_of_two(exponent: usize) -> Self {
45        Self::ZERO - Self::power_of_two(exponent)
46    }
47
48    fn to_unsigned_number(self) -> Number {
49        BigInt::from(BigUint::from_bytes_le(self.to_repr().as_ref())).into()
50    }
51
52    fn to_signed_number(self) -> Number {
53        if self.is_ge_zero() {
54            self.to_unsigned_number()
55        } else {
56            -(Self::ZERO - self).to_unsigned_number()
57        }
58    }
59
60    /// Whether a number is binary or not.
61    fn is_binary(self) -> bool {
62        self <= Self::ONE
63    }
64
65    /// Whether a number is greater or equal to zero according to the signed order.
66    #[inline(always)]
67    fn is_ge_zero(self) -> bool {
68        // should be equivalent to self <= Self::ZERO - self (see test)
69        self < Self::TWO_INV
70    }
71
72    /// Whether a number is less than or equal to zero according to the signed order.
73    fn is_le_zero(self) -> bool {
74        self >= Self::ZERO - self
75    }
76
77    /// Whether a number is greater than zero according to the signed order.
78    #[inline(always)]
79    fn is_gt_zero(self) -> bool {
80        !self.is_le_zero()
81    }
82
83    /// Whether a number is less than zero according to the signed order.
84    #[inline(always)]
85    fn is_lt_zero(self) -> bool {
86        !self.is_ge_zero()
87    }
88
89    /// Max according to the cyclic order on the smaller interval between the two field elements.
90    fn max_cyclic(self, other: Self) -> Self {
91        if (other - self).is_ge_zero() {
92            other
93        } else {
94            self
95        }
96    }
97    /// Min according to the cyclic order on the smaller interval between the two field elements.
98    fn min_cyclic(self, other: Self) -> Self {
99        if (other - self).is_ge_zero() {
100            self
101        } else {
102            other
103        }
104    }
105    /// Max according to (un)signed order.
106    fn max(self, other: Self, signed: bool) -> Self {
107        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
108        if self - offset < other - offset {
109            other
110        } else {
111            self
112        }
113    }
114    /// Min according to (un)signed order.
115    fn min(self, other: Self, signed: bool) -> Self {
116        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
117        if self - offset > other - offset {
118            other
119        } else {
120            self
121        }
122    }
123    /// Sort according to the cyclic order on the smaller interval between the two field elements.
124    fn sort_pair(self, other: Self) -> (Self, Self) {
125        if (other - self).is_ge_zero() {
126            (self, other)
127        } else {
128            (other, self)
129        }
130    }
131    /// Abs according to the signed order.
132    fn abs(self) -> Self {
133        if self.is_ge_zero() {
134            self
135        } else {
136            Self::ZERO - self
137        }
138    }
139    // assuming self and other >= 0
140    fn does_mul_overflow(self, other: Self) -> bool {
141        let zero = Self::ZERO;
142        if self == zero || other == zero {
143            return false;
144        }
145        let prod = self.to_unsigned_number() * other.to_unsigned_number();
146        prod >= Self::modulus()
147    }
148    fn does_add_signed_overflow(self, other: Self) -> bool {
149        let sum = self + other;
150        match (self.is_ge_zero(), other.is_ge_zero()) {
151            (true, true) => sum.is_lt_zero(),
152            (true, false) => false,
153            (false, true) => false,
154            (false, false) => sum.is_ge_zero(),
155        }
156    }
157    fn does_add_unsigned_overflow(self, other: Self) -> bool {
158        if self == Self::ZERO || other == Self::ZERO {
159            false
160        } else {
161            self >= -other
162        }
163    }
164    /// The number of bits of self in unsigned notation.
165    fn unsigned_bits(self) -> usize {
166        let binding = self.to_repr();
167        let bits = binding.as_bits::<Lsb0>();
168        bits.len() - bits.trailing_zeros()
169    }
170    /// The number of bits of self in signed notation.
171    fn signed_bits(self) -> usize {
172        self.abs().unsigned_bits()
173    }
174    /// The idx bit of self in unsigned notation.
175    fn unsigned_bit(&self, idx: usize) -> bool {
176        let repr = self.to_repr();
177        let bits = repr.as_bits::<Lsb0>();
178        if idx < bits.len() {
179            bits[idx]
180        } else {
181            false
182        }
183    }
184    /// The idx bit of self in signed notation.
185    fn signed_bit(&self, idx: usize) -> bool {
186        if self.is_ge_zero() {
187            self.unsigned_bit(idx)
188        } else {
189            !(self.abs() - Self::ONE).unsigned_bit(idx)
190        }
191    }
192    /// The unsigned Euclidean division. Returns 0 if the divisor is 0.
193    fn unsigned_euclidean_division(self, other: Self) -> Self {
194        if other == Self::ZERO {
195            Self::ZERO
196        } else {
197            (self.to_unsigned_number() / other.to_unsigned_number()).into()
198        }
199    }
200    /// The signed Euclidean division. Returns 0 if the divisor is 0.
201    fn signed_euclidean_division(self, other: Self) -> Self {
202        if other == Self::ZERO {
203            Self::ZERO
204        } else {
205            (self.to_signed_number() / other.to_signed_number()).into()
206        }
207    }
208    /// Generates a field element between min and max, included.
209    fn gen_inclusive_range<R: Rng + ?Sized>(rng: &mut R, min: Self, max: Self) -> Self {
210        min + Self::from(Number::gen_range(
211            rng,
212            &0.into(),
213            &((max - min).to_unsigned_number() + 1),
214        ))
215    }
216
217    /// Converts a number in lsb-to-msb binary expansion to the corresponding element in Self.
218    fn from_bin(bin: &str) -> Self {
219        Self::from(
220            bin.chars()
221                .enumerate()
222                .fold(Number::from(0), |acc, (i, c)| {
223                    if c == '1' {
224                        acc + Number::power_of_two(i)
225                    } else {
226                        acc
227                    }
228                }),
229        )
230    }
231
232    /// Converts self to its lsb-to-msb binary expansion.
233    fn to_bin(&self) -> String {
234        (0..Self::modulus().bits()).fold(String::new(), |mut acc, i| {
235            if self.unsigned_bit(i) {
236                acc.push('1');
237            } else {
238                acc.push('0');
239            }
240            acc
241        })
242    }
243    fn as_power_of_two(self) -> Option<usize> {
244        if self == Self::ZERO {
245            return None;
246        }
247        let mut min_possible_exponent = 0usize;
248        let mut max_possible_exponent = Self::CAPACITY as usize;
249        while max_possible_exponent >= min_possible_exponent {
250            let mid = (min_possible_exponent + max_possible_exponent) / 2;
251            match self.partial_cmp(&Self::power_of_two(mid)) {
252                None => panic!("order should be total"),
253                Some(Ordering::Less) => {
254                    max_possible_exponent = mid - 1;
255                }
256                Some(Ordering::Equal) => return Some(mid),
257                Some(Ordering::Greater) => {
258                    min_possible_exponent = mid + 1;
259                }
260            }
261        }
262        None
263    }
264}
265
266impl<F: UsedField> Invert for F {
267    fn invert(self, _is_expected_non_zero: bool) -> Self {
268        F::invert(&self).unwrap_or(F::ZERO)
269    }
270}
271
272impl<F: UsedField> Pow for F {
273    fn pow(self, e: &Number, _is_expected_non_zero: bool) -> Self {
274        let e = e % (F::modulus() - 1);
275        let mut e_u64 = [0u64; 4];
276        let bytes: [u8; 32] = e.into();
277        for (i, chunk) in bytes.chunks_exact(8).enumerate() {
278            e_u64[i] = u64::from_le_bytes(chunk.try_into().unwrap());
279        }
280
281        F::pow(&self, e_u64)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use crate::utils::field::ScalarField;
289    use ff::Field;
290    #[test]
291    fn is_ge_zero() {
292        for n in [
293            ScalarField::ZERO,
294            ScalarField::ONE,
295            ScalarField::TWO_INV - ScalarField::ONE,
296            ScalarField::TWO_INV,
297            ScalarField::ZERO - ScalarField::ONE,
298        ] {
299            assert_eq!(n.is_ge_zero(), n <= ScalarField::ZERO - n)
300        }
301    }
302}