abstract_integers/
nat_mod.rs

1#[macro_export]
2macro_rules! modular_integer {
3    ($name:ident, $base:ident, $max:expr) => {
4        #[derive(Clone, Copy, Default)]
5        pub struct $name($base);
6
7        impl core::fmt::Display for $name {
8            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
9                let uint: $base = (*self).into();
10                write!(f, "{}", uint)
11            }
12        }
13
14        impl core::fmt::Debug for $name {
15            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
16                let uint: $base = (*self).into();
17                write!(f, "{}", uint)
18            }
19        }
20
21        impl core::fmt::LowerHex for $name {
22            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
23                let val: $base = (*self).into();
24                core::fmt::LowerHex::fmt(&val, f)
25            }
26        }
27
28        impl From<$base> for $name {
29            fn from(x: $base) -> $name {
30                $name(x.rem($max))
31            }
32        }
33
34        impl Into<$base> for $name {
35            fn into(self) -> $base {
36                self.0
37            }
38        }
39
40        impl $name {
41            pub fn from_canvas(x: $base) -> $name {
42                $name(x.rem($max))
43            }
44            pub fn into_canvas(self) -> $base {
45                self.0
46            }
47
48            pub fn max() -> $base {
49                $max
50            }
51
52            pub fn declassify(self) -> BigInt {
53                let a: $base = self.into();
54                a.into()
55            }
56
57            #[allow(dead_code)]
58            pub fn from_hex(s: &str) -> Self {
59                $base::from_hex(s).into()
60            }
61
62            #[allow(dead_code)]
63            pub fn from_be_bytes(v: &[u8]) -> Self {
64                $base::from_be_bytes(v).into()
65            }
66
67            #[allow(dead_code)]
68            pub fn to_be_bytes(self) -> Vec<u8> {
69                $base::to_be_bytes(self.into()).to_vec()
70            }
71
72            #[allow(dead_code)]
73            pub fn from_le_bytes(v: &[u8]) -> Self {
74                $base::from_le_bytes(v).into()
75            }
76
77            #[allow(dead_code)]
78            pub fn to_le_bytes(self) -> Vec<u8> {
79                $base::to_le_bytes(self.into()).to_vec()
80            }
81
82            /// Gets the `i`-th least significant bit of this integer.
83            #[allow(dead_code)]
84            pub fn bit(self, i: usize) -> bool {
85                $base::bit(self.into(), i)
86            }
87
88            #[allow(dead_code)]
89            pub fn from_literal(x: u128) -> Self {
90                let big_x = BigUint::from(x);
91                if big_x > $name::max().into() {
92                    panic!("literal {} too big for type {}", x, stringify!($name));
93                }
94                $name(big_x.into())
95            }
96
97            #[allow(dead_code)]
98            pub fn from_signed_literal(x: i128) -> Self {
99                let big_x = BigUint::from(x as u128);
100                if big_x > $name::max().into() {
101                    panic!("literal {} too big for type {}", x, stringify!($name));
102                }
103                $name(big_x.into())
104            }
105
106            #[inline]
107            pub fn comp_eq(self, rhs: Self) -> Self {
108                let x: $base = self.into();
109                x.comp_eq(rhs.into()).into()
110            }
111
112            #[inline]
113            pub fn comp_ne(self, rhs: Self) -> Self {
114                let x: $base = self.into();
115                x.comp_ne(rhs.into()).into()
116            }
117
118            #[inline]
119            pub fn comp_gte(self, rhs: Self) -> Self {
120                let x: $base = self.into();
121                x.comp_gte(rhs.into()).into()
122            }
123
124            #[inline]
125            pub fn comp_gt(self, rhs: Self) -> Self {
126                let x: $base = self.into();
127                x.comp_gt(rhs.into()).into()
128            }
129
130            #[inline]
131            pub fn comp_lte(self, rhs: Self) -> Self {
132                let x: $base = self.into();
133                x.comp_lte(rhs.into()).into()
134            }
135
136            #[inline]
137            pub fn comp_lt(self, rhs: Self) -> Self {
138                let x: $base = self.into();
139                x.comp_lt(rhs.into()).into()
140            }
141
142            /// Negate the value modulo max: `mod_value - self`
143            #[inline]
144            pub fn neg(self) -> Self {
145                let mod_val: BigInt = $max.into();
146                let s: $base = self.into();
147                let s: BigInt = s.into();
148                let result: $base = mod_val.sub(s).into();
149                result.into()
150            }
151        }
152    };
153}
154
155// FIXME: Implement ct algorithms
156#[macro_export]
157macro_rules! abstract_secret_modular_integer {
158    ($name:ident, $base:ident, $max:expr) => {
159        modular_integer!($name, $base, $max);
160
161        impl $name {
162            fn modulo(self, n: Self) -> Self {
163                let a: $base = self.into();
164                let b: $base = n.into();
165                let a: BigUint = a.into();
166                let b: BigUint = b.into();
167                let r: $base = (a % b).into();
168                r.into()
169            }
170        }
171
172        /// **Warning**: wraps on overflow.
173        impl Add for $name {
174            type Output = $name;
175            fn add(self, rhs: $name) -> $name {
176                let a: $base = self.into();
177                let b: $base = rhs.into();
178                let a: BigUint = a.into();
179                let b: BigUint = b.into();
180                let c: BigUint = a + b;
181                let max: BigUint = $max.into();
182                let d: BigUint = c % max;
183                let d: $base = d.into();
184                d.into()
185            }
186        }
187
188        /// **Warning**: wraps on underflow.
189        impl Sub for $name {
190            type Output = $name;
191            fn sub(self, rhs: $name) -> $name {
192                let a: $base = self.into();
193                let b: $base = rhs.into();
194                let a: BigUint = a.into();
195                let b: BigUint = b.into();
196                let max: BigUint = $max.into();
197                let c: BigUint = if b > a { max.clone() - b + a } else { a - b };
198                let d: BigUint = c % max;
199                let d: $base = d.into();
200                d.into()
201            }
202        }
203
204        /// **Warning**: wraps on overflow.
205        impl Mul for $name {
206            type Output = $name;
207            fn mul(self, rhs: $name) -> $name {
208                let a: $base = self.into();
209                let b: $base = rhs.into();
210                let a: BigUint = a.into();
211                let b: BigUint = b.into();
212                let c: BigUint = a * b;
213                let max: BigUint = $max.into();
214                let d: BigUint = c % max;
215                let d: $base = d.into();
216                d.into()
217            }
218        }
219
220        impl Not for $name {
221            type Output = $name;
222            fn not(self) -> Self::Output {
223                let a: $base = self.into();
224                let not_a = !a;
225                not_a.rem($max).into()
226            }
227        }
228
229        impl BitOr for $name {
230            type Output = $name;
231            fn bitor(self, rhs: Self) -> Self::Output {
232                let a: $base = self.into();
233                let b: $base = rhs.into();
234                (a | b).into()
235            }
236        }
237
238        impl BitXor for $name {
239            type Output = $name;
240            fn bitxor(self, rhs: Self) -> Self::Output {
241                let a: $base = self.into();
242                let b: $base = rhs.into();
243                (a ^ b).into()
244            }
245        }
246
247        impl BitAnd for $name {
248            type Output = $name;
249            fn bitand(self, rhs: Self) -> Self::Output {
250                let a: $base = self.into();
251                let b: $base = rhs.into();
252                (a & b).into()
253            }
254        }
255
256        impl Shr<usize> for $name {
257            type Output = $name;
258            fn shr(self, rhs: usize) -> Self::Output {
259                let a: $base = self.into();
260                (a >> rhs).into()
261            }
262        }
263
264        impl Shl<usize> for $name {
265            type Output = $name;
266            fn shl(self, rhs: usize) -> Self::Output {
267                let a: $base = self.into();
268                (a << rhs).into()
269            }
270        }
271    };
272}
273
274#[macro_export]
275macro_rules! abstract_public_modular_integer {
276    ($name:ident, $base:ident, $max:expr) => {
277        modular_integer!($name, $base, $max);
278
279        // TODO: implement PartialEq, Eq, PartialOrd, Ord,
280        impl PartialOrd for $name {
281            fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
282                Some(self.cmp(other))
283            }
284        }
285        impl Ord for $name {
286            fn cmp(&self, other: &Self) -> Ordering {
287                self.0.cmp(&other.0)
288            }
289        }
290        impl PartialEq for $name {
291            fn eq(&self, other: &Self) -> bool {
292                self.0 == other.0
293            }
294        }
295        impl Eq for $name {}
296
297        /// **Warning**: wraps on overflow.
298        impl Add for $name {
299            type Output = $name;
300            fn add(self, rhs: $name) -> $name {
301                let a: $base = self.into();
302                let b: $base = rhs.into();
303                let a: BigUint = a.into();
304                let b: BigUint = b.into();
305                let c: BigUint = a + b;
306                let max: BigUint = $max.into();
307                let d: BigUint = c % max;
308                let d: $base = d.into();
309                d.into()
310            }
311        }
312
313        /// **Warning**: wraps on underflow.
314        impl Sub for $name {
315            type Output = $name;
316            fn sub(self, rhs: $name) -> $name {
317                let a: $base = self.into();
318                let b: $base = rhs.into();
319                let a: BigUint = a.into();
320                let b: BigUint = b.into();
321                let max: BigUint = $max.into();
322                let c: BigUint = if b > a { max.clone() - b + a } else { a - b };
323                let d: BigUint = c % max;
324                let d: $base = d.into();
325                d.into()
326            }
327        }
328
329        /// **Warning**: wraps on overflow.
330        impl Mul for $name {
331            type Output = $name;
332            fn mul(self, rhs: $name) -> $name {
333                let a: $base = self.into();
334                let b: $base = rhs.into();
335                let a: BigUint = a.into();
336                let b: BigUint = b.into();
337                let c: BigUint = a * b;
338                let max: BigUint = $max.into();
339                let d: BigUint = c % max;
340                let d: $base = d.into();
341                d.into()
342            }
343        }
344
345        impl Div for $name {
346            type Output = $name;
347            fn div(self, rhs: $name) -> $name {
348                self * rhs.inv()
349            }
350        }
351
352        /// **Warning**: panics on division by 0.
353        impl Rem for $name {
354            type Output = $name;
355            fn rem(self, rhs: $name) -> $name {
356                let a: $base = self.into();
357                let b: $base = rhs.into();
358                let a: BigUint = a.into();
359                let b: BigUint = b.into();
360                let c: BigUint = a % b;
361                let max: BigUint = $max.into();
362                let d: BigUint = c % max;
363                let d: $base = d.into();
364                d.into()
365            }
366        }
367
368        impl Not for $name {
369            type Output = $name;
370            fn not(self) -> Self::Output {
371                let a: $base = self.into();
372                (!a).into()
373            }
374        }
375
376        impl BitOr for $name {
377            type Output = $name;
378            fn bitor(self, rhs: Self) -> Self::Output {
379                let a: $base = self.into();
380                let b: $base = rhs.into();
381                (a | b).into()
382            }
383        }
384
385        impl BitXor for $name {
386            type Output = $name;
387            fn bitxor(self, rhs: Self) -> Self::Output {
388                let a: $base = self.into();
389                let b: $base = rhs.into();
390                (a ^ b).into()
391            }
392        }
393
394        impl BitAnd for $name {
395            type Output = $name;
396            fn bitand(self, rhs: Self) -> Self::Output {
397                let a: $base = self.into();
398                let b: $base = rhs.into();
399                (a & b).into()
400            }
401        }
402
403        impl Shr<usize> for $name {
404            type Output = $name;
405            fn shr(self, rhs: usize) -> Self::Output {
406                let a: $base = self.into();
407                (a >> rhs).into()
408            }
409        }
410
411        impl Shl<usize> for $name {
412            type Output = $name;
413            fn shl(self, rhs: usize) -> Self::Output {
414                let a: $base = self.into();
415                (a << rhs).into()
416            }
417        }
418
419        impl $name {
420            #[allow(dead_code)]
421            pub fn inv(self) -> Self {
422                let base: $base = self.into();
423                base.inv(Self::max()).into()
424            }
425
426            #[allow(dead_code)]
427            pub fn pow_felem(self, exp: Self) -> Self {
428                let base: $base = self.into();
429                base.pow_felem(exp.into(), Self::max()).into()
430            }
431            /// Returns self to the power of the argument.
432            /// The exponent is a u128.
433            #[allow(dead_code)]
434            pub fn pow(self, exp: u128) -> Self {
435                let base: $base = self.into();
436                base.pow(exp, Self::max()).into()
437            }
438
439            /// Returns 2 to the power of the argument
440            #[allow(dead_code)]
441            pub fn pow2(x: usize) -> $name {
442                $base::pow2(x).into()
443            }
444        }
445    };
446}
447
448#[macro_export]
449macro_rules! abstract_nat_mod {
450    ($name:ident,$base:ident,$bits:literal,$n:literal) => {
451        abstract_unsigned_secret_integer!($base, $bits);
452        abstract_secret_modular_integer!($name, $base, $base::from_hex($n));
453    };
454}
455
456#[macro_export]
457macro_rules! abstract_public_nat_mod {
458    ($name:ident,$base:ident,$bits:literal,$n:literal) => {
459        abstract_unsigned_public_integer!($base, $bits);
460        abstract_public_modular_integer!($name, $base, $base::from_hex($n));
461    };
462}
463
464// ============ Legacy API ============
465
466/// Defines a bounded natural integer with modular arithmetic operations
467#[macro_export]
468macro_rules! define_refined_modular_integer {
469    ($name:ident, $base:ident, $max:expr) => {
470        abstract_public_modular_integer!($name, $base, $max);
471    };
472}