Skip to main content

num_modular/
monty.rs

1use crate::reduced::{impl_reduced_binary_pow, impl_reduced_ops};
2use crate::{powm_u32, powm_u64, ModularUnaryOps, Reducer, Vanilla};
3
4/// Negated modular inverse on binary bases
5/// `neginv` calculates `-(m^-1) mod R`, `R = 2^k. If m is odd, then result of m + 1 will be returned.
6pub(crate) mod neg_mod_inv {
7    // Entry i contains (2i+1)^(-1) mod 256.
8    #[rustfmt::skip]
9    const BINV_TABLE: [u8; 128] = [
10        0x01, 0xAB, 0xCD, 0xB7, 0x39, 0xA3, 0xC5, 0xEF, 0xF1, 0x1B, 0x3D, 0xA7, 0x29, 0x13, 0x35, 0xDF,
11        0xE1, 0x8B, 0xAD, 0x97, 0x19, 0x83, 0xA5, 0xCF, 0xD1, 0xFB, 0x1D, 0x87, 0x09, 0xF3, 0x15, 0xBF,
12        0xC1, 0x6B, 0x8D, 0x77, 0xF9, 0x63, 0x85, 0xAF, 0xB1, 0xDB, 0xFD, 0x67, 0xE9, 0xD3, 0xF5, 0x9F,
13        0xA1, 0x4B, 0x6D, 0x57, 0xD9, 0x43, 0x65, 0x8F, 0x91, 0xBB, 0xDD, 0x47, 0xC9, 0xB3, 0xD5, 0x7F,
14        0x81, 0x2B, 0x4D, 0x37, 0xB9, 0x23, 0x45, 0x6F, 0x71, 0x9B, 0xBD, 0x27, 0xA9, 0x93, 0xB5, 0x5F,
15        0x61, 0x0B, 0x2D, 0x17, 0x99, 0x03, 0x25, 0x4F, 0x51, 0x7B, 0x9D, 0x07, 0x89, 0x73, 0x95, 0x3F,
16        0x41, 0xEB, 0x0D, 0xF7, 0x79, 0xE3, 0x05, 0x2F, 0x31, 0x5B, 0x7D, 0xE7, 0x69, 0x53, 0x75, 0x1F,
17        0x21, 0xCB, 0xED, 0xD7, 0x59, 0xC3, 0xE5, 0x0F, 0x11, 0x3B, 0x5D, 0xC7, 0x49, 0x33, 0x55, 0xFF,
18    ];
19
20    pub mod u8 {
21        use super::*;
22        pub const fn neginv(m: u8) -> u8 {
23            let i = BINV_TABLE[((m >> 1) & 0x7F) as usize];
24            i.wrapping_neg()
25        }
26    }
27
28    pub mod u16 {
29        use super::*;
30        pub const fn neginv(m: u16) -> u16 {
31            let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u16;
32            // hensel lifting
33            i = 2u16.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
34            i.wrapping_neg()
35        }
36    }
37
38    pub mod u32 {
39        use super::*;
40        pub const fn neginv(m: u32) -> u32 {
41            let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u32;
42            i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
43            i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
44            i.wrapping_neg()
45        }
46    }
47
48    pub mod u64 {
49        use super::*;
50        pub const fn neginv(m: u64) -> u64 {
51            let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u64;
52            i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
53            i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
54            i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
55            i.wrapping_neg()
56        }
57    }
58
59    pub mod u128 {
60        use super::*;
61        pub const fn neginv(m: u128) -> u128 {
62            let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u128;
63            i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
64            i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
65            i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
66            i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i);
67            i.wrapping_neg()
68        }
69    }
70
71    pub mod usize {
72        #[inline]
73        pub const fn neginv(m: usize) -> usize {
74            #[cfg(target_pointer_width = "16")]
75            return super::u16::neginv(m as _) as _;
76            #[cfg(target_pointer_width = "32")]
77            return super::u32::neginv(m as _) as _;
78            #[cfg(target_pointer_width = "64")]
79            return super::u64::neginv(m as _) as _;
80        }
81    }
82}
83
84/// A modular reducer based on [Montgomery form](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#Montgomery_form), only supports odd modulus.
85///
86/// The generic type T represents the underlying integer representation for modular inverse `-m^-1 mod R`,
87/// and `R=2^B` will be used as the auxiliary modulus, where B is automatically selected
88/// based on the size of T.
89#[must_use]
90#[derive(Debug, Clone, Copy)]
91pub struct Montgomery<T> {
92    m: T,   // modulus
93    inv: T, // modular inverse of the modulus
94}
95
96macro_rules! impl_montgomery_for {
97    ($t:ident, $ns:ident) => {
98        mod $ns {
99            use super::*;
100            use crate::word::$t::*;
101            use neg_mod_inv::$t::neginv;
102
103            impl Montgomery<$t> {
104                pub const fn new(m: $t) -> Self {
105                    assert!(
106                        m & 1 != 0,
107                        "Only odd moduli are supported by the Montgomery form"
108                    );
109                    Self { m, inv: neginv(m) }
110                }
111                const fn reduce(&self, monty: DoubleWord) -> $t {
112                    debug_assert!(high(monty) < self.m);
113
114                    // REDC algorithm
115                    let tm = low(monty).wrapping_mul(self.inv);
116                    let (t, overflow) = monty.overflowing_add(wmul(tm, self.m));
117                    let t = high(t);
118
119                    if overflow {
120                        t + self.m.wrapping_neg()
121                    } else if t >= self.m {
122                        t - self.m
123                    } else {
124                        t
125                    }
126                }
127            }
128
129            impl Reducer<$t> for Montgomery<$t> {
130                #[inline]
131                fn new(m: &$t) -> Self {
132                    Self::new(*m)
133                }
134                #[inline]
135                fn transform(&self, target: $t) -> $t {
136                    if target == 0 {
137                        return 0;
138                    }
139                    nrem(merge(0, target), self.m)
140                }
141                #[inline]
142                fn check(&self, target: &$t) -> bool {
143                    *target < self.m
144                }
145
146                #[inline]
147                fn residue(&self, target: $t) -> $t {
148                    self.reduce(extend(target))
149                }
150                #[inline(always)]
151                fn modulus(&self) -> $t {
152                    self.m
153                }
154                #[inline(always)]
155                fn is_zero(&self, target: &$t) -> bool {
156                    *target == 0
157                }
158
159                #[inline(always)]
160                fn add(&self, lhs: &$t, rhs: &$t) -> $t {
161                    Vanilla::<$t>::add(&self.m, *lhs, *rhs)
162                }
163
164                #[inline(always)]
165                fn dbl(&self, target: $t) -> $t {
166                    Vanilla::<$t>::dbl(&self.m, target)
167                }
168
169                #[inline(always)]
170                fn sub(&self, lhs: &$t, rhs: &$t) -> $t {
171                    Vanilla::<$t>::sub(&self.m, *lhs, *rhs)
172                }
173
174                #[inline(always)]
175                fn neg(&self, target: $t) -> $t {
176                    Vanilla::<$t>::neg(&self.m, target)
177                }
178
179                #[inline]
180                fn mul(&self, lhs: &$t, rhs: &$t) -> $t {
181                    self.reduce(wmul(*lhs, *rhs))
182                }
183
184                #[inline]
185                fn sqr(&self, target: $t) -> $t {
186                    self.reduce(wsqr(target))
187                }
188
189                #[inline(always)]
190                fn inv(&self, target: $t) -> Option<$t> {
191                    // TODO: support direct montgomery inverse
192                    // REF: http://cetinkayakoc.net/docs/j82.pdf
193                    self.residue(target)
194                        .invm(&self.m)
195                        .map(|v| self.transform(v))
196                }
197
198                impl_reduced_binary_pow!(Word);
199            }
200        }
201    };
202}
203impl_montgomery_for!(u8, u8_impl);
204impl_montgomery_for!(u16, u16_impl);
205impl_montgomery_for!(u32, u32_impl);
206impl_montgomery_for!(u64, u64_impl);
207impl_montgomery_for!(u128, u128_impl);
208impl_montgomery_for!(usize, usize_impl);
209
210// ── Shared Reducer boilerplate ────────────────────────────────────────────
211
212/// Generates all `Reducer` methods except `new()`.
213///
214/// The caller must provide an inherent `fn reduce(&self, monty: $D) -> $T`
215/// and the associated constants `MODULUS` / `R2`.
216#[macro_export]
217macro_rules! impl_fixed_monty_ops {
218    // Primitive widening: uses `as $D` casts
219    ($T:ty, $D:ty, $r2:expr, primitive) => {
220        #[inline]
221        fn transform(&self, target: $T) -> $T {
222            if target == 0 {
223                return 0;
224            }
225            self.reduce((target as $D) * ($r2 as $D))
226        }
227        #[inline]
228        fn residue(&self, target: $T) -> $T {
229            if target == 0 {
230                return 0;
231            }
232            self.reduce(target as $D)
233        }
234
235        impl_reduced_ops!($T);
236
237        #[inline]
238        fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
239            self.reduce((*lhs as $D) * (*rhs as $D))
240        }
241        #[inline]
242        fn sqr(&self, target: $T) -> $T {
243            self.reduce((target as $D) * (target as $D))
244        }
245        #[inline]
246        fn inv(&self, target: $T) -> Option<$T> {
247            let plain = self.residue(target);
248            let inv_plain = plain.invm(&Self::MODULUS)?;
249            if inv_plain == 0 {
250                return Some(0);
251            }
252            Some(self.reduce((inv_plain as $D) * ($r2 as $D)))
253        }
254
255        impl_reduced_binary_pow!($T);
256    };
257    // udouble widening: uses udouble::widening_mul / widening_square
258    ($T:ty, $D:ty, $r2:expr, udouble) => {
259        #[inline]
260        fn transform(&self, target: $T) -> $T {
261            if target == 0 {
262                return 0;
263            }
264            self.reduce(udouble::widening_mul(target, $r2))
265        }
266        #[inline]
267        fn residue(&self, target: $T) -> $T {
268            if target == 0 {
269                return 0;
270            }
271            self.reduce(udouble { hi: 0, lo: target })
272        }
273
274        impl_reduced_ops!($T);
275
276        #[inline]
277        fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
278            self.reduce(udouble::widening_mul(*lhs, *rhs))
279        }
280        #[inline]
281        fn sqr(&self, target: $T) -> $T {
282            self.reduce(udouble::widening_square(target))
283        }
284        #[inline]
285        fn inv(&self, target: $T) -> Option<$T> {
286            let plain = self.residue(target);
287            let inv_plain = plain.invm(&Self::MODULUS)?;
288            if inv_plain == 0 {
289                return Some(0);
290            }
291            Some(self.reduce(udouble::widening_mul(inv_plain, $r2)))
292        }
293
294        impl_reduced_binary_pow!($T);
295    };
296}
297
298// ── FixedMontgomery32 / FixedMontgomery64 ──────────────────────────────────
299
300/// Const-generic Montgomery reducer, with modulus `P` known at compile time.
301///
302/// Precomputes N0 and R² as associated constants.  ZST — no runtime state.
303macro_rules! impl_fixed_montgomery_inherent {
304    ($TypeName:ident, $T:ty, $D:ty, $neginv_fn:path, $powm:ident) => {
305        impl<const P: $T> $TypeName<P> {
306            pub const MODULUS: $T = P;
307
308            /// Montgomery constant:  -P⁻¹ mod 2^BITS
309            const N0: $T = $neginv_fn(P);
310
311            /// R² mod P  (R = 2^BITS, so R² = 2^{2·BITS})
312            const R2: $T = $powm(2, (2 * <$T>::BITS) as $T, P);
313
314            #[inline]
315            const fn reduce(&self, monty: $D) -> $T {
316                let tm = (monty as $T).wrapping_mul(Self::N0);
317                let (t, overflow) = monty.overflowing_add((tm as $D) * (Self::MODULUS as $D));
318                let t = (t >> <$T>::BITS) as $T;
319                if overflow {
320                    t.wrapping_add(Self::MODULUS.wrapping_neg())
321                } else if t >= Self::MODULUS {
322                    t - Self::MODULUS
323                } else {
324                    t
325                }
326            }
327        }
328    };
329}
330
331/// A modular reducer based on [Montgomery form](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#Montgomery_form)
332/// for 32-bit operands, with the modulus `P` known at compile time.
333///
334/// # Example
335///
336/// ```rust
337/// use num_modular::{FixedMontgomery32, Reducer};
338///
339/// const P: u32 = 17;
340/// let reducer = FixedMontgomery32::<P>::new(&P);
341/// let a = reducer.transform(3);
342/// let b = reducer.transform(5);
343/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), 15);
344/// ```
345#[must_use]
346#[derive(Debug, Clone, Copy)]
347pub struct FixedMontgomery32<const P: u32>;
348
349impl_fixed_montgomery_inherent!(
350    FixedMontgomery32,
351    u32,
352    u64,
353    neg_mod_inv::u32::neginv,
354    powm_u32
355);
356
357impl<const P: u32> Reducer<u32> for FixedMontgomery32<P> {
358    #[inline]
359    fn new(m: &u32) -> Self {
360        assert!(*m == P, "modulus does not match const generic parameter");
361        assert!(
362            P & 1 != 0,
363            "only odd moduli are supported by the Montgomery form"
364        );
365        Self {}
366    }
367    impl_fixed_monty_ops!(u32, u64, Self::R2, primitive);
368}
369
370/// A modular reducer based on [Montgomery form](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#Montgomery_form)
371/// for 64-bit operands, with the modulus `P` known at compile time.
372///
373/// # Example
374///
375/// ```rust
376/// use num_modular::{FixedMontgomery64, Reducer};
377///
378/// const P: u64 = 97;
379/// let reducer = FixedMontgomery64::<P>::new(&P);
380/// let a = reducer.transform(10);
381/// let b = reducer.transform(20);
382/// assert_eq!(reducer.residue(reducer.mul(&a, &b)), (10u64 * 20) % 97);
383/// ```
384#[must_use]
385#[derive(Debug, Clone, Copy)]
386pub struct FixedMontgomery64<const P: u64>;
387
388impl_fixed_montgomery_inherent!(
389    FixedMontgomery64,
390    u64,
391    u128,
392    neg_mod_inv::u64::neginv,
393    powm_u64
394);
395
396impl<const P: u64> Reducer<u64> for FixedMontgomery64<P> {
397    #[inline]
398    fn new(m: &u64) -> Self {
399        assert!(*m == P, "modulus does not match const generic parameter");
400        assert!(
401            P & 1 != 0,
402            "only odd moduli are supported by the Montgomery form"
403        );
404        Self {}
405    }
406    impl_fixed_monty_ops!(u64, u128, Self::R2, primitive);
407}
408
409// TODO(v0.6.x): accept even numbers by removing 2 factors from m and store the exponent
410// Requirement: 1. A separate class to perform modular arithmetics with 2^n as modulus
411//              2. Algorithm for construct residue from two components (see http://koclab.cs.ucsb.edu/teaching/cs154/docx/Notes7-Montgomery.pdf)
412// Or we can just provide crt function, and let the implementation of monty int with full modulus support as an example code.
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use rand::random;
418
419    const NRANDOM: u32 = 10;
420
421    #[test]
422    fn creation_test() {
423        // a deterministic test case for u128
424        let a = (0x81u128 << 120) - 1;
425        let m = (0x81u128 << 119) - 1;
426        let m = m >> m.trailing_zeros();
427        let r = Montgomery::<u128>::new(m);
428        assert_eq!(r.residue(r.transform(a)), a % m);
429
430        // is_zero test
431        let r = Montgomery::<u8>::new(11u8);
432        assert!(r.is_zero(&r.transform(0)));
433        let five = r.transform(5u8);
434        let six = r.transform(6u8);
435        assert!(r.is_zero(&r.add(&five, &six)));
436
437        // random creation test
438        for _ in 0..NRANDOM {
439            let a = random::<u8>();
440            let m = random::<u8>() | 1;
441            let r = Montgomery::<u8>::new(m);
442            assert_eq!(r.residue(r.transform(a)), a % m);
443
444            let a = random::<u16>();
445            let m = random::<u16>() | 1;
446            let r = Montgomery::<u16>::new(m);
447            assert_eq!(r.residue(r.transform(a)), a % m);
448
449            let a = random::<u32>();
450            let m = random::<u32>() | 1;
451            let r = Montgomery::<u32>::new(m);
452            assert_eq!(r.residue(r.transform(a)), a % m);
453
454            let a = random::<u64>();
455            let m = random::<u64>() | 1;
456            let r = Montgomery::<u64>::new(m);
457            assert_eq!(r.residue(r.transform(a)), a % m);
458
459            let a = random::<u128>();
460            let m = random::<u128>() | 1;
461            let r = Montgomery::<u128>::new(m);
462            assert_eq!(r.residue(r.transform(a)), a % m);
463        }
464    }
465
466    #[test]
467    fn test_against_modops() {
468        use crate::reduced::tests::ReducedTester;
469        for _ in 0..NRANDOM {
470            ReducedTester::<u8>::test_against_modops::<Montgomery<u8>>(1);
471            ReducedTester::<u16>::test_against_modops::<Montgomery<u16>>(1);
472            ReducedTester::<u32>::test_against_modops::<Montgomery<u32>>(1);
473            ReducedTester::<u64>::test_against_modops::<Montgomery<u64>>(1);
474            ReducedTester::<u128>::test_against_modops::<Montgomery<u128>>(1);
475            ReducedTester::<usize>::test_against_modops::<Montgomery<usize>>(1);
476        }
477    }
478}