iris_ztd/belt/
mod.rs

1use alloc::vec::Vec;
2use core::ops::{Add, Div, Mul, Neg, Sub};
3use num_traits::Pow;
4use serde::{Deserialize, Serialize};
5
6pub mod bpoly;
7pub mod poly;
8
9pub use bpoly::*;
10pub use poly::*;
11
12// Base field arithmetic functions.
13pub const PRIME: u64 = 18446744069414584321;
14pub const PRIME_128: u128 = 18446744069414584321;
15const RP: u128 = 340282366841710300967557013911933812736;
16pub const R2: u128 = 18446744065119617025;
17
18#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Default, Serialize, Deserialize)]
19#[repr(transparent)]
20pub struct Belt(pub u64);
21
22impl Belt {
23    pub fn from_bytes(bytes: &[u8]) -> Vec<Belt> {
24        let mut belts = Vec::new();
25        for chunk in bytes.chunks(4) {
26            let mut arr = [0u8; 4];
27            arr[..chunk.len()].copy_from_slice(chunk);
28            belts.push(Belt(u32::from_le_bytes(arr) as u64));
29        }
30        belts
31    }
32
33    pub fn to_bytes(belts: &[Belt]) -> Vec<u8> {
34        let mut bytes = Vec::new();
35        for b in belts {
36            bytes.extend(u32::try_from(b.0).expect("Too big for u32").to_le_bytes());
37        }
38        bytes
39    }
40}
41
42pub fn based_check(a: u64) -> bool {
43    a < PRIME
44}
45
46#[macro_export]
47macro_rules! based {
48    ( $( $x:expr ),* ) => {
49      {
50          $(
51              debug_assert!($crate::belt::based_check($x), "element must be inside the field\r");
52          )*
53      }
54    };
55}
56
57const ROOTS: &[u64] = &[
58    0x0000000000000001,
59    0xffffffff00000000,
60    0x0001000000000000,
61    0xfffffffeff000001,
62    0xefffffff00000001,
63    0x00003fffffffc000,
64    0x0000008000000000,
65    0xf80007ff08000001,
66    0xbf79143ce60ca966,
67    0x1905d02a5c411f4e,
68    0x9d8f2ad78bfed972,
69    0x0653b4801da1c8cf,
70    0xf2c35199959dfcb6,
71    0x1544ef2335d17997,
72    0xe0ee099310bba1e2,
73    0xf6b2cffe2306baac,
74    0x54df9630bf79450e,
75    0xabd0a6e8aa3d8a0e,
76    0x81281a7b05f9beac,
77    0xfbd41c6b8caa3302,
78    0x30ba2ecd5e93e76d,
79    0xf502aef532322654,
80    0x4b2a18ade67246b5,
81    0xea9d5a1336fbc98b,
82    0x86cdcc31c307e171,
83    0x4bbaf5976ecfefd8,
84    0xed41d05b78d6e286,
85    0x10d78dd8915a171d,
86    0x59049500004a4485,
87    0xdfa8c93ba46d2666,
88    0x7e9bd009b86a0845,
89    0x400a7f755588e659,
90    0x185629dcda58878c,
91];
92
93impl Belt {
94    #[inline(always)]
95    pub fn zero() -> Self {
96        Belt(Default::default())
97    }
98
99    #[inline(always)]
100    pub fn one() -> Self {
101        Belt(1)
102    }
103
104    #[inline(always)]
105    pub fn is_zero(&self) -> bool {
106        self.0 == 0
107    }
108
109    #[inline(always)]
110    pub fn is_one(&self) -> bool {
111        self.0 == 1
112    }
113
114    #[inline(always)]
115    pub fn ordered_root(&self) -> Result<Self, FieldError> {
116        let log_of_self = self.0.ilog2();
117        if (log_of_self as usize) >= ROOTS.len() {
118            return Err(FieldError::OrderedRootError);
119        }
120        // assert that it was an even power of two
121        if self.0 != 1 << log_of_self {
122            return Err(FieldError::OrderedRootError);
123        }
124        Ok(ROOTS[log_of_self as usize].into())
125    }
126
127    #[inline(always)]
128    pub fn inv(&self) -> Self {
129        Belt(binv(self.0))
130    }
131}
132
133impl Add for Belt {
134    type Output = Self;
135
136    #[inline(always)]
137    fn add(self, rhs: Self) -> Self::Output {
138        let a = self.0;
139        let b = rhs.0;
140        Belt(badd(a, b))
141    }
142}
143
144impl Sub for Belt {
145    type Output = Self;
146
147    #[inline(always)]
148    fn sub(self, rhs: Self) -> Self::Output {
149        let a = self.0;
150        let b = rhs.0;
151        Belt(bsub(a, b))
152    }
153}
154
155impl Neg for Belt {
156    type Output = Self;
157
158    #[inline(always)]
159    fn neg(self) -> Self::Output {
160        let a = self.0;
161        Belt(bneg(a))
162    }
163}
164
165impl Mul for Belt {
166    type Output = Self;
167
168    #[inline(always)]
169    fn mul(self, rhs: Self) -> Self::Output {
170        let a = self.0;
171        let b = rhs.0;
172        Belt(bmul(a, b))
173    }
174}
175
176impl Pow<usize> for Belt {
177    type Output = Self;
178
179    #[inline(always)]
180    fn pow(self, rhs: usize) -> Self::Output {
181        Belt(bpow(self.0, rhs as u64))
182    }
183}
184
185impl Div for Belt {
186    type Output = Self;
187
188    #[inline(always)]
189    #[allow(clippy::suspicious_arithmetic_impl)]
190    fn div(self, rhs: Self) -> Self::Output {
191        self * rhs.inv()
192    }
193}
194
195impl PartialEq<u64> for Belt {
196    #[inline(always)]
197    fn eq(&self, other: &u64) -> bool {
198        self.0 == *other
199    }
200}
201
202impl PartialEq<Belt> for u64 {
203    #[inline(always)]
204    fn eq(&self, other: &Belt) -> bool {
205        *self == other.0
206    }
207}
208
209impl AsRef<u64> for Belt {
210    #[inline(always)]
211    fn as_ref(&self) -> &u64 {
212        &self.0
213    }
214}
215
216impl TryFrom<&u64> for Belt {
217    type Error = ();
218
219    #[inline(always)]
220    fn try_from(f: &u64) -> Result<Self, Self::Error> {
221        based!(*f);
222        Ok(Belt(*f))
223    }
224}
225
226impl From<u64> for Belt {
227    #[inline(always)]
228    fn from(f: u64) -> Self {
229        Belt(f)
230    }
231}
232
233impl From<Belt> for u64 {
234    #[inline(always)]
235    fn from(b: Belt) -> Self {
236        b.0
237    }
238}
239
240impl From<u32> for Belt {
241    #[inline(always)]
242    fn from(f: u32) -> Self {
243        Belt(f as u64)
244    }
245}
246
247impl From<Belt> for u32 {
248    #[inline(always)]
249    fn from(b: Belt) -> Self {
250        b.0 as u32
251    }
252}
253
254#[derive(Debug)]
255pub enum FieldError {
256    OrderedRootError,
257}
258
259#[inline(always)]
260pub fn mont_reduction(a: u128) -> u64 {
261    debug_assert!(a < RP, "element must be inside the field\r");
262    let x1: u128 = (a >> 32) & 0xffffffff;
263    let x2: u128 = a >> 64;
264    let c: u128 = {
265        let x0: u128 = a & 0xffffffff;
266        (x0 + x1) << 32
267    };
268    let f: u128 = c >> 64;
269    let d: u128 = c - (x1 + (f * PRIME_128));
270    if x2 >= d {
271        (x2 - d) as u64
272    } else {
273        (x2 + PRIME_128 - d) as u64
274    }
275}
276
277#[inline(always)]
278pub fn montiply(a: u64, b: u64) -> u64 {
279    based!(a);
280    based!(b);
281
282    mont_reduction((a as u128) * (b as u128))
283}
284
285#[inline(always)]
286pub fn montify(a: u64) -> u64 {
287    based!(a);
288
289    mont_reduction((a as u128) * R2)
290}
291
292#[inline(always)]
293pub fn badd(a: u64, b: u64) -> u64 {
294    based!(a);
295    based!(b);
296
297    let b = PRIME.wrapping_sub(b);
298    let (r, c) = a.overflowing_sub(b);
299    let adj = 0u32.wrapping_sub(c as u32);
300    r.wrapping_sub(adj as u64)
301}
302
303#[inline(always)]
304pub fn bneg(a: u64) -> u64 {
305    based!(a);
306    if a != 0 {
307        PRIME - a
308    } else {
309        0
310    }
311}
312
313#[inline(always)]
314pub fn bsub(a: u64, b: u64) -> u64 {
315    based!(a);
316    based!(b);
317
318    let (r, c) = a.overflowing_sub(b);
319    let adj = 0u32.wrapping_sub(c as u32);
320    r.wrapping_sub(adj as u64)
321}
322
323#[inline(always)]
324pub fn reduce(n: u128) -> u64 {
325    reduce_159(n as u64, (n >> 64) as u32, (n >> 96) as u64)
326}
327
328#[inline(always)]
329pub fn reduce_159(low: u64, mid: u32, high: u64) -> u64 {
330    let (mut low2, carry) = low.overflowing_sub(high);
331    if carry {
332        low2 = low2.wrapping_add(PRIME);
333    }
334
335    let mut product = (mid as u64) << 32;
336    product -= product >> 32;
337
338    let (mut result, carry) = product.overflowing_add(low2);
339    if carry {
340        result = result.wrapping_sub(PRIME);
341    }
342
343    if result >= PRIME {
344        result -= PRIME;
345    }
346    result
347}
348
349#[inline(always)]
350pub fn bmul(a: u64, b: u64) -> u64 {
351    based!(a);
352    based!(b);
353    reduce((a as u128) * (b as u128))
354}
355
356#[inline(always)]
357pub fn binv(a: u64) -> u64 {
358    based!(a);
359    let y = montify(a);
360    let y2 = montiply(y, montiply(y, y));
361    let y3 = montiply(y, montiply(y2, y2));
362    let y5 = montiply(y2, montwopow(y3, 2));
363    let y10 = montiply(y5, montwopow(y5, 5));
364    let y20 = montiply(y10, montwopow(y10, 10));
365    let y30 = montiply(y10, montwopow(y20, 10));
366    let y31 = montiply(y, montiply(y30, y30));
367    let dup = montiply(montwopow(y31, 32), y31);
368
369    mont_reduction(montiply(y, montiply(dup, dup)).into())
370}
371
372#[inline(always)]
373pub fn montwopow(a: u64, b: u32) -> u64 {
374    based!(a);
375
376    let mut res = a;
377    for _ in 0..b {
378        res = montiply(res, res);
379    }
380    res
381}
382
383#[inline(always)]
384pub fn bpow(mut a: u64, mut b: u64) -> u64 {
385    based!(a);
386    based!(b);
387
388    let mut c: u64 = 1;
389    if b == 0 {
390        return c;
391    }
392
393    while b > 1 {
394        if b & 1 == 0 {
395            a = reduce((a as u128) * (a as u128));
396            b /= 2;
397        } else {
398            c = reduce((c as u128) * (a as u128));
399            a = reduce((a as u128) * (a as u128));
400            b = (b - 1) / 2;
401        }
402    }
403    reduce((c as u128) * (a as u128))
404}