m61_modulus/
definition.rs1use core::fmt;
4use core::iter;
5use core::ops;
6
7pub(crate) const MODULUS: u64 = (1 << 61) - 1;
11
12#[inline(always)]
17pub(crate) fn final_reduction(mut x: u64) -> M61 {
18 if x >= MODULUS {
19 x -= MODULUS;
20 }
21
22 if x >= MODULUS {
23 M61(x - MODULUS)
24 } else {
25 M61(x)
26 }
27}
28
29#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
31#[repr(transparent)]
32pub struct M61(pub(crate) u64);
33
34impl M61 {
35 #[inline(always)]
37 #[must_use]
38 pub const fn get(self) -> u64 {
39 self.0
40 }
41
42 pub fn pow(mut self, mut exp: u64) -> Self {
44 if exp == 0 {
45 return Self(1);
46 }
47 let mut acc = Self(1);
48
49 while exp != 1 {
50 if exp & 1 != 0 {
51 acc *= self;
52 }
53
54 exp /= 2;
55 self = self * self;
56 }
57
58 acc * self
59 }
60}
61
62macro_rules! make_fmt_impl {
64 ($trait:ident) => {
65 impl fmt::$trait for M61 {
66 #[inline(always)]
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 <u64 as fmt::$trait>::fmt(&self.0, f)
69 }
70 }
71 };
72}
73
74make_fmt_impl!(Display);
75make_fmt_impl!(Debug);
76make_fmt_impl!(LowerExp);
77make_fmt_impl!(UpperExp);
78make_fmt_impl!(LowerHex);
79make_fmt_impl!(UpperHex);
80make_fmt_impl!(Octal);
81make_fmt_impl!(Binary);
82
83macro_rules! make_trivial_from {
87 ($type:ty) => {
88 impl From<$type> for M61 {
89 #[inline(always)]
90 fn from(value: $type) -> Self {
91 #[allow(unused_comparisons)]
95 if value < 0 {
96 Self((value as i64 + MODULUS as i64) as u64)
97 } else {
98 Self(value as u64)
99 }
100 }
101 }
102 };
103}
104
105make_trivial_from!(u8);
106make_trivial_from!(u16);
107make_trivial_from!(u32);
108#[cfg(not(target_pointer_width = "64"))]
109make_trivial_from!(usize);
110
111#[cfg(target_pointer_width = "64")]
112impl From<usize> for M61 {
113 #[inline(always)]
114 fn from(value: usize) -> Self {
115 Self::from(value as u64)
116 }
117}
118
119make_trivial_from!(i8);
120make_trivial_from!(i16);
121make_trivial_from!(i32);
122#[cfg(not(target_pointer_width = "64"))]
123make_trivial_from!(isize);
124
125#[cfg(target_pointer_width = "64")]
126impl From<isize> for M61 {
127 #[inline(always)]
128 fn from(value: isize) -> Self {
129 Self::from(value as i64)
130 }
131}
132
133impl From<u64> for M61 {
134 #[inline]
135 fn from(value: u64) -> Self {
136 let tmp = (value & MODULUS) + (value >> 61);
137 if tmp >= MODULUS {
138 Self(tmp - MODULUS)
139 } else {
140 Self(tmp)
141 }
142 }
143}
144
145impl From<i64> for M61 {
146 #[inline]
147 fn from(mut value: i64) -> Self {
148 if value < 0 {
149 value = value.wrapping_add(4 * MODULUS as i64);
150 }
151 if value < 0 {
152 value = value.wrapping_add(MODULUS as i64);
153 }
154
155 Self::from(value as u64)
156 }
157}
158
159impl From<u128> for M61 {
160 #[inline]
161 fn from(value: u128) -> Self {
162 let mut x = value as u64 & MODULUS;
163 x += (value >> 61) as u64 & MODULUS;
164 x += (value >> 122) as u64;
165 Self::from(x)
166 }
167}
168
169impl From<i128> for M61 {
170 #[inline]
171 fn from(mut value: i128) -> Self {
172 while value < 0 {
173 value += 16 * ((1 << 122) - 1);
174 }
175
176 Self::from(value as u128)
177 }
178}
179
180macro_rules! make_arith_impl {
182 ($trait:ident, $trait_assign:ident, $func:ident, $func_assign:ident, $op:tt, $impl:expr) => {
183 impl ops::$trait for M61 {
184 type Output = Self;
185
186 #[inline]
187 fn $func(self, rhs: Self) -> Self::Output {
188 #[allow(clippy::redundant_closure_call)]
189 Self($impl(self.0, rhs.0))
190 }
191 }
192
193 impl<'a> ops::$trait<&'a M61> for M61 {
194 type Output = Self;
195
196 #[inline(always)]
197 fn $func(self, rhs: &Self) -> Self::Output {
198 self $op *rhs
199 }
200 }
201
202 impl ops::$trait_assign for M61 {
203 #[inline(always)]
204 fn $func_assign(&mut self, rhs: Self) {
205 *self = *self $op rhs
206 }
207 }
208
209 impl<'a> ops::$trait_assign<&'a M61> for M61 {
210 #[inline(always)]
211 fn $func_assign(&mut self, rhs: &Self) {
212 *self = *self $op rhs
213 }
214 }
215 };
216}
217
218make_arith_impl!(Add, AddAssign, add, add_assign, +, |a, b| {
219 let x = a + b;
220 if x >= MODULUS {
221 x - MODULUS
222 } else {
223 x
224 }
225});
226make_arith_impl!(Sub, SubAssign, sub, sub_assign, -, |a, b| {
227 let x = a + MODULUS - b;
228 if x >= MODULUS {
229 x - MODULUS
230 } else {
231 x
232 }
233});
234make_arith_impl!(Mul, MulAssign, mul, mul_assign, *, |a, b| {
235 let x = a as u128 * b as u128;
236 let mut hi = (x >> 61) as u64;
237 let mut lo = (x as u64) & MODULUS;
238 lo = lo.wrapping_add(hi);
239 hi = lo.wrapping_sub(MODULUS);
240 if lo < MODULUS {
241 lo
242 } else {
243 hi
244 }
245});
246make_arith_impl!(Div, DivAssign, div, div_assign, /, |a, b| {
247 if b == 0 {
248 panic!("attempt to divide by zero");
249 }
250
251 let mut r0 = MODULUS;
255 let mut r1 = b;
256 let mut s0 = 1i64;
257 let mut s1 = 0i64;
258 let mut t0 = 0i64;
259 let mut t1 = 1i64;
260
261 while r1 != 0 {
262 let (q, rn) = (r0 / r1, r0 % r1);
263 let sn = s0 - q as i64 * s1;
264 let tn = t0 - q as i64 * t1;
265
266 r0 = r1;
267 r1 = rn;
268 s0 = s1;
269 s1 = sn;
270 t0 = t1;
271 t1 = tn;
272 }
273
274 debug_assert_eq!(MODULUS as i128 * s0 as i128 + b as i128 * t0 as i128, 1);
275
276 (Self(a) * Self::from(t0)).0
277});
278impl iter::Sum for M61 {
283 #[inline(always)]
284 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
285 iter.fold(Self(0), |a, b| a + b)
286 }
287}
288
289impl<'a> iter::Sum<&'a M61> for M61 {
290 #[inline(always)]
291 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
292 iter.fold(Self(0), |a, b| a + b)
293 }
294}
295
296impl iter::Product for M61 {
297 #[inline(always)]
298 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
299 iter.fold(Self(1), |a, b| a * b)
300 }
301}
302
303impl<'a> iter::Product<&'a M61> for M61 {
304 #[inline(always)]
305 fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
306 iter.fold(Self(1), |a, b| a * b)
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::M61;
313 use super::MODULUS;
314
315 quickcheck::quickcheck! {
316 fn creation_u64_correct(x: u64) -> bool {
317 let expected = x % MODULUS;
318 let actual = M61::from(x).get();
319 expected == actual
320 }
321
322 fn creation_u128_correct(x: u128) -> bool {
323 let expected = (x % MODULUS as u128) as u64;
324 let actual = M61::from(x).get();
325 expected == actual
326 }
327
328 fn creation_i64_correct(x: i64) -> bool {
329 let expected = x.rem_euclid(MODULUS as i64) as u64;
330 let actual = M61::from(x).get();
331 expected == actual
332 }
333
334 fn creation_i128_correct(x: i128) -> bool {
335 let expected = x.rem_euclid(MODULUS as i128) as u64;
336 let actual = M61::from(x).get();
337 expected == actual
338 }
339
340 fn add_distributive(x: u64, y: u64) -> bool {
341 let x = x >> 1;
342 let y = y >> 1;
343
344 let expected = M61::from(x + y);
345 let actual = M61::from(x) + M61::from(y);
346
347 expected == actual
348 }
349
350 fn sub_distributive(x: u64, y: u64) -> bool {
351 let x = (x >> 1) as i64;
352 let y = (y >> 1) as i64;
353
354 let expected = M61::from(x - y);
355 let actual = M61::from(x) - M61::from(y);
356
357 expected == actual
358 }
359
360 fn mul_distributive(x: u64, y: u64) -> bool {
361 let x = x as u128;
362 let y = y as u128;
363
364 let expected = M61::from(x * y);
365 let actual = M61::from(x) * M61::from(y);
366
367 expected == actual
368 }
369 }
370}