1use alloc::vec::Vec;
2use core::ops::{Add, Div, Mul, Neg, Sub};
3use num_traits::Pow;
4use serde::{Deserialize, Serialize};
5
6pub mod bpoly;
7pub mod poly;
8
9pub use bpoly::*;
10pub use poly::*;
11
12pub const PRIME: u64 = 18446744069414584321;
14pub const PRIME_128: u128 = 18446744069414584321;
15const RP: u128 = 340282366841710300967557013911933812736;
16pub const R2: u128 = 18446744065119617025;
17
18#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Default, Serialize, Deserialize)]
19#[repr(transparent)]
20pub struct Belt(pub u64);
21
22impl Belt {
23 pub fn from_bytes(bytes: &[u8]) -> Vec<Belt> {
24 let mut belts = Vec::new();
25 for chunk in bytes.chunks(4) {
26 let mut arr = [0u8; 4];
27 arr[..chunk.len()].copy_from_slice(chunk);
28 belts.push(Belt(u32::from_le_bytes(arr) as u64));
29 }
30 belts
31 }
32
33 pub fn to_bytes(belts: &[Belt]) -> Vec<u8> {
34 let mut bytes = Vec::new();
35 for b in belts {
36 bytes.extend(u32::try_from(b.0).expect("Too big for u32").to_le_bytes());
37 }
38 bytes
39 }
40}
41
42pub fn based_check(a: u64) -> bool {
43 a < PRIME
44}
45
46#[macro_export]
47macro_rules! based {
48 ( $( $x:expr ),* ) => {
49 {
50 $(
51 debug_assert!($crate::belt::based_check($x), "element must be inside the field\r");
52 )*
53 }
54 };
55}
56
57const ROOTS: &[u64] = &[
58 0x0000000000000001,
59 0xffffffff00000000,
60 0x0001000000000000,
61 0xfffffffeff000001,
62 0xefffffff00000001,
63 0x00003fffffffc000,
64 0x0000008000000000,
65 0xf80007ff08000001,
66 0xbf79143ce60ca966,
67 0x1905d02a5c411f4e,
68 0x9d8f2ad78bfed972,
69 0x0653b4801da1c8cf,
70 0xf2c35199959dfcb6,
71 0x1544ef2335d17997,
72 0xe0ee099310bba1e2,
73 0xf6b2cffe2306baac,
74 0x54df9630bf79450e,
75 0xabd0a6e8aa3d8a0e,
76 0x81281a7b05f9beac,
77 0xfbd41c6b8caa3302,
78 0x30ba2ecd5e93e76d,
79 0xf502aef532322654,
80 0x4b2a18ade67246b5,
81 0xea9d5a1336fbc98b,
82 0x86cdcc31c307e171,
83 0x4bbaf5976ecfefd8,
84 0xed41d05b78d6e286,
85 0x10d78dd8915a171d,
86 0x59049500004a4485,
87 0xdfa8c93ba46d2666,
88 0x7e9bd009b86a0845,
89 0x400a7f755588e659,
90 0x185629dcda58878c,
91];
92
93impl Belt {
94 #[inline(always)]
95 pub fn zero() -> Self {
96 Belt(Default::default())
97 }
98
99 #[inline(always)]
100 pub fn one() -> Self {
101 Belt(1)
102 }
103
104 #[inline(always)]
105 pub fn is_zero(&self) -> bool {
106 self.0 == 0
107 }
108
109 #[inline(always)]
110 pub fn is_one(&self) -> bool {
111 self.0 == 1
112 }
113
114 #[inline(always)]
115 pub fn ordered_root(&self) -> Result<Self, FieldError> {
116 let log_of_self = self.0.ilog2();
117 if (log_of_self as usize) >= ROOTS.len() {
118 return Err(FieldError::OrderedRootError);
119 }
120 if self.0 != 1 << log_of_self {
122 return Err(FieldError::OrderedRootError);
123 }
124 Ok(ROOTS[log_of_self as usize].into())
125 }
126
127 #[inline(always)]
128 pub fn inv(&self) -> Self {
129 Belt(binv(self.0))
130 }
131}
132
133impl Add for Belt {
134 type Output = Self;
135
136 #[inline(always)]
137 fn add(self, rhs: Self) -> Self::Output {
138 let a = self.0;
139 let b = rhs.0;
140 Belt(badd(a, b))
141 }
142}
143
144impl Sub for Belt {
145 type Output = Self;
146
147 #[inline(always)]
148 fn sub(self, rhs: Self) -> Self::Output {
149 let a = self.0;
150 let b = rhs.0;
151 Belt(bsub(a, b))
152 }
153}
154
155impl Neg for Belt {
156 type Output = Self;
157
158 #[inline(always)]
159 fn neg(self) -> Self::Output {
160 let a = self.0;
161 Belt(bneg(a))
162 }
163}
164
165impl Mul for Belt {
166 type Output = Self;
167
168 #[inline(always)]
169 fn mul(self, rhs: Self) -> Self::Output {
170 let a = self.0;
171 let b = rhs.0;
172 Belt(bmul(a, b))
173 }
174}
175
176impl Pow<usize> for Belt {
177 type Output = Self;
178
179 #[inline(always)]
180 fn pow(self, rhs: usize) -> Self::Output {
181 Belt(bpow(self.0, rhs as u64))
182 }
183}
184
185impl Div for Belt {
186 type Output = Self;
187
188 #[inline(always)]
189 #[allow(clippy::suspicious_arithmetic_impl)]
190 fn div(self, rhs: Self) -> Self::Output {
191 self * rhs.inv()
192 }
193}
194
195impl PartialEq<u64> for Belt {
196 #[inline(always)]
197 fn eq(&self, other: &u64) -> bool {
198 self.0 == *other
199 }
200}
201
202impl PartialEq<Belt> for u64 {
203 #[inline(always)]
204 fn eq(&self, other: &Belt) -> bool {
205 *self == other.0
206 }
207}
208
209impl AsRef<u64> for Belt {
210 #[inline(always)]
211 fn as_ref(&self) -> &u64 {
212 &self.0
213 }
214}
215
216impl TryFrom<&u64> for Belt {
217 type Error = ();
218
219 #[inline(always)]
220 fn try_from(f: &u64) -> Result<Self, Self::Error> {
221 based!(*f);
222 Ok(Belt(*f))
223 }
224}
225
226impl From<u64> for Belt {
227 #[inline(always)]
228 fn from(f: u64) -> Self {
229 Belt(f)
230 }
231}
232
233impl From<Belt> for u64 {
234 #[inline(always)]
235 fn from(b: Belt) -> Self {
236 b.0
237 }
238}
239
240impl From<u32> for Belt {
241 #[inline(always)]
242 fn from(f: u32) -> Self {
243 Belt(f as u64)
244 }
245}
246
247impl From<Belt> for u32 {
248 #[inline(always)]
249 fn from(b: Belt) -> Self {
250 b.0 as u32
251 }
252}
253
254#[derive(Debug)]
255pub enum FieldError {
256 OrderedRootError,
257}
258
259#[inline(always)]
260pub fn mont_reduction(a: u128) -> u64 {
261 debug_assert!(a < RP, "element must be inside the field\r");
262 let x1: u128 = (a >> 32) & 0xffffffff;
263 let x2: u128 = a >> 64;
264 let c: u128 = {
265 let x0: u128 = a & 0xffffffff;
266 (x0 + x1) << 32
267 };
268 let f: u128 = c >> 64;
269 let d: u128 = c - (x1 + (f * PRIME_128));
270 if x2 >= d {
271 (x2 - d) as u64
272 } else {
273 (x2 + PRIME_128 - d) as u64
274 }
275}
276
277#[inline(always)]
278pub fn montiply(a: u64, b: u64) -> u64 {
279 based!(a);
280 based!(b);
281
282 mont_reduction((a as u128) * (b as u128))
283}
284
285#[inline(always)]
286pub fn montify(a: u64) -> u64 {
287 based!(a);
288
289 mont_reduction((a as u128) * R2)
290}
291
292#[inline(always)]
293pub fn badd(a: u64, b: u64) -> u64 {
294 based!(a);
295 based!(b);
296
297 let b = PRIME.wrapping_sub(b);
298 let (r, c) = a.overflowing_sub(b);
299 let adj = 0u32.wrapping_sub(c as u32);
300 r.wrapping_sub(adj as u64)
301}
302
303#[inline(always)]
304pub fn bneg(a: u64) -> u64 {
305 based!(a);
306 if a != 0 {
307 PRIME - a
308 } else {
309 0
310 }
311}
312
313#[inline(always)]
314pub fn bsub(a: u64, b: u64) -> u64 {
315 based!(a);
316 based!(b);
317
318 let (r, c) = a.overflowing_sub(b);
319 let adj = 0u32.wrapping_sub(c as u32);
320 r.wrapping_sub(adj as u64)
321}
322
323#[inline(always)]
324pub fn reduce(n: u128) -> u64 {
325 reduce_159(n as u64, (n >> 64) as u32, (n >> 96) as u64)
326}
327
328#[inline(always)]
329pub fn reduce_159(low: u64, mid: u32, high: u64) -> u64 {
330 let (mut low2, carry) = low.overflowing_sub(high);
331 if carry {
332 low2 = low2.wrapping_add(PRIME);
333 }
334
335 let mut product = (mid as u64) << 32;
336 product -= product >> 32;
337
338 let (mut result, carry) = product.overflowing_add(low2);
339 if carry {
340 result = result.wrapping_sub(PRIME);
341 }
342
343 if result >= PRIME {
344 result -= PRIME;
345 }
346 result
347}
348
349#[inline(always)]
350pub fn bmul(a: u64, b: u64) -> u64 {
351 based!(a);
352 based!(b);
353 reduce((a as u128) * (b as u128))
354}
355
356#[inline(always)]
357pub fn binv(a: u64) -> u64 {
358 based!(a);
359 let y = montify(a);
360 let y2 = montiply(y, montiply(y, y));
361 let y3 = montiply(y, montiply(y2, y2));
362 let y5 = montiply(y2, montwopow(y3, 2));
363 let y10 = montiply(y5, montwopow(y5, 5));
364 let y20 = montiply(y10, montwopow(y10, 10));
365 let y30 = montiply(y10, montwopow(y20, 10));
366 let y31 = montiply(y, montiply(y30, y30));
367 let dup = montiply(montwopow(y31, 32), y31);
368
369 mont_reduction(montiply(y, montiply(dup, dup)).into())
370}
371
372#[inline(always)]
373pub fn montwopow(a: u64, b: u32) -> u64 {
374 based!(a);
375
376 let mut res = a;
377 for _ in 0..b {
378 res = montiply(res, res);
379 }
380 res
381}
382
383#[inline(always)]
384pub fn bpow(mut a: u64, mut b: u64) -> u64 {
385 based!(a);
386 based!(b);
387
388 let mut c: u64 = 1;
389 if b == 0 {
390 return c;
391 }
392
393 while b > 1 {
394 if b & 1 == 0 {
395 a = reduce((a as u128) * (a as u128));
396 b /= 2;
397 } else {
398 c = reduce((c as u128) * (a as u128));
399 a = reduce((a as u128) * (a as u128));
400 b = (b - 1) / 2;
401 }
402 }
403 reduce((c as u128) * (a as u128))
404}