1use crate::reduced::{impl_reduced_binary_pow, impl_reduced_ops};
2use crate::{powm_u32, powm_u64, ModularUnaryOps, Reducer, Vanilla};
3
4pub(crate) mod neg_mod_inv {
7 #[rustfmt::skip]
9 const BINV_TABLE: [u8; 128] = [
10 0x01, 0xAB, 0xCD, 0xB7, 0x39, 0xA3, 0xC5, 0xEF, 0xF1, 0x1B, 0x3D, 0xA7, 0x29, 0x13, 0x35, 0xDF,
11 0xE1, 0x8B, 0xAD, 0x97, 0x19, 0x83, 0xA5, 0xCF, 0xD1, 0xFB, 0x1D, 0x87, 0x09, 0xF3, 0x15, 0xBF,
12 0xC1, 0x6B, 0x8D, 0x77, 0xF9, 0x63, 0x85, 0xAF, 0xB1, 0xDB, 0xFD, 0x67, 0xE9, 0xD3, 0xF5, 0x9F,
13 0xA1, 0x4B, 0x6D, 0x57, 0xD9, 0x43, 0x65, 0x8F, 0x91, 0xBB, 0xDD, 0x47, 0xC9, 0xB3, 0xD5, 0x7F,
14 0x81, 0x2B, 0x4D, 0x37, 0xB9, 0x23, 0x45, 0x6F, 0x71, 0x9B, 0xBD, 0x27, 0xA9, 0x93, 0xB5, 0x5F,
15 0x61, 0x0B, 0x2D, 0x17, 0x99, 0x03, 0x25, 0x4F, 0x51, 0x7B, 0x9D, 0x07, 0x89, 0x73, 0x95, 0x3F,
16 0x41, 0xEB, 0x0D, 0xF7, 0x79, 0xE3, 0x05, 0x2F, 0x31, 0x5B, 0x7D, 0xE7, 0x69, 0x53, 0x75, 0x1F,
17 0x21, 0xCB, 0xED, 0xD7, 0x59, 0xC3, 0xE5, 0x0F, 0x11, 0x3B, 0x5D, 0xC7, 0x49, 0x33, 0x55, 0xFF,
18 ];
19
20 pub mod u8 {
21 use super::*;
22 pub const fn neginv(m: u8) -> u8 {
23 let i = BINV_TABLE[((m >> 1) & 0x7F) as usize];
24 i.wrapping_neg()
25 }
26 }
27
28 pub mod u16 {
29 use super::*;
30 pub const fn neginv(m: u16) -> u16 {
31 let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u16;
32 i = 2u16.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
34 i.wrapping_neg()
35 }
36 }
37
38 pub mod u32 {
39 use super::*;
40 pub const fn neginv(m: u32) -> u32 {
41 let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u32;
42 i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
43 i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
44 i.wrapping_neg()
45 }
46 }
47
48 pub mod u64 {
49 use super::*;
50 pub const fn neginv(m: u64) -> u64 {
51 let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u64;
52 i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
53 i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
54 i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
55 i.wrapping_neg()
56 }
57 }
58
59 pub mod u128 {
60 use super::*;
61 pub const fn neginv(m: u128) -> u128 {
62 let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u128;
63 i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
64 i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
65 i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
66 i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
67 i.wrapping_neg()
68 }
69 }
70
71 pub mod usize {
72 #[inline]
73 pub const fn neginv(m: usize) -> usize {
74 #[cfg(target_pointer_width = "16")]
75 return super::u16::neginv(m as _) as _;
76 #[cfg(target_pointer_width = "32")]
77 return super::u32::neginv(m as _) as _;
78 #[cfg(target_pointer_width = "64")]
79 return super::u64::neginv(m as _) as _;
80 }
81 }
82}
83
84#[must_use]
90#[derive(Debug, Clone, Copy)]
91pub struct Montgomery<T> {
92 m: T, inv: T, }
95
96macro_rules! impl_montgomery_for {
97 ($t:ident, $ns:ident) => {
98 mod $ns {
99 use super::*;
100 use crate::word::$t::*;
101 use neg_mod_inv::$t::neginv;
102
103 impl Montgomery<$t> {
104 pub const fn new(m: $t) -> Self {
105 assert!(
106 m & 1 != 0,
107 "Only odd moduli are supported by the Montgomery form"
108 );
109 Self { m, inv: neginv(m) }
110 }
111 const fn reduce(&self, monty: DoubleWord) -> $t {
112 debug_assert!(high(monty) < self.m);
113
114 let tm = low(monty).wrapping_mul(self.inv);
116 let (t, overflow) = monty.overflowing_add(wmul(tm, self.m));
117 let t = high(t);
118
119 if overflow {
120 t + self.m.wrapping_neg()
121 } else if t >= self.m {
122 t - self.m
123 } else {
124 t
125 }
126 }
127 }
128
129 impl Reducer<$t> for Montgomery<$t> {
130 #[inline]
131 fn new(m: &$t) -> Self {
132 Self::new(*m)
133 }
134 #[inline]
135 fn transform(&self, target: $t) -> $t {
136 if target == 0 {
137 return 0;
138 }
139 nrem(merge(0, target), self.m)
140 }
141 #[inline]
142 fn check(&self, target: &$t) -> bool {
143 *target < self.m
144 }
145
146 #[inline]
147 fn residue(&self, target: $t) -> $t {
148 self.reduce(extend(target))
149 }
150 #[inline(always)]
151 fn modulus(&self) -> $t {
152 self.m
153 }
154 #[inline(always)]
155 fn is_zero(&self, target: &$t) -> bool {
156 *target == 0
157 }
158
159 #[inline(always)]
160 fn add(&self, lhs: &$t, rhs: &$t) -> $t {
161 Vanilla::<$t>::add(&self.m, *lhs, *rhs)
162 }
163
164 #[inline(always)]
165 fn dbl(&self, target: $t) -> $t {
166 Vanilla::<$t>::dbl(&self.m, target)
167 }
168
169 #[inline(always)]
170 fn sub(&self, lhs: &$t, rhs: &$t) -> $t {
171 Vanilla::<$t>::sub(&self.m, *lhs, *rhs)
172 }
173
174 #[inline(always)]
175 fn neg(&self, target: $t) -> $t {
176 Vanilla::<$t>::neg(&self.m, target)
177 }
178
179 #[inline]
180 fn mul(&self, lhs: &$t, rhs: &$t) -> $t {
181 self.reduce(wmul(*lhs, *rhs))
182 }
183
184 #[inline]
185 fn sqr(&self, target: $t) -> $t {
186 self.reduce(wsqr(target))
187 }
188
189 #[inline(always)]
190 fn inv(&self, target: $t) -> Option<$t> {
191 self.residue(target)
194 .invm(&self.m)
195 .map(|v| self.transform(v))
196 }
197
198 impl_reduced_binary_pow!(Word);
199 }
200 }
201 };
202}
203impl_montgomery_for!(u8, u8_impl);
204impl_montgomery_for!(u16, u16_impl);
205impl_montgomery_for!(u32, u32_impl);
206impl_montgomery_for!(u64, u64_impl);
207impl_montgomery_for!(u128, u128_impl);
208impl_montgomery_for!(usize, usize_impl);
209
210#[macro_export]
217macro_rules! impl_fixed_monty_ops {
218 ($T:ty, $D:ty, $r2:expr, primitive) => {
220 #[inline]
221 fn transform(&self, target: $T) -> $T {
222 if target == 0 {
223 return 0;
224 }
225 self.reduce((target as $D) * ($r2 as $D))
226 }
227 #[inline]
228 fn residue(&self, target: $T) -> $T {
229 if target == 0 {
230 return 0;
231 }
232 self.reduce(target as $D)
233 }
234
235 impl_reduced_ops!($T);
236
237 #[inline]
238 fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
239 self.reduce((*lhs as $D) * (*rhs as $D))
240 }
241 #[inline]
242 fn sqr(&self, target: $T) -> $T {
243 self.reduce((target as $D) * (target as $D))
244 }
245 #[inline]
246 fn inv(&self, target: $T) -> Option<$T> {
247 let plain = self.residue(target);
248 let inv_plain = plain.invm(&Self::MODULUS)?;
249 if inv_plain == 0 {
250 return Some(0);
251 }
252 Some(self.reduce((inv_plain as $D) * ($r2 as $D)))
253 }
254
255 impl_reduced_binary_pow!($T);
256 };
257 ($T:ty, $D:ty, $r2:expr, udouble) => {
259 #[inline]
260 fn transform(&self, target: $T) -> $T {
261 if target == 0 {
262 return 0;
263 }
264 self.reduce(udouble::widening_mul(target, $r2))
265 }
266 #[inline]
267 fn residue(&self, target: $T) -> $T {
268 if target == 0 {
269 return 0;
270 }
271 self.reduce(udouble { hi: 0, lo: target })
272 }
273
274 impl_reduced_ops!($T);
275
276 #[inline]
277 fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
278 self.reduce(udouble::widening_mul(*lhs, *rhs))
279 }
280 #[inline]
281 fn sqr(&self, target: $T) -> $T {
282 self.reduce(udouble::widening_square(target))
283 }
284 #[inline]
285 fn inv(&self, target: $T) -> Option<$T> {
286 let plain = self.residue(target);
287 let inv_plain = plain.invm(&Self::MODULUS)?;
288 if inv_plain == 0 {
289 return Some(0);
290 }
291 Some(self.reduce(udouble::widening_mul(inv_plain, $r2)))
292 }
293
294 impl_reduced_binary_pow!($T);
295 };
296}
297
298macro_rules! impl_fixed_montgomery_inherent {
304 ($TypeName:ident, $T:ty, $D:ty, $neginv_fn:path, $powm:ident) => {
305 impl<const P: $T> $TypeName<P> {
306 pub const MODULUS: $T = P;
307
308 const N0: $T = $neginv_fn(P);
310
311 const R2: $T = $powm(2, (2 * <$T>::BITS) as $T, P);
313
314 #[inline]
315 const fn reduce(&self, monty: $D) -> $T {
316 let tm = (monty as $T).wrapping_mul(Self::N0);
317 let (t, overflow) = monty.overflowing_add((tm as $D) * (Self::MODULUS as $D));
318 let t = (t >> <$T>::BITS) as $T;
319 if overflow {
320 t.wrapping_add(Self::MODULUS.wrapping_neg())
321 } else if t >= Self::MODULUS {
322 t - Self::MODULUS
323 } else {
324 t
325 }
326 }
327 }
328 };
329}
330
331#[must_use]
346#[derive(Debug, Clone, Copy)]
347pub struct FixedMontgomery32<const P: u32>;
348
349impl_fixed_montgomery_inherent!(
350 FixedMontgomery32,
351 u32,
352 u64,
353 neg_mod_inv::u32::neginv,
354 powm_u32
355);
356
357impl<const P: u32> Reducer<u32> for FixedMontgomery32<P> {
358 #[inline]
359 fn new(m: &u32) -> Self {
360 assert!(*m == P, "modulus does not match const generic parameter");
361 assert!(
362 P & 1 != 0,
363 "only odd moduli are supported by the Montgomery form"
364 );
365 Self {}
366 }
367 impl_fixed_monty_ops!(u32, u64, Self::R2, primitive);
368}
369
370#[must_use]
385#[derive(Debug, Clone, Copy)]
386pub struct FixedMontgomery64<const P: u64>;
387
388impl_fixed_montgomery_inherent!(
389 FixedMontgomery64,
390 u64,
391 u128,
392 neg_mod_inv::u64::neginv,
393 powm_u64
394);
395
396impl<const P: u64> Reducer<u64> for FixedMontgomery64<P> {
397 #[inline]
398 fn new(m: &u64) -> Self {
399 assert!(*m == P, "modulus does not match const generic parameter");
400 assert!(
401 P & 1 != 0,
402 "only odd moduli are supported by the Montgomery form"
403 );
404 Self {}
405 }
406 impl_fixed_monty_ops!(u64, u128, Self::R2, primitive);
407}
408
409#[cfg(test)]
415mod tests {
416 use super::*;
417 use rand::random;
418
419 const NRANDOM: u32 = 10;
420
421 #[test]
422 fn creation_test() {
423 let a = (0x81u128 << 120) - 1;
425 let m = (0x81u128 << 119) - 1;
426 let m = m >> m.trailing_zeros();
427 let r = Montgomery::<u128>::new(m);
428 assert_eq!(r.residue(r.transform(a)), a % m);
429
430 let r = Montgomery::<u8>::new(11u8);
432 assert!(r.is_zero(&r.transform(0)));
433 let five = r.transform(5u8);
434 let six = r.transform(6u8);
435 assert!(r.is_zero(&r.add(&five, &six)));
436
437 for _ in 0..NRANDOM {
439 let a = random::<u8>();
440 let m = random::<u8>() | 1;
441 let r = Montgomery::<u8>::new(m);
442 assert_eq!(r.residue(r.transform(a)), a % m);
443
444 let a = random::<u16>();
445 let m = random::<u16>() | 1;
446 let r = Montgomery::<u16>::new(m);
447 assert_eq!(r.residue(r.transform(a)), a % m);
448
449 let a = random::<u32>();
450 let m = random::<u32>() | 1;
451 let r = Montgomery::<u32>::new(m);
452 assert_eq!(r.residue(r.transform(a)), a % m);
453
454 let a = random::<u64>();
455 let m = random::<u64>() | 1;
456 let r = Montgomery::<u64>::new(m);
457 assert_eq!(r.residue(r.transform(a)), a % m);
458
459 let a = random::<u128>();
460 let m = random::<u128>() | 1;
461 let r = Montgomery::<u128>::new(m);
462 assert_eq!(r.residue(r.transform(a)), a % m);
463 }
464 }
465
466 #[test]
467 fn test_against_modops() {
468 use crate::reduced::tests::ReducedTester;
469 for _ in 0..NRANDOM {
470 ReducedTester::<u8>::test_against_modops::<Montgomery<u8>>(1);
471 ReducedTester::<u16>::test_against_modops::<Montgomery<u16>>(1);
472 ReducedTester::<u32>::test_against_modops::<Montgomery<u32>>(1);
473 ReducedTester::<u64>::test_against_modops::<Montgomery<u64>>(1);
474 ReducedTester::<u128>::test_against_modops::<Montgomery<u128>>(1);
475 ReducedTester::<usize>::test_against_modops::<Montgomery<usize>>(1);
476 }
477 }
478}