arcis_compiler/utils/
used_field.rs1use crate::{
2 traits::{FromLeBytes, Invert, Pow},
3 utils::{matrix::Matrix, number::Number},
4};
5use ff::{
6 derive::bitvec::{order::Lsb0, view::AsBits},
7 PrimeField,
8};
9use num_bigint::{BigInt, BigUint};
10use num_traits::Zero;
11use rand::Rng;
12use std::{cmp::Ordering, hash::Hash};
13
14pub trait UsedField:
17 PrimeField
18 + Hash
19 + PartialOrd
20 + From<Number>
21 + From<i32>
22 + From<bool>
23 + From<f64>
24 + Zero
25 + std::ops::Shr<usize, Output = Self>
26 + FromLeBytes
27{
28 fn modulus() -> Number;
30
31 fn get_alpha() -> Number;
33
34 fn get_alpha_inverse() -> Number;
36
37 fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>);
39
40 fn power_of_two(exponent: usize) -> Self;
42
43 fn negative_power_of_two(exponent: usize) -> Self {
45 Self::ZERO - Self::power_of_two(exponent)
46 }
47
48 fn to_unsigned_number(self) -> Number {
49 BigInt::from(BigUint::from_bytes_le(self.to_repr().as_ref())).into()
50 }
51
52 fn to_signed_number(self) -> Number {
53 if self.is_ge_zero() {
54 self.to_unsigned_number()
55 } else {
56 -(Self::ZERO - self).to_unsigned_number()
57 }
58 }
59
60 fn is_binary(self) -> bool {
62 self <= Self::ONE
63 }
64
65 #[inline(always)]
67 fn is_ge_zero(self) -> bool {
68 self < Self::TWO_INV
70 }
71
72 fn is_le_zero(self) -> bool {
74 self >= Self::ZERO - self
75 }
76
77 #[inline(always)]
79 fn is_gt_zero(self) -> bool {
80 !self.is_le_zero()
81 }
82
83 #[inline(always)]
85 fn is_lt_zero(self) -> bool {
86 !self.is_ge_zero()
87 }
88
89 fn max_cyclic(self, other: Self) -> Self {
91 if (other - self).is_ge_zero() {
92 other
93 } else {
94 self
95 }
96 }
97 fn min_cyclic(self, other: Self) -> Self {
99 if (other - self).is_ge_zero() {
100 self
101 } else {
102 other
103 }
104 }
105 fn max(self, other: Self, signed: bool) -> Self {
107 let offset = if signed { Self::TWO_INV } else { Self::ZERO };
108 if self - offset < other - offset {
109 other
110 } else {
111 self
112 }
113 }
114 fn min(self, other: Self, signed: bool) -> Self {
116 let offset = if signed { Self::TWO_INV } else { Self::ZERO };
117 if self - offset > other - offset {
118 other
119 } else {
120 self
121 }
122 }
123 fn sort_pair(self, other: Self) -> (Self, Self) {
125 if (other - self).is_ge_zero() {
126 (self, other)
127 } else {
128 (other, self)
129 }
130 }
131 fn abs(self) -> Self {
133 if self.is_ge_zero() {
134 self
135 } else {
136 Self::ZERO - self
137 }
138 }
139 fn does_mul_overflow(self, other: Self) -> bool {
141 let zero = Self::ZERO;
142 if self == zero || other == zero {
143 return false;
144 }
145 let prod = self.to_unsigned_number() * other.to_unsigned_number();
146 prod >= Self::modulus()
147 }
148 fn does_add_signed_overflow(self, other: Self) -> bool {
149 let sum = self + other;
150 match (self.is_ge_zero(), other.is_ge_zero()) {
151 (true, true) => sum.is_lt_zero(),
152 (true, false) => false,
153 (false, true) => false,
154 (false, false) => sum.is_ge_zero(),
155 }
156 }
157 fn does_add_unsigned_overflow(self, other: Self) -> bool {
158 if self == Self::ZERO || other == Self::ZERO {
159 false
160 } else {
161 self >= -other
162 }
163 }
164 fn unsigned_bits(self) -> usize {
166 let binding = self.to_repr();
167 let bits = binding.as_bits::<Lsb0>();
168 bits.len() - bits.trailing_zeros()
169 }
170 fn signed_bits(self) -> usize {
172 self.abs().unsigned_bits()
173 }
174 fn unsigned_bit(&self, idx: usize) -> bool {
176 let repr = self.to_repr();
177 let bits = repr.as_bits::<Lsb0>();
178 if idx < bits.len() {
179 bits[idx]
180 } else {
181 false
182 }
183 }
184 fn signed_bit(&self, idx: usize) -> bool {
186 if self.is_ge_zero() {
187 self.unsigned_bit(idx)
188 } else {
189 !(self.abs() - Self::ONE).unsigned_bit(idx)
190 }
191 }
192 fn unsigned_euclidean_division(self, other: Self) -> Self {
194 if other == Self::ZERO {
195 Self::ZERO
196 } else {
197 (self.to_unsigned_number() / other.to_unsigned_number()).into()
198 }
199 }
200 fn signed_euclidean_division(self, other: Self) -> Self {
202 if other == Self::ZERO {
203 Self::ZERO
204 } else {
205 (self.to_signed_number() / other.to_signed_number()).into()
206 }
207 }
208 fn gen_inclusive_range<R: Rng + ?Sized>(rng: &mut R, min: Self, max: Self) -> Self {
210 min + Self::from(Number::gen_range(
211 rng,
212 &0.into(),
213 &((max - min).to_unsigned_number() + 1),
214 ))
215 }
216
217 fn from_bin(bin: &str) -> Self {
219 Self::from(
220 bin.chars()
221 .enumerate()
222 .fold(Number::from(0), |acc, (i, c)| {
223 if c == '1' {
224 acc + Number::power_of_two(i)
225 } else {
226 acc
227 }
228 }),
229 )
230 }
231
232 fn to_bin(&self) -> String {
234 (0..Self::modulus().bits()).fold(String::new(), |mut acc, i| {
235 if self.unsigned_bit(i) {
236 acc.push('1');
237 } else {
238 acc.push('0');
239 }
240 acc
241 })
242 }
243 fn as_power_of_two(self) -> Option<usize> {
244 if self == Self::ZERO {
245 return None;
246 }
247 let mut min_possible_exponent = 0usize;
248 let mut max_possible_exponent = Self::CAPACITY as usize;
249 while max_possible_exponent >= min_possible_exponent {
250 let mid = (min_possible_exponent + max_possible_exponent) / 2;
251 match self.partial_cmp(&Self::power_of_two(mid)) {
252 None => panic!("order should be total"),
253 Some(Ordering::Less) => {
254 max_possible_exponent = mid - 1;
255 }
256 Some(Ordering::Equal) => return Some(mid),
257 Some(Ordering::Greater) => {
258 min_possible_exponent = mid + 1;
259 }
260 }
261 }
262 None
263 }
264}
265
266impl<F: UsedField> Invert for F {
267 fn invert(self, _is_expected_non_zero: bool) -> Self {
268 F::invert(&self).unwrap_or(F::ZERO)
269 }
270}
271
272impl<F: UsedField> Pow for F {
273 fn pow(self, e: &Number, _is_expected_non_zero: bool) -> Self {
274 let e = e % (F::modulus() - 1);
275 let mut e_u64 = [0u64; 4];
276 let bytes: [u8; 32] = e.into();
277 for (i, chunk) in bytes.chunks_exact(8).enumerate() {
278 e_u64[i] = u64::from_le_bytes(chunk.try_into().unwrap());
279 }
280
281 F::pow(&self, e_u64)
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use crate::utils::field::ScalarField;
289 use ff::Field;
290 #[test]
291 fn is_ge_zero() {
292 for n in [
293 ScalarField::ZERO,
294 ScalarField::ONE,
295 ScalarField::TWO_INV - ScalarField::ONE,
296 ScalarField::TWO_INV,
297 ScalarField::ZERO - ScalarField::ONE,
298 ] {
299 assert_eq!(n.is_ge_zero(), n <= ScalarField::ZERO - n)
300 }
301 }
302}