Skip to main content

arcis_compiler/utils/
used_field.rs

1use crate::{
2    traits::{FromLeBytes, Invert, Pow},
3    utils::{matrix::Matrix, number::Number},
4};
5use ff::{
6    derive::bitvec::{order::Lsb0, view::AsBits},
7    PrimeField,
8};
9use num_bigint::{BigInt, BigUint};
10use num_traits::Zero;
11use rand::Rng;
12use std::{cmp::Ordering, hash::Hash};
13
14/// A collection of functions on prime fields.
15/// A lot of these functions are linked to the different orders one can define over a cyclic group.
16pub trait UsedField:
17    PrimeField
18    + Hash
19    + PartialOrd
20    + From<Number>
21    + From<i32>
22    + From<bool>
23    + From<f64>
24    + Zero
25    + std::ops::Shr<usize, Output = Self>
26    + FromLeBytes
27{
28    /// The prime number p such that the field is F_p.
29    fn modulus() -> Number;
30
31    /// The smallest prime such that it does not divide p-1.
32    fn get_alpha() -> Number;
33
34    /// The smallest positive integer n such that alpha*n = 1 mod (p-1).
35    fn get_alpha_inverse() -> Number;
36
37    /// An MDS matrix and its inverse.
38    fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>);
39
40    /// Returns two^exponent.
41    fn power_of_two(exponent: usize) -> Self;
42
43    /// Returns -two^exponent.
44    fn negative_power_of_two(exponent: usize) -> Self {
45        Self::ZERO - Self::power_of_two(exponent)
46    }
47
48    fn to_unsigned_number(self) -> Number {
49        BigInt::from(BigUint::from_bytes_le(self.to_repr().as_ref())).into()
50    }
51
52    fn to_signed_number(self) -> Number {
53        if self.is_ge_zero() {
54            self.to_unsigned_number()
55        } else {
56            -(Self::ZERO - self).to_unsigned_number()
57        }
58    }
59
60    /// Whether a number is binary or not.
61    fn is_binary(self) -> bool {
62        self <= Self::ONE
63    }
64
65    /// Whether a number is greater or equal to zero according to the signed order.
66    #[inline(always)]
67    fn is_ge_zero(self) -> bool {
68        // should be equivalent to self <= Self::ZERO - self (see test)
69        self < Self::TWO_INV
70    }
71
72    /// Whether a number is less than or equal to zero according to the signed order.
73    fn is_le_zero(self) -> bool {
74        self >= Self::ZERO - self
75    }
76
77    /// Whether a number is greater than zero according to the signed order.
78    #[inline(always)]
79    fn is_gt_zero(self) -> bool {
80        !self.is_le_zero()
81    }
82
83    /// Whether a number is less than zero according to the signed order.
84    #[inline(always)]
85    fn is_lt_zero(self) -> bool {
86        !self.is_ge_zero()
87    }
88
89    /// Max according to the cyclic order on the smaller interval between the two field elements.
90    fn max_cyclic(self, other: Self) -> (Self, bool) {
91        if (other - self).is_ge_zero() {
92            (other, true)
93        } else {
94            (self, false)
95        }
96    }
97    /// Min according to the cyclic order on the smaller interval between the two field elements.
98    fn min_cyclic(self, other: Self) -> (Self, bool) {
99        if (other - self).is_ge_zero() {
100            (self, false)
101        } else {
102            (other, true)
103        }
104    }
105    /// Max according to (un)signed order.
106    fn max(self, other: Self, signed: bool) -> Self {
107        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
108        if self - offset < other - offset {
109            other
110        } else {
111            self
112        }
113    }
114    /// Min according to (un)signed order.
115    fn min(self, other: Self, signed: bool) -> Self {
116        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
117        if self - offset > other - offset {
118            other
119        } else {
120            self
121        }
122    }
123    /// Sort according to the cyclic order on the smaller interval between the two field elements.
124    fn sort_pair(self, other: Self) -> (Self, Self) {
125        if (other - self).is_ge_zero() {
126            (self, other)
127        } else {
128            (other, self)
129        }
130    }
131    /// Abs according to the signed order.
132    fn abs(self) -> Self {
133        if self.is_ge_zero() {
134            self
135        } else {
136            Self::ZERO - self
137        }
138    }
139    // assuming self and other >= 0
140    fn does_mul_overflow(self, other: Self) -> bool {
141        if self.is_zero_vartime() || other.is_zero_vartime() {
142            return false;
143        }
144        let prod = self.to_unsigned_number() * other.to_unsigned_number();
145        prod >= Self::modulus()
146    }
147    fn does_add_signed_overflow(self, other: Self) -> bool {
148        let sum = self + other;
149        match (self.is_ge_zero(), other.is_ge_zero()) {
150            (true, true) => sum.is_lt_zero(),
151            (true, false) => false,
152            (false, true) => false,
153            (false, false) => sum.is_ge_zero(),
154        }
155    }
156    fn does_add_unsigned_overflow(self, other: Self) -> bool {
157        if self == Self::ZERO || other == Self::ZERO {
158            false
159        } else {
160            self >= -other
161        }
162    }
163    /// The number of bits of self in unsigned notation.
164    fn unsigned_bits(self) -> usize {
165        let binding = self.to_repr();
166        let bits = binding.as_bits::<Lsb0>();
167        bits.len() - bits.trailing_zeros()
168    }
169    /// The number of bits of self in signed notation.
170    fn signed_bits(self) -> usize {
171        self.abs().unsigned_bits()
172    }
173    /// The idx bit of self in unsigned notation.
174    fn unsigned_bit(&self, idx: usize) -> bool {
175        let repr = self.to_repr();
176        let bits = repr.as_bits::<Lsb0>();
177        if idx < bits.len() {
178            bits[idx]
179        } else {
180            false
181        }
182    }
183    /// The idx bit of self in signed notation.
184    fn signed_bit(&self, idx: usize) -> bool {
185        if self.is_ge_zero() {
186            self.unsigned_bit(idx)
187        } else {
188            !(self.abs() - Self::ONE).unsigned_bit(idx)
189        }
190    }
191    /// The unsigned Euclidean division. Returns 0 if the divisor is 0.
192    fn unsigned_euclidean_division(self, other: Self) -> Self {
193        if other == Self::ZERO {
194            Self::ZERO
195        } else {
196            (self.to_unsigned_number() / other.to_unsigned_number()).into()
197        }
198    }
199    /// The signed Euclidean division. Returns 0 if the divisor is 0.
200    fn signed_euclidean_division(self, other: Self) -> Self {
201        if other == Self::ZERO {
202            Self::ZERO
203        } else {
204            (self.to_signed_number() / other.to_signed_number()).into()
205        }
206    }
207    /// Generates a field element between min and max, included.
208    fn gen_inclusive_range<R: Rng + ?Sized>(rng: &mut R, min: Self, max: Self) -> Self {
209        min + Self::from(Number::gen_range(
210            rng,
211            &0.into(),
212            &((max - min).to_unsigned_number() + 1),
213        ))
214    }
215
216    /// Converts a number in lsb-to-msb binary expansion to the corresponding element in Self.
217    fn from_bin(bin: &str) -> Self {
218        Self::from(
219            bin.chars()
220                .enumerate()
221                .fold(Number::from(0), |acc, (i, c)| {
222                    if c == '1' {
223                        acc + Number::power_of_two(i)
224                    } else {
225                        acc
226                    }
227                }),
228        )
229    }
230
231    /// Converts self to its lsb-to-msb binary expansion.
232    fn to_bin(&self) -> String {
233        (0..Self::modulus().bits()).fold(String::new(), |mut acc, i| {
234            if self.unsigned_bit(i) {
235                acc.push('1');
236            } else {
237                acc.push('0');
238            }
239            acc
240        })
241    }
242    fn as_power_of_two(self) -> Option<usize> {
243        if self == Self::ZERO {
244            return None;
245        }
246        let mut min_possible_exponent = 0usize;
247        let mut max_possible_exponent = Self::CAPACITY as usize;
248        while max_possible_exponent >= min_possible_exponent {
249            let mid = (min_possible_exponent + max_possible_exponent) / 2;
250            match self.partial_cmp(&Self::power_of_two(mid)) {
251                None => panic!("order should be total"),
252                Some(Ordering::Less) => {
253                    max_possible_exponent = mid - 1;
254                }
255                Some(Ordering::Equal) => return Some(mid),
256                Some(Ordering::Greater) => {
257                    min_possible_exponent = mid + 1;
258                }
259            }
260        }
261        None
262    }
263    fn signed_gt(self, other: Self) -> bool {
264        self.max(other, true) != other
265    }
266    fn signed_ge(self, other: Self) -> bool {
267        self.max(other, true) == self
268    }
269    fn signed_lt(self, other: Self) -> bool {
270        self.min(other, true) != other
271    }
272    fn signed_le(self, other: Self) -> bool {
273        self.min(other, true) == self
274    }
275}
276
277impl<F: UsedField> Invert for F {
278    fn invert(self, _is_expected_non_zero: bool) -> Self {
279        F::invert(&self).unwrap_or(F::ZERO)
280    }
281}
282
283impl<F: UsedField> Pow for F {
284    fn pow(self, e: &Number, _is_expected_non_zero: bool) -> Self {
285        let e = e % (F::modulus() - 1);
286        let mut e_u64 = [0u64; 4];
287        let bytes: [u8; 32] = e.into();
288        for (i, chunk) in bytes.chunks_exact(8).enumerate() {
289            e_u64[i] = u64::from_le_bytes(chunk.try_into().unwrap());
290        }
291
292        F::pow(&self, e_u64)
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::utils::field::ScalarField;
300    use ff::Field;
301    #[test]
302    fn is_ge_zero() {
303        for n in [
304            ScalarField::ZERO,
305            ScalarField::ONE,
306            ScalarField::TWO_INV - ScalarField::ONE,
307            ScalarField::TWO_INV,
308            ScalarField::ZERO - ScalarField::ONE,
309        ] {
310            assert_eq!(n.is_ge_zero(), n <= ScalarField::ZERO - n)
311        }
312    }
313}