m61_modulus/
definition.rs

1//! Definition of the [`M61`] type as well as basic operations on it.
2
3use core::fmt;
4use core::iter;
5use core::ops;
6
7/// The modulus on which arithmetic is performed.
8/// Also functions as a bitmask for calculating
9/// digit sums base `2^61`.
10pub(crate) const MODULUS: u64 = (1 << 61) - 1;
11
12/// When calculating the reduction of an arbitary precision integer
13/// using a digit sum, the sum itself must be reduced aswell.
14/// This function performs this reduction, assuming that
15/// are themselves partially reduced, meaning `x <= 2 * (2^61 - 1)`.
16#[inline(always)]
17pub(crate) fn final_reduction(mut x: u64) -> M61 {
18    if x >= MODULUS {
19        x -= MODULUS;
20    }
21
22    if x >= MODULUS {
23        M61(x - MODULUS)
24    } else {
25        M61(x)
26    }
27}
28
29/// A 64-bit integer in which arithmetic is performed modulp `2^61 - 1`.
30#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
31#[repr(transparent)]
32pub struct M61(pub(crate) u64);
33
34impl M61 {
35    /// Returns the contained value.
36    #[inline(always)]
37    #[must_use]
38    pub const fn get(self) -> u64 {
39        self.0
40    }
41
42    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
43    pub fn pow(mut self, mut exp: u64) -> Self {
44        if exp == 0 {
45            return Self(1);
46        }
47        let mut acc = Self(1);
48
49        while exp != 1 {
50            if exp & 1 != 0 {
51                acc *= self;
52            }
53
54            exp /= 2;
55            self = self * self;
56        }
57
58        acc * self
59    }
60}
61
62/// A helper macro for the quick generation of formatting trait implementations.
63macro_rules! make_fmt_impl {
64    ($trait:ident) => {
65        impl fmt::$trait for M61 {
66            #[inline(always)]
67            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68                <u64 as fmt::$trait>::fmt(&self.0, f)
69            }
70        }
71    };
72}
73
74make_fmt_impl!(Display);
75make_fmt_impl!(Debug);
76make_fmt_impl!(LowerExp);
77make_fmt_impl!(UpperExp);
78make_fmt_impl!(LowerHex);
79make_fmt_impl!(UpperHex);
80make_fmt_impl!(Octal);
81make_fmt_impl!(Binary);
82
83/// A helper macro for generation of [`From`] implementations
84/// where the numerical bounds of the source type are smaller
85/// than the modulus.
86macro_rules! make_trivial_from {
87    ($type:ty) => {
88        impl From<$type> for M61 {
89            #[inline(always)]
90            fn from(value: $type) -> Self {
91                // rustc warns us against this seemingly useless comparison
92                // whenever the argument is an unsigned integer. The macro is also
93                // used on signed integer types, making the comparison neccessary.
94                #[allow(unused_comparisons)]
95                if value < 0 {
96                    Self((value as i64 + MODULUS as i64) as u64)
97                } else {
98                    Self(value as u64)
99                }
100            }
101        }
102    };
103}
104
105make_trivial_from!(u8);
106make_trivial_from!(u16);
107make_trivial_from!(u32);
108#[cfg(not(target_pointer_width = "64"))]
109make_trivial_from!(usize);
110
111#[cfg(target_pointer_width = "64")]
112impl From<usize> for M61 {
113    #[inline(always)]
114    fn from(value: usize) -> Self {
115        Self::from(value as u64)
116    }
117}
118
119make_trivial_from!(i8);
120make_trivial_from!(i16);
121make_trivial_from!(i32);
122#[cfg(not(target_pointer_width = "64"))]
123make_trivial_from!(isize);
124
125#[cfg(target_pointer_width = "64")]
126impl From<isize> for M61 {
127    #[inline(always)]
128    fn from(value: isize) -> Self {
129        Self::from(value as i64)
130    }
131}
132
133impl From<u64> for M61 {
134    #[inline]
135    fn from(value: u64) -> Self {
136        let tmp = (value & MODULUS) + (value >> 61);
137        if tmp >= MODULUS {
138            Self(tmp - MODULUS)
139        } else {
140            Self(tmp)
141        }
142    }
143}
144
145impl From<i64> for M61 {
146    #[inline]
147    fn from(mut value: i64) -> Self {
148        if value < 0 {
149            value = value.wrapping_add(4 * MODULUS as i64);
150        }
151        if value < 0 {
152            value = value.wrapping_add(MODULUS as i64);
153        }
154
155        Self::from(value as u64)
156    }
157}
158
159impl From<u128> for M61 {
160    #[inline]
161    fn from(value: u128) -> Self {
162        let mut x = value as u64 & MODULUS;
163        x += (value >> 61) as u64 & MODULUS;
164        x += (value >> 122) as u64;
165        Self::from(x)
166    }
167}
168
169impl From<i128> for M61 {
170    #[inline]
171    fn from(mut value: i128) -> Self {
172        while value < 0 {
173            value += 16 * ((1 << 122) - 1);
174        }
175
176        Self::from(value as u128)
177    }
178}
179
180/// A helper macro for the quick implementation of arithmetic operators.
181macro_rules! make_arith_impl {
182    ($trait:ident, $trait_assign:ident, $func:ident, $func_assign:ident, $op:tt, $impl:expr) => {
183        impl ops::$trait for M61 {
184            type Output = Self;
185
186            #[inline]
187            fn $func(self, rhs: Self) -> Self::Output {
188                #[allow(clippy::redundant_closure_call)]
189                Self($impl(self.0, rhs.0))
190            }
191        }
192
193        impl<'a> ops::$trait<&'a M61> for M61 {
194            type Output = Self;
195
196            #[inline(always)]
197            fn $func(self, rhs: &Self) -> Self::Output {
198                self $op *rhs
199            }
200        }
201
202        impl ops::$trait_assign for M61 {
203            #[inline(always)]
204            fn $func_assign(&mut self, rhs: Self) {
205                *self = *self $op rhs
206            }
207        }
208
209        impl<'a> ops::$trait_assign<&'a M61> for M61 {
210            #[inline(always)]
211            fn $func_assign(&mut self, rhs: &Self) {
212                *self = *self $op rhs
213            }
214        }
215    };
216}
217
218make_arith_impl!(Add, AddAssign, add, add_assign, +, |a, b| {
219    let x = a + b;
220    if x >= MODULUS {
221        x - MODULUS
222    } else {
223        x
224    }
225});
226make_arith_impl!(Sub, SubAssign, sub, sub_assign, -, |a, b| {
227    let x = a + MODULUS - b;
228    if x >= MODULUS {
229        x - MODULUS
230    } else {
231        x
232    }
233});
234make_arith_impl!(Mul, MulAssign, mul, mul_assign, *, |a, b| {
235    let x = a as u128 * b as u128;
236    let mut hi = (x >> 61) as u64;
237    let mut lo = (x as u64) & MODULUS;
238    lo = lo.wrapping_add(hi);
239    hi = lo.wrapping_sub(MODULUS);
240    if lo < MODULUS {
241        lo
242    } else {
243        hi
244    }
245});
246make_arith_impl!(Div, DivAssign, div, div_assign, /, |a, b| {
247    if b == 0 {
248        panic!("attempt to divide by zero");
249    }
250
251    // Calculate the multiplicative inverse using the extended Euclidean algorithm.
252    // (https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm)
253
254    let mut r0 = MODULUS;
255    let mut r1 = b;
256    let mut s0 = 1i64;
257    let mut s1 = 0i64;
258    let mut t0 = 0i64;
259    let mut t1 = 1i64;
260
261    while r1 != 0 {
262        let (q, rn) = (r0 / r1, r0 % r1);
263        let sn = s0 - q as i64 * s1;
264        let tn = t0 - q as i64 * t1;
265
266        r0 = r1;
267        r1 = rn;
268        s0 = s1;
269        s1 = sn;
270        t0 = t1;
271        t1 = tn;
272    }
273
274    debug_assert_eq!(MODULUS as i128 * s0 as i128 + b as i128 * t0 as i128, 1);
275
276    (Self(a) * Self::from(t0)).0
277});
278//make_arith_impl!(Rem, RemAssign, rem, rem_assign, %, |a, b| {
279//    a % b
280//});
281
282impl iter::Sum for M61 {
283    #[inline(always)]
284    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
285        iter.fold(Self(0), |a, b| a + b)
286    }
287}
288
289impl<'a> iter::Sum<&'a M61> for M61 {
290    #[inline(always)]
291    fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
292        iter.fold(Self(0), |a, b| a + b)
293    }
294}
295
296impl iter::Product for M61 {
297    #[inline(always)]
298    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
299        iter.fold(Self(1), |a, b| a * b)
300    }
301}
302
303impl<'a> iter::Product<&'a M61> for M61 {
304    #[inline(always)]
305    fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
306        iter.fold(Self(1), |a, b| a * b)
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::M61;
313    use super::MODULUS;
314
315    quickcheck::quickcheck! {
316        fn creation_u64_correct(x: u64) -> bool {
317            let expected = x % MODULUS;
318            let actual = M61::from(x).get();
319            expected == actual
320        }
321
322        fn creation_u128_correct(x: u128) -> bool {
323            let expected = (x % MODULUS as u128) as u64;
324            let actual = M61::from(x).get();
325            expected == actual
326        }
327
328        fn creation_i64_correct(x: i64) -> bool {
329            let expected = x.rem_euclid(MODULUS as i64) as u64;
330            let actual = M61::from(x).get();
331            expected == actual
332        }
333
334        fn creation_i128_correct(x: i128) -> bool {
335            let expected = x.rem_euclid(MODULUS as i128) as u64;
336            let actual = M61::from(x).get();
337            expected == actual
338        }
339
340        fn add_distributive(x: u64, y: u64) -> bool {
341            let x = x >> 1;
342            let y = y >> 1;
343
344            let expected = M61::from(x + y);
345            let actual = M61::from(x) + M61::from(y);
346
347            expected == actual
348        }
349
350        fn sub_distributive(x: u64, y: u64) -> bool {
351            let x = (x >> 1) as i64;
352            let y = (y >> 1) as i64;
353
354            let expected = M61::from(x - y);
355            let actual = M61::from(x) - M61::from(y);
356
357            expected == actual
358        }
359
360        fn mul_distributive(x: u64, y: u64) -> bool {
361            let x = x as u128;
362            let y = y as u128;
363
364            let expected = M61::from(x * y);
365            let actual = M61::from(x) * M61::from(y);
366
367            expected == actual
368        }
369    }
370}