arcis_compiler/utils/
used_field.rs1use crate::{
2 traits::{FromLeBytes, Invert, Pow},
3 utils::{matrix::Matrix, number::Number},
4};
5use ff::{
6 derive::bitvec::{order::Lsb0, view::AsBits},
7 PrimeField,
8};
9use num_bigint::{BigInt, BigUint};
10use num_traits::Zero;
11use rand::Rng;
12use std::{cmp::Ordering, hash::Hash};
13
14pub trait UsedField:
17 PrimeField
18 + Hash
19 + PartialOrd
20 + From<Number>
21 + From<i32>
22 + From<bool>
23 + From<f64>
24 + Zero
25 + std::ops::Shr<usize, Output = Self>
26 + FromLeBytes
27{
28 fn modulus() -> Number;
30
31 fn get_alpha() -> Number;
33
34 fn get_alpha_inverse() -> Number;
36
37 fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>);
39
40 fn power_of_two(exponent: usize) -> Self;
42
43 fn negative_power_of_two(exponent: usize) -> Self {
45 Self::ZERO - Self::power_of_two(exponent)
46 }
47
48 fn exponent_close_power_of_two() -> usize;
49
50 fn to_unsigned_number(self) -> Number {
51 BigInt::from(BigUint::from_bytes_le(self.to_repr().as_ref())).into()
52 }
53
54 fn to_signed_number(self) -> Number {
55 if self.is_ge_zero() {
56 self.to_unsigned_number()
57 } else {
58 -(Self::ZERO - self).to_unsigned_number()
59 }
60 }
61
62 fn is_binary(self) -> bool {
64 self <= Self::ONE
65 }
66
67 #[inline(always)]
69 fn is_ge_zero(self) -> bool {
70 self < Self::TWO_INV
72 }
73
74 fn is_le_zero(self) -> bool {
76 self >= Self::ZERO - self
77 }
78
79 #[inline(always)]
81 fn is_gt_zero(self) -> bool {
82 !self.is_le_zero()
83 }
84
85 #[inline(always)]
87 fn is_lt_zero(self) -> bool {
88 !self.is_ge_zero()
89 }
90
91 fn max_cyclic(self, other: Self) -> (Self, bool) {
93 if (other - self).is_ge_zero() {
94 (other, true)
95 } else {
96 (self, false)
97 }
98 }
99 fn min_cyclic(self, other: Self) -> (Self, bool) {
101 if (other - self).is_ge_zero() {
102 (self, false)
103 } else {
104 (other, true)
105 }
106 }
107 fn max(self, other: Self, signed: bool) -> Self {
109 let offset = if signed { Self::TWO_INV } else { Self::ZERO };
110 if self - offset < other - offset {
111 other
112 } else {
113 self
114 }
115 }
116 fn min(self, other: Self, signed: bool) -> Self {
118 let offset = if signed { Self::TWO_INV } else { Self::ZERO };
119 if self - offset > other - offset {
120 other
121 } else {
122 self
123 }
124 }
125 fn sort_pair(self, other: Self) -> (Self, Self) {
127 if (other - self).is_ge_zero() {
128 (self, other)
129 } else {
130 (other, self)
131 }
132 }
133 fn abs(self) -> Self {
135 if self.is_ge_zero() {
136 self
137 } else {
138 Self::ZERO - self
139 }
140 }
141 fn does_mul_overflow(self, other: Self) -> bool {
143 if self.is_zero_vartime() || other.is_zero_vartime() {
144 return false;
145 }
146 let prod = self.to_unsigned_number() * other.to_unsigned_number();
147 prod >= Self::modulus()
148 }
149 fn does_add_signed_overflow(self, other: Self) -> bool {
150 let sum = self + other;
151 match (self.is_ge_zero(), other.is_ge_zero()) {
152 (true, true) => sum.is_lt_zero(),
153 (true, false) => false,
154 (false, true) => false,
155 (false, false) => sum.is_ge_zero(),
156 }
157 }
158 fn does_add_unsigned_overflow(self, other: Self) -> bool {
159 if self == Self::ZERO || other == Self::ZERO {
160 false
161 } else {
162 self >= -other
163 }
164 }
165 fn unsigned_bits(self) -> usize {
167 let binding = self.to_repr();
168 let bits = binding.as_bits::<Lsb0>();
169 bits.len() - bits.trailing_zeros()
170 }
171 fn signed_bits(self) -> usize {
173 self.abs().unsigned_bits()
174 }
175 fn unsigned_bit(&self, idx: usize) -> bool {
177 let repr = self.to_repr();
178 let bits = repr.as_bits::<Lsb0>();
179 if idx < bits.len() {
180 bits[idx]
181 } else {
182 false
183 }
184 }
185 fn signed_bit(&self, idx: usize) -> bool {
187 if self.is_ge_zero() {
188 self.unsigned_bit(idx)
189 } else {
190 !(self.abs() - Self::ONE).unsigned_bit(idx)
191 }
192 }
193 fn unsigned_euclidean_division(self, other: Self) -> Self {
195 if other == Self::ZERO {
196 Self::ZERO
197 } else {
198 (self.to_unsigned_number() / other.to_unsigned_number()).into()
199 }
200 }
201 fn unsigned_euclidean_division_better_bounds(self, other: Self) -> Self {
205 if other == Self::ZERO {
206 Self::ZERO
207 } else {
208 let s = self.to_signed_number();
209 let other = other.to_unsigned_number();
210 if s < 0 {
211 ((&s - &s * &other) / &other + &s).into()
212 } else {
213 (s / other).into()
214 }
215 }
216 }
217 fn signed_euclidean_division(self, other: Self) -> Self {
219 if other == Self::ZERO {
220 Self::ZERO
221 } else {
222 (self.to_signed_number() / other.to_signed_number()).into()
223 }
224 }
225 fn gen_inclusive_range<R: Rng + ?Sized>(rng: &mut R, min: Self, max: Self) -> Self {
227 min + Self::from(Number::gen_range(
228 rng,
229 &0.into(),
230 &((max - min).to_unsigned_number() + 1),
231 ))
232 }
233
234 fn from_bin(bin: &str) -> Self {
236 Self::from(
237 bin.chars()
238 .enumerate()
239 .fold(Number::from(0), |acc, (i, c)| {
240 if c == '1' {
241 acc + Number::power_of_two(i)
242 } else {
243 acc
244 }
245 }),
246 )
247 }
248
249 fn to_bin(&self) -> String {
251 (0..Self::modulus().bits()).fold(String::new(), |mut acc, i| {
252 if self.unsigned_bit(i) {
253 acc.push('1');
254 } else {
255 acc.push('0');
256 }
257 acc
258 })
259 }
260 fn as_power_of_two(self) -> Option<usize> {
261 if self == Self::ZERO {
262 return None;
263 }
264 let mut min_possible_exponent = 0usize;
265 let mut max_possible_exponent = Self::CAPACITY as usize;
266 while max_possible_exponent >= min_possible_exponent {
267 let mid = (min_possible_exponent + max_possible_exponent) / 2;
268 match self.partial_cmp(&Self::power_of_two(mid)) {
269 None => panic!("order should be total"),
270 Some(Ordering::Less) => {
271 max_possible_exponent = mid - 1;
272 }
273 Some(Ordering::Equal) => return Some(mid),
274 Some(Ordering::Greater) => {
275 min_possible_exponent = mid + 1;
276 }
277 }
278 }
279 None
280 }
281 fn signed_gt(self, other: Self) -> bool {
282 self.max(other, true) != other
283 }
284 fn signed_ge(self, other: Self) -> bool {
285 self.max(other, true) == self
286 }
287 fn signed_lt(self, other: Self) -> bool {
288 self.min(other, true) != other
289 }
290 fn signed_le(self, other: Self) -> bool {
291 self.min(other, true) == self
292 }
293}
294
295impl<F: UsedField> Invert for F {
296 fn invert(self, _is_expected_non_zero: bool) -> Self {
297 F::invert(&self).unwrap_or(F::ZERO)
298 }
299}
300
301impl<F: UsedField> Pow for F {
302 fn pow(self, e: &Number, _is_expected_non_zero: bool) -> Self {
303 let e = e % (F::modulus() - 1);
304 let mut e_u64 = [0u64; 4];
305 let bytes: [u8; 32] = e.into();
306 for (i, chunk) in bytes.chunks_exact(8).enumerate() {
307 e_u64[i] = u64::from_le_bytes(chunk.try_into().unwrap());
308 }
309
310 F::pow(&self, e_u64)
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::utils::field::ScalarField;
318 use ff::Field;
319 #[test]
320 fn is_ge_zero() {
321 for n in [
322 ScalarField::ZERO,
323 ScalarField::ONE,
324 ScalarField::TWO_INV - ScalarField::ONE,
325 ScalarField::TWO_INV,
326 ScalarField::ZERO - ScalarField::ONE,
327 ] {
328 assert_eq!(n.is_ge_zero(), n <= ScalarField::ZERO - n)
329 }
330 }
331}