Skip to main content

arcis_compiler/utils/
used_field.rs

1use crate::{
2    traits::{FromLeBytes, Invert, Pow},
3    utils::{matrix::Matrix, number::Number},
4};
5use ff::{
6    derive::bitvec::{order::Lsb0, view::AsBits},
7    PrimeField,
8};
9use num_bigint::{BigInt, BigUint};
10use num_traits::Zero;
11use rand::Rng;
12use std::{cmp::Ordering, hash::Hash};
13
14/// A collection of functions on prime fields.
15/// A lot of these functions are linked to the different orders one can define over a cyclic group.
16pub trait UsedField:
17    PrimeField
18    + Hash
19    + PartialOrd
20    + From<Number>
21    + From<i32>
22    + From<bool>
23    + From<f64>
24    + Zero
25    + std::ops::Shr<usize, Output = Self>
26    + FromLeBytes
27{
28    /// The prime number p such that the field is F_p.
29    fn modulus() -> Number;
30
31    /// The smallest prime such that it does not divide p-1.
32    fn get_alpha() -> Number;
33
34    /// The smallest positive integer n such that alpha*n = 1 mod (p-1).
35    fn get_alpha_inverse() -> Number;
36
37    /// An MDS matrix and its inverse.
38    fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>);
39
40    /// Returns two^exponent.
41    fn power_of_two(exponent: usize) -> Self;
42
43    /// Returns -two^exponent.
44    fn negative_power_of_two(exponent: usize) -> Self {
45        Self::ZERO - Self::power_of_two(exponent)
46    }
47
48    fn exponent_close_power_of_two() -> usize;
49
50    fn to_unsigned_number(self) -> Number {
51        BigInt::from(BigUint::from_bytes_le(self.to_repr().as_ref())).into()
52    }
53
54    fn to_signed_number(self) -> Number {
55        if self.is_ge_zero() {
56            self.to_unsigned_number()
57        } else {
58            -(Self::ZERO - self).to_unsigned_number()
59        }
60    }
61
62    /// Whether a number is binary or not.
63    fn is_binary(self) -> bool {
64        self <= Self::ONE
65    }
66
67    /// Whether a number is greater or equal to zero according to the signed order.
68    #[inline(always)]
69    fn is_ge_zero(self) -> bool {
70        // should be equivalent to self <= Self::ZERO - self (see test)
71        self < Self::TWO_INV
72    }
73
74    /// Whether a number is less than or equal to zero according to the signed order.
75    fn is_le_zero(self) -> bool {
76        self >= Self::ZERO - self
77    }
78
79    /// Whether a number is greater than zero according to the signed order.
80    #[inline(always)]
81    fn is_gt_zero(self) -> bool {
82        !self.is_le_zero()
83    }
84
85    /// Whether a number is less than zero according to the signed order.
86    #[inline(always)]
87    fn is_lt_zero(self) -> bool {
88        !self.is_ge_zero()
89    }
90
91    /// Max according to the cyclic order on the smaller interval between the two field elements.
92    fn max_cyclic(self, other: Self) -> (Self, bool) {
93        if (other - self).is_ge_zero() {
94            (other, true)
95        } else {
96            (self, false)
97        }
98    }
99    /// Min according to the cyclic order on the smaller interval between the two field elements.
100    fn min_cyclic(self, other: Self) -> (Self, bool) {
101        if (other - self).is_ge_zero() {
102            (self, false)
103        } else {
104            (other, true)
105        }
106    }
107    /// Max according to (un)signed order.
108    fn max(self, other: Self, signed: bool) -> Self {
109        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
110        if self - offset < other - offset {
111            other
112        } else {
113            self
114        }
115    }
116    /// Min according to (un)signed order.
117    fn min(self, other: Self, signed: bool) -> Self {
118        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
119        if self - offset > other - offset {
120            other
121        } else {
122            self
123        }
124    }
125    /// Sort according to the cyclic order on the smaller interval between the two field elements.
126    fn sort_pair(self, other: Self) -> (Self, Self) {
127        if (other - self).is_ge_zero() {
128            (self, other)
129        } else {
130            (other, self)
131        }
132    }
133    /// Abs according to the signed order.
134    fn abs(self) -> Self {
135        if self.is_ge_zero() {
136            self
137        } else {
138            Self::ZERO - self
139        }
140    }
141    // assuming self and other >= 0
142    fn does_mul_overflow(self, other: Self) -> bool {
143        if self.is_zero_vartime() || other.is_zero_vartime() {
144            return false;
145        }
146        let prod = self.to_unsigned_number() * other.to_unsigned_number();
147        prod >= Self::modulus()
148    }
149    fn does_add_signed_overflow(self, other: Self) -> bool {
150        let sum = self + other;
151        match (self.is_ge_zero(), other.is_ge_zero()) {
152            (true, true) => sum.is_lt_zero(),
153            (true, false) => false,
154            (false, true) => false,
155            (false, false) => sum.is_ge_zero(),
156        }
157    }
158    fn does_add_unsigned_overflow(self, other: Self) -> bool {
159        if self == Self::ZERO || other == Self::ZERO {
160            false
161        } else {
162            self >= -other
163        }
164    }
165    /// The number of bits of self in unsigned notation.
166    fn unsigned_bits(self) -> usize {
167        let binding = self.to_repr();
168        let bits = binding.as_bits::<Lsb0>();
169        bits.len() - bits.trailing_zeros()
170    }
171    /// The number of bits of self in signed notation.
172    fn signed_bits(self) -> usize {
173        self.abs().unsigned_bits()
174    }
175    /// The idx bit of self in unsigned notation.
176    fn unsigned_bit(&self, idx: usize) -> bool {
177        let repr = self.to_repr();
178        let bits = repr.as_bits::<Lsb0>();
179        if idx < bits.len() {
180            bits[idx]
181        } else {
182            false
183        }
184    }
185    /// The idx bit of self in signed notation.
186    fn signed_bit(&self, idx: usize) -> bool {
187        if self.is_ge_zero() {
188            self.unsigned_bit(idx)
189        } else {
190            !(self.abs() - Self::ONE).unsigned_bit(idx)
191        }
192    }
193    /// The unsigned Euclidean division. Returns 0 if the divisor is 0.
194    fn unsigned_euclidean_division(self, other: Self) -> Self {
195        if other == Self::ZERO {
196            Self::ZERO
197        } else {
198            (self.to_unsigned_number() / other.to_unsigned_number()).into()
199        }
200    }
201    /// The unsigned Euclidean division. Returns 0 if the divisor is 0.
202    /// Returns better bounds if the inputs have bounds that overlap with negatives.
203    /// -1 / 2 = -1, with modulo 1
204    fn unsigned_euclidean_division_better_bounds(self, other: Self) -> Self {
205        if other == Self::ZERO {
206            Self::ZERO
207        } else {
208            let s = self.to_signed_number();
209            let other = other.to_unsigned_number();
210            if s < 0 {
211                ((&s - &s * &other) / &other + &s).into()
212            } else {
213                (s / other).into()
214            }
215        }
216    }
217    /// The signed Euclidean division. Returns 0 if the divisor is 0.
218    fn signed_euclidean_division(self, other: Self) -> Self {
219        if other == Self::ZERO {
220            Self::ZERO
221        } else {
222            (self.to_signed_number() / other.to_signed_number()).into()
223        }
224    }
225    /// Generates a field element between min and max, included.
226    fn gen_inclusive_range<R: Rng + ?Sized>(rng: &mut R, min: Self, max: Self) -> Self {
227        min + Self::from(Number::gen_range(
228            rng,
229            &0.into(),
230            &((max - min).to_unsigned_number() + 1),
231        ))
232    }
233
234    /// Converts a number in lsb-to-msb binary expansion to the corresponding element in Self.
235    fn from_bin(bin: &str) -> Self {
236        Self::from(
237            bin.chars()
238                .enumerate()
239                .fold(Number::from(0), |acc, (i, c)| {
240                    if c == '1' {
241                        acc + Number::power_of_two(i)
242                    } else {
243                        acc
244                    }
245                }),
246        )
247    }
248
249    /// Converts self to its lsb-to-msb binary expansion.
250    fn to_bin(&self) -> String {
251        (0..Self::modulus().bits()).fold(String::new(), |mut acc, i| {
252            if self.unsigned_bit(i) {
253                acc.push('1');
254            } else {
255                acc.push('0');
256            }
257            acc
258        })
259    }
260    fn as_power_of_two(self) -> Option<usize> {
261        if self == Self::ZERO {
262            return None;
263        }
264        let mut min_possible_exponent = 0usize;
265        let mut max_possible_exponent = Self::CAPACITY as usize;
266        while max_possible_exponent >= min_possible_exponent {
267            let mid = (min_possible_exponent + max_possible_exponent) / 2;
268            match self.partial_cmp(&Self::power_of_two(mid)) {
269                None => panic!("order should be total"),
270                Some(Ordering::Less) => {
271                    max_possible_exponent = mid - 1;
272                }
273                Some(Ordering::Equal) => return Some(mid),
274                Some(Ordering::Greater) => {
275                    min_possible_exponent = mid + 1;
276                }
277            }
278        }
279        None
280    }
281    fn signed_gt(self, other: Self) -> bool {
282        self.max(other, true) != other
283    }
284    fn signed_ge(self, other: Self) -> bool {
285        self.max(other, true) == self
286    }
287    fn signed_lt(self, other: Self) -> bool {
288        self.min(other, true) != other
289    }
290    fn signed_le(self, other: Self) -> bool {
291        self.min(other, true) == self
292    }
293}
294
295impl<F: UsedField> Invert for F {
296    fn invert(self, _is_expected_non_zero: bool) -> Self {
297        F::invert(&self).unwrap_or(F::ZERO)
298    }
299}
300
301impl<F: UsedField> Pow for F {
302    fn pow(self, e: &Number, _is_expected_non_zero: bool) -> Self {
303        let e = e % (F::modulus() - 1);
304        let mut e_u64 = [0u64; 4];
305        let bytes: [u8; 32] = e.into();
306        for (i, chunk) in bytes.chunks_exact(8).enumerate() {
307            e_u64[i] = u64::from_le_bytes(chunk.try_into().unwrap());
308        }
309
310        F::pow(&self, e_u64)
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::utils::field::ScalarField;
318    use ff::Field;
319    #[test]
320    fn is_ge_zero() {
321        for n in [
322            ScalarField::ZERO,
323            ScalarField::ONE,
324            ScalarField::TWO_INV - ScalarField::ONE,
325            ScalarField::TWO_INV,
326            ScalarField::ZERO - ScalarField::ONE,
327        ] {
328            assert_eq!(n.is_ge_zero(), n <= ScalarField::ZERO - n)
329        }
330    }
331}