arcis_compiler/utils/
used_field.rs1use crate::{
2 traits::{FromLeBytes, Invert, Pow},
3 utils::{matrix::Matrix, number::Number},
4};
5use ff::{
6 derive::bitvec::{order::Lsb0, view::AsBits},
7 PrimeField,
8};
9use num_bigint::{BigInt, BigUint};
10use num_traits::Zero;
11use rand::Rng;
12use std::{cmp::Ordering, hash::Hash};
13
14pub trait UsedField:
17 PrimeField
18 + Hash
19 + PartialOrd
20 + From<Number>
21 + From<i32>
22 + From<bool>
23 + From<f64>
24 + Zero
25 + std::ops::Shr<usize, Output = Self>
26 + FromLeBytes
27{
28 fn modulus() -> Number;
30
31 fn get_alpha() -> Number;
33
34 fn get_alpha_inverse() -> Number;
36
37 fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>);
39
40 fn power_of_two(exponent: usize) -> Self;
42
43 fn negative_power_of_two(exponent: usize) -> Self {
45 Self::ZERO - Self::power_of_two(exponent)
46 }
47
48 fn to_unsigned_number(self) -> Number {
49 BigInt::from(BigUint::from_bytes_le(self.to_repr().as_ref())).into()
50 }
51
52 fn to_signed_number(self) -> Number {
53 if self.is_ge_zero() {
54 self.to_unsigned_number()
55 } else {
56 -(Self::ZERO - self).to_unsigned_number()
57 }
58 }
59
60 fn is_binary(self) -> bool {
62 self <= Self::ONE
63 }
64
65 #[inline(always)]
67 fn is_ge_zero(self) -> bool {
68 self < Self::TWO_INV
70 }
71
72 fn is_le_zero(self) -> bool {
74 self >= Self::ZERO - self
75 }
76
77 #[inline(always)]
79 fn is_gt_zero(self) -> bool {
80 !self.is_le_zero()
81 }
82
83 #[inline(always)]
85 fn is_lt_zero(self) -> bool {
86 !self.is_ge_zero()
87 }
88
89 fn max_cyclic(self, other: Self) -> (Self, bool) {
91 if (other - self).is_ge_zero() {
92 (other, true)
93 } else {
94 (self, false)
95 }
96 }
97 fn min_cyclic(self, other: Self) -> (Self, bool) {
99 if (other - self).is_ge_zero() {
100 (self, false)
101 } else {
102 (other, true)
103 }
104 }
105 fn max(self, other: Self, signed: bool) -> Self {
107 let offset = if signed { Self::TWO_INV } else { Self::ZERO };
108 if self - offset < other - offset {
109 other
110 } else {
111 self
112 }
113 }
114 fn min(self, other: Self, signed: bool) -> Self {
116 let offset = if signed { Self::TWO_INV } else { Self::ZERO };
117 if self - offset > other - offset {
118 other
119 } else {
120 self
121 }
122 }
123 fn sort_pair(self, other: Self) -> (Self, Self) {
125 if (other - self).is_ge_zero() {
126 (self, other)
127 } else {
128 (other, self)
129 }
130 }
131 fn abs(self) -> Self {
133 if self.is_ge_zero() {
134 self
135 } else {
136 Self::ZERO - self
137 }
138 }
139 fn does_mul_overflow(self, other: Self) -> bool {
141 if self.is_zero_vartime() || other.is_zero_vartime() {
142 return false;
143 }
144 let prod = self.to_unsigned_number() * other.to_unsigned_number();
145 prod >= Self::modulus()
146 }
147 fn does_add_signed_overflow(self, other: Self) -> bool {
148 let sum = self + other;
149 match (self.is_ge_zero(), other.is_ge_zero()) {
150 (true, true) => sum.is_lt_zero(),
151 (true, false) => false,
152 (false, true) => false,
153 (false, false) => sum.is_ge_zero(),
154 }
155 }
156 fn does_add_unsigned_overflow(self, other: Self) -> bool {
157 if self == Self::ZERO || other == Self::ZERO {
158 false
159 } else {
160 self >= -other
161 }
162 }
163 fn unsigned_bits(self) -> usize {
165 let binding = self.to_repr();
166 let bits = binding.as_bits::<Lsb0>();
167 bits.len() - bits.trailing_zeros()
168 }
169 fn signed_bits(self) -> usize {
171 self.abs().unsigned_bits()
172 }
173 fn unsigned_bit(&self, idx: usize) -> bool {
175 let repr = self.to_repr();
176 let bits = repr.as_bits::<Lsb0>();
177 if idx < bits.len() {
178 bits[idx]
179 } else {
180 false
181 }
182 }
183 fn signed_bit(&self, idx: usize) -> bool {
185 if self.is_ge_zero() {
186 self.unsigned_bit(idx)
187 } else {
188 !(self.abs() - Self::ONE).unsigned_bit(idx)
189 }
190 }
191 fn unsigned_euclidean_division(self, other: Self) -> Self {
193 if other == Self::ZERO {
194 Self::ZERO
195 } else {
196 (self.to_unsigned_number() / other.to_unsigned_number()).into()
197 }
198 }
199 fn signed_euclidean_division(self, other: Self) -> Self {
201 if other == Self::ZERO {
202 Self::ZERO
203 } else {
204 (self.to_signed_number() / other.to_signed_number()).into()
205 }
206 }
207 fn gen_inclusive_range<R: Rng + ?Sized>(rng: &mut R, min: Self, max: Self) -> Self {
209 min + Self::from(Number::gen_range(
210 rng,
211 &0.into(),
212 &((max - min).to_unsigned_number() + 1),
213 ))
214 }
215
216 fn from_bin(bin: &str) -> Self {
218 Self::from(
219 bin.chars()
220 .enumerate()
221 .fold(Number::from(0), |acc, (i, c)| {
222 if c == '1' {
223 acc + Number::power_of_two(i)
224 } else {
225 acc
226 }
227 }),
228 )
229 }
230
231 fn to_bin(&self) -> String {
233 (0..Self::modulus().bits()).fold(String::new(), |mut acc, i| {
234 if self.unsigned_bit(i) {
235 acc.push('1');
236 } else {
237 acc.push('0');
238 }
239 acc
240 })
241 }
242 fn as_power_of_two(self) -> Option<usize> {
243 if self == Self::ZERO {
244 return None;
245 }
246 let mut min_possible_exponent = 0usize;
247 let mut max_possible_exponent = Self::CAPACITY as usize;
248 while max_possible_exponent >= min_possible_exponent {
249 let mid = (min_possible_exponent + max_possible_exponent) / 2;
250 match self.partial_cmp(&Self::power_of_two(mid)) {
251 None => panic!("order should be total"),
252 Some(Ordering::Less) => {
253 max_possible_exponent = mid - 1;
254 }
255 Some(Ordering::Equal) => return Some(mid),
256 Some(Ordering::Greater) => {
257 min_possible_exponent = mid + 1;
258 }
259 }
260 }
261 None
262 }
263 fn signed_gt(self, other: Self) -> bool {
264 self.max(other, true) != other
265 }
266 fn signed_ge(self, other: Self) -> bool {
267 self.max(other, true) == self
268 }
269 fn signed_lt(self, other: Self) -> bool {
270 self.min(other, true) != other
271 }
272 fn signed_le(self, other: Self) -> bool {
273 self.min(other, true) == self
274 }
275}
276
277impl<F: UsedField> Invert for F {
278 fn invert(self, _is_expected_non_zero: bool) -> Self {
279 F::invert(&self).unwrap_or(F::ZERO)
280 }
281}
282
283impl<F: UsedField> Pow for F {
284 fn pow(self, e: &Number, _is_expected_non_zero: bool) -> Self {
285 let e = e % (F::modulus() - 1);
286 let mut e_u64 = [0u64; 4];
287 let bytes: [u8; 32] = e.into();
288 for (i, chunk) in bytes.chunks_exact(8).enumerate() {
289 e_u64[i] = u64::from_le_bytes(chunk.try_into().unwrap());
290 }
291
292 F::pow(&self, e_u64)
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::utils::field::ScalarField;
300 use ff::Field;
301 #[test]
302 fn is_ge_zero() {
303 for n in [
304 ScalarField::ZERO,
305 ScalarField::ONE,
306 ScalarField::TWO_INV - ScalarField::ONE,
307 ScalarField::TWO_INV,
308 ScalarField::ZERO - ScalarField::ONE,
309 ] {
310 assert_eq!(n.is_ge_zero(), n <= ScalarField::ZERO - n)
311 }
312 }
313}