Skip to main content

num_prime/
mint.rs

1//! Wrapper of integer to makes it efficient in modular arithmetics but still have the same
2//! API of normal integers.
3
4use core::ops::{Add, Div, Mul, Neg, Rem, Shr, Sub};
5use either::{Either, Left, Right};
6use num_integer::{Integer, Roots};
7use num_modular::{
8    ModularCoreOps, ModularInteger, ModularPow, ModularSymbols, ModularUnaryOps, Montgomery,
9    ReducedInt, Reducer,
10};
11use num_traits::{FromPrimitive, Num, One, Pow, ToPrimitive, Zero};
12
13use crate::{BitTest, ExactRoots};
14
15/// Integer with fast modular arithmetics support, based on [`num_modular::MontgomeryInt`] under the hood
16///
17/// This struct only designed to be working with this crate. Most binary operators assume that
18/// the modulus of two operands (when in montgomery form) are the same, and most implicit conversions
19/// between conventional form and montgomery form will be forbidden
20#[derive(Debug, Clone, Copy)]
21pub struct Mint<T: Integer, R: Reducer<T>>(Either<T, ReducedInt<T, R>>);
22
23impl<T: Integer, R: Reducer<T>> From<T> for Mint<T, R> {
24    #[inline(always)]
25    fn from(v: T) -> Self {
26        Self(Left(v))
27    }
28}
29impl<T: Integer, R: Reducer<T>> From<ReducedInt<T, R>> for Mint<T, R> {
30    #[inline(always)]
31    fn from(v: ReducedInt<T, R>) -> Self {
32        Self(Right(v))
33    }
34}
35
36#[inline(always)]
37fn left_only<T: Integer, R: Reducer<T>>(lhs: Mint<T, R>, rhs: Mint<T, R>) -> (T, T) {
38    match (lhs.0, rhs.0) {
39        (Left(v1), Left(v2)) => (v1, v2),
40        (_, _) => unreachable!(),
41    }
42}
43
44#[inline(always)]
45fn left_ref_only<'a, T: Integer, R: Reducer<T>>(
46    lhs: &'a Mint<T, R>,
47    rhs: &'a Mint<T, R>,
48) -> (&'a T, &'a T) {
49    match (&lhs.0, &rhs.0) {
50        (Left(v1), Left(v2)) => (v1, v2),
51        (_, _) => unreachable!(),
52    }
53}
54
55macro_rules! forward_binops_left_ref_only {
56    ($method:ident) => {
57        #[inline(always)]
58        fn $method(&self, other: &Self) -> Self {
59            let (v1, v2) = left_ref_only(self, other);
60            Self(Left(v1.$method(v2)))
61        }
62    };
63    ($method:ident => $return:ty) => {
64        #[inline(always)]
65        fn $method(&self, other: &Self) -> $return {
66            let (v1, v2) = left_ref_only(self, other);
67            v1.$method(v2)
68        }
69    };
70}
71
72macro_rules! forward_uops_ref {
73    ($method:ident => $return:ty) => {
74        #[inline(always)]
75        fn $method(&self) -> $return {
76            match &self.0 {
77                Left(v) => v.$method(),
78                Right(m) => m.residue().$method(),
79            }
80        }
81    };
82}
83
84impl<T: Integer + Clone, R: Reducer<T>> PartialEq for Mint<T, R> {
85    fn eq(&self, other: &Self) -> bool {
86        match (&self.0, &other.0) {
87            (Left(v1), Left(v2)) => v1 == v2,
88            (Right(v1), Right(v2)) => v1 == v2,
89            (_, _) => unreachable!(), // force optimization of equality test
90        }
91    }
92}
93impl<T: Integer + Clone, R: Reducer<T>> Eq for Mint<T, R> {}
94
95impl<T: Integer + Clone, R: Reducer<T> + Clone> PartialOrd for Mint<T, R> {
96    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
97        Some(self.cmp(other))
98    }
99}
100impl<T: Integer + Clone, R: Reducer<T> + Clone> Ord for Mint<T, R> {
101    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
102        match (&self.0, &other.0) {
103            (Left(v1), Left(v2)) => v1.cmp(v2),
104            (Left(v1), Right(v2)) => v1.cmp(&v2.residue()),
105            (Right(v1), Left(v2)) => v1.residue().cmp(v2),
106            (Right(v1), Right(v2)) => v1.residue().cmp(&v2.residue()),
107        }
108    }
109}
110
111impl<T: Integer + Clone, R: Reducer<T> + Clone> Mint<T, R> {
112    #[inline(always)]
113    pub fn value(&self) -> T {
114        match &self.0 {
115            Left(v) => v.clone(),
116            Right(m) => m.residue(),
117        }
118    }
119}
120
121// forward binary operators by converting result to MontgomeryInt whenever possible
122macro_rules! forward_binops_right {
123    (impl $imp:ident, $method:ident) => {
124        impl<T: Integer + Clone, R: Reducer<T> + Clone> $imp for Mint<T, R> {
125            type Output = Self;
126            #[inline]
127            fn $method(self, rhs: Self) -> Self::Output {
128                Self(match (self.0, rhs.0) {
129                    (Left(v1), Left(v2)) => Left(v1.$method(v2)),
130                    (Left(v1), Right(v2)) => Right(v2.convert(v1).$method(v2)),
131                    (Right(v1), Left(v2)) => {
132                        let v2 = v1.convert(v2);
133                        Right(v1.$method(v2))
134                    }
135                    (Right(v1), Right(v2)) => Right(v1.$method(v2)),
136                })
137            }
138        }
139
140        impl<T: Integer + Clone + for<'r> $imp<&'r T, Output = T>, R: Reducer<T> + Clone>
141            $imp<&Self> for Mint<T, R>
142        {
143            type Output = Mint<T, R>;
144            #[inline]
145            fn $method(self, rhs: &Self) -> Self::Output {
146                Mint(match (self.0, &rhs.0) {
147                    (Left(v1), Left(v2)) => Left(v1.$method(v2)),
148                    (Left(v1), Right(v2)) => Right(v2.convert(v1).$method(v2)),
149                    (Right(v1), Left(v2)) => {
150                        let v2 = v1.convert(v2.clone());
151                        Right(v1.$method(v2))
152                    }
153                    (Right(v1), Right(v2)) => Right(v1.$method(v2)),
154                })
155            }
156        }
157
158        impl<T: Integer + Clone, R: Reducer<T> + Clone> $imp<Mint<T, R>> for &Mint<T, R> {
159            type Output = Mint<T, R>;
160            // FIXME: additional clone here due to https://github.com/rust-lang/rust/issues/39959
161            // (same for ref & ref operation below, and those for Div and Rem)
162            #[inline]
163            fn $method(self, rhs: Mint<T, R>) -> Self::Output {
164                Mint(match (&self.0, rhs.0) {
165                    (Left(v1), Left(v2)) => Left(v1.clone().$method(v2)),
166                    (Left(v1), Right(v2)) => Right(v2.convert(v1.clone()).$method(v2)),
167                    (Right(v1), Left(v2)) => {
168                        let v2 = v1.convert(v2);
169                        Right(v1.clone().$method(v2))
170                    }
171                    (Right(v1), Right(v2)) => Right(v1.$method(v2)),
172                })
173            }
174        }
175        impl<
176                'a,
177                'b,
178                T: Integer + Clone + for<'r> $imp<&'r T, Output = T>,
179                R: Reducer<T> + Clone,
180            > $imp<&'b Mint<T, R>> for &'a Mint<T, R>
181        {
182            type Output = Mint<T, R>;
183            #[inline]
184            fn $method(self, rhs: &Mint<T, R>) -> Self::Output {
185                Mint(match (&self.0, &rhs.0) {
186                    (Left(v1), Left(v2)) => Left(v1.clone().$method(v2)),
187                    (Left(v1), Right(v2)) => Right(v2.convert(v1.clone()).$method(v2)),
188                    (Right(v1), Left(v2)) => {
189                        let v2 = v1.convert(v2.clone());
190                        Right(v1.clone().$method(v2))
191                    }
192                    (Right(v1), Right(v2)) => Right(v1.$method(v2)),
193                })
194            }
195        }
196    };
197}
198
199forward_binops_right!(impl Add, add);
200forward_binops_right!(impl Sub, sub);
201forward_binops_right!(impl Mul, mul);
202
203impl<T: Integer + Clone, R: Reducer<T>> Div for Mint<T, R> {
204    type Output = Self;
205
206    #[inline]
207    fn div(self, rhs: Self) -> Self::Output {
208        let (v1, v2) = left_only(self, rhs);
209        Self(Left(v1.div(v2)))
210    }
211}
212impl<T: Integer + Clone + for<'r> Div<&'r T, Output = T>, R: Reducer<T>> Div<&Self> for Mint<T, R> {
213    type Output = Self;
214
215    #[inline]
216    fn div(self, rhs: &Self) -> Self::Output {
217        match (self.0, &rhs.0) {
218            (Left(v1), Left(v2)) => Self(Left(v1.div(v2))),
219            (_, _) => unreachable!(),
220        }
221    }
222}
223impl<T: Integer + Clone, R: Reducer<T>> Div<Mint<T, R>> for &Mint<T, R> {
224    type Output = Mint<T, R>;
225
226    #[inline]
227    fn div(self, rhs: Mint<T, R>) -> Self::Output {
228        match (&self.0, rhs.0) {
229            (Left(v1), Left(v2)) => Mint(Left(v1.clone().div(v2))),
230            (_, _) => unreachable!(),
231        }
232    }
233}
234impl<T: Integer + Clone + for<'r> Div<&'r T, Output = T>, R: Reducer<T>> Div<&Mint<T, R>>
235    for &Mint<T, R>
236{
237    type Output = Mint<T, R>;
238    #[inline]
239    fn div(self, rhs: &Mint<T, R>) -> Self::Output {
240        match (&self.0, &rhs.0) {
241            (Left(v1), Left(v2)) => Mint(Left(v1.clone().div(v2))),
242            (_, _) => unreachable!(),
243        }
244    }
245}
246
247impl<T: Integer + Clone, R: Reducer<T> + Clone> Rem for Mint<T, R> {
248    type Output = Self;
249
250    #[inline]
251    fn rem(self, rhs: Self) -> Self::Output {
252        match (self.0, rhs.0) {
253            (Left(v1), Left(v2)) => Self(Right(ReducedInt::new(v1, &v2))),
254            (Right(v1), Left(v2)) => {
255                debug_assert!(v1.modulus() == v2);
256                Self(Right(v1))
257            }
258            (_, _) => unreachable!(),
259        }
260    }
261}
262impl<T: Integer + Clone, R: Reducer<T> + Clone> Rem<&Self> for Mint<T, R> {
263    type Output = Self;
264
265    #[inline]
266    fn rem(self, rhs: &Self) -> Self::Output {
267        match (self.0, &rhs.0) {
268            (Left(v1), Left(v2)) => Self(Right(ReducedInt::new(v1, v2))),
269            (Right(v1), Left(v2)) => {
270                debug_assert!(&v1.modulus() == v2);
271                Self(Right(v1))
272            }
273            (_, _) => unreachable!(),
274        }
275    }
276}
277impl<T: Integer + Clone, R: Reducer<T> + Clone> Rem<Mint<T, R>> for &Mint<T, R> {
278    type Output = Mint<T, R>;
279
280    #[inline]
281    fn rem(self, rhs: Mint<T, R>) -> Self::Output {
282        match (&self.0, rhs.0) {
283            (Left(v1), Left(v2)) => Mint(Right(ReducedInt::new(v1.clone(), &v2))),
284            (Right(v1), Left(v2)) => {
285                debug_assert!(v1.modulus() == v2);
286                Mint(Right(v1.clone()))
287            }
288            (_, _) => unreachable!(),
289        }
290    }
291}
292impl<T: Integer + Clone, R: Reducer<T> + Clone> Rem<&Mint<T, R>> for &Mint<T, R> {
293    type Output = Mint<T, R>;
294
295    #[inline]
296    fn rem(self, rhs: &Mint<T, R>) -> Self::Output {
297        match (&self.0, &rhs.0) {
298            (Left(v1), Left(v2)) => Mint(Right(ReducedInt::new(v1.clone(), v2))),
299            (Right(v1), Left(v2)) => {
300                debug_assert!(&v1.modulus() == v2);
301                Mint(Right(v1.clone()))
302            }
303            (_, _) => unreachable!(),
304        }
305    }
306}
307
308impl<T: Integer + Clone, R: Reducer<T> + Clone> Zero for Mint<T, R> {
309    #[inline(always)]
310    fn zero() -> Self {
311        Self(Left(T::zero()))
312    }
313    #[inline(always)]
314    fn is_zero(&self) -> bool {
315        match &self.0 {
316            Left(v) => v.is_zero(),
317            Right(m) => m.is_zero(),
318        }
319    }
320}
321
322impl<T: Integer + Clone, R: Reducer<T> + Clone> One for Mint<T, R> {
323    #[inline(always)]
324    fn one() -> Self {
325        Self(Left(T::one()))
326    }
327    forward_uops_ref!(is_one => bool);
328}
329
330impl<T: Integer + Clone, R: Reducer<T> + Clone> Num for Mint<T, R> {
331    type FromStrRadixErr = <T as Num>::FromStrRadixErr;
332
333    #[inline(always)]
334    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
335        T::from_str_radix(str, radix).map(|v| Self(Left(v)))
336    }
337}
338
339impl<T: Integer + Clone, R: Reducer<T> + Clone> Integer for Mint<T, R> {
340    forward_binops_left_ref_only!(div_floor);
341    forward_binops_left_ref_only!(mod_floor);
342    forward_binops_left_ref_only!(lcm);
343    forward_binops_left_ref_only!(is_multiple_of => bool);
344    forward_uops_ref!(is_even => bool);
345    forward_uops_ref!(is_odd => bool);
346
347    #[inline(always)]
348    fn div_rem(&self, other: &Self) -> (Self, Self) {
349        let (v1, v2) = left_ref_only(self, other);
350        let (q, r) = v1.div_rem(v2);
351        (Self(Left(q)), Self(Left(r)))
352    }
353    #[inline(always)]
354    fn gcd(&self, other: &Self) -> Self {
355        Self(Left(match (&self.0, &other.0) {
356            (Left(v1), Left(v2)) => v1.gcd(v2),
357            (Right(v1), Left(v2)) => v1.residue().gcd(v2),
358            (Left(v1), Right(v2)) => v1.gcd(&v2.residue()),
359            (Right(v1), Right(v2)) => v1.residue().gcd(&v2.residue()),
360        }))
361    }
362}
363
364impl<T: Integer + Clone + Roots, R: Reducer<T> + Clone> Roots for Mint<T, R> {
365    #[inline]
366    fn nth_root(&self, n: u32) -> Self {
367        match &self.0 {
368            Left(v) => Self(Left(v.nth_root(n))),
369            Right(_) => unreachable!(),
370        }
371    }
372}
373
374impl<T: Integer + Clone + FromPrimitive, R: Reducer<T>> FromPrimitive for Mint<T, R> {
375    #[inline]
376    fn from_f64(n: f64) -> Option<Self> {
377        T::from_f64(n).map(|v| Self(Left(v)))
378    }
379    #[inline]
380    fn from_i64(n: i64) -> Option<Self> {
381        T::from_i64(n).map(|v| Self(Left(v)))
382    }
383    #[inline]
384    fn from_u64(n: u64) -> Option<Self> {
385        T::from_u64(n).map(|v| Self(Left(v)))
386    }
387}
388
389impl<T: Integer + Clone + ToPrimitive, R: Reducer<T> + Clone> ToPrimitive for Mint<T, R> {
390    #[inline]
391    fn to_f64(&self) -> Option<f64> {
392        match &self.0 {
393            Left(v) => v.to_f64(),
394            Right(m) => m.residue().to_f64(),
395        }
396    }
397    #[inline]
398    fn to_i64(&self) -> Option<i64> {
399        match &self.0 {
400            Left(v) => v.to_i64(),
401            Right(m) => m.residue().to_i64(),
402        }
403    }
404    #[inline]
405    fn to_u64(&self) -> Option<u64> {
406        match &self.0 {
407            Left(v) => v.to_u64(),
408            Right(m) => m.residue().to_u64(),
409        }
410    }
411}
412
413impl<T: Integer + Clone + Pow<u32, Output = T>, R: Reducer<T>> Pow<u32> for Mint<T, R> {
414    type Output = Self;
415    #[inline]
416    fn pow(self, rhs: u32) -> Self::Output {
417        match self.0 {
418            Left(v) => Self(Left(v.pow(rhs))),
419            Right(_) => unreachable!(),
420        }
421    }
422}
423
424impl<T: Integer + Clone + ExactRoots, R: Reducer<T> + Clone> ExactRoots for Mint<T, R> {
425    #[inline]
426    fn nth_root_exact(&self, n: u32) -> Option<Self> {
427        match &self.0 {
428            Left(v) => v.nth_root_exact(n).map(|v| Self(Left(v))),
429            Right(_) => unreachable!(),
430        }
431    }
432}
433
434impl<T: Integer + Clone + BitTest, R: Reducer<T>> BitTest for Mint<T, R> {
435    #[inline]
436    fn bit(&self, position: usize) -> bool {
437        match &self.0 {
438            Left(v) => v.bit(position),
439            Right(_) => unreachable!(),
440        }
441    }
442    #[inline]
443    fn bits(&self) -> usize {
444        match &self.0 {
445            Left(v) => v.bits(),
446            Right(_) => unreachable!(),
447        }
448    }
449    #[inline]
450    fn trailing_zeros(&self) -> usize {
451        match &self.0 {
452            Left(v) => v.trailing_zeros(),
453            Right(_) => unreachable!(),
454        }
455    }
456}
457
458impl<T: Integer + Clone + Shr<usize, Output = T>, R: Reducer<T>> Shr<usize> for Mint<T, R> {
459    type Output = Self;
460    #[inline]
461    fn shr(self, rhs: usize) -> Self::Output {
462        match self.0 {
463            Left(v) => Self(Left(v >> rhs)),
464            Right(_) => unreachable!(),
465        }
466    }
467}
468impl<T: Integer + Clone + Shr<usize, Output = T>, R: Reducer<T>> Shr<usize> for &Mint<T, R> {
469    type Output = Mint<T, R>;
470    #[inline]
471    fn shr(self, rhs: usize) -> Self::Output {
472        match &self.0 {
473            Left(v) => Mint(Left(v.clone() >> rhs)),
474            Right(_) => unreachable!(),
475        }
476    }
477}
478
479impl<T: Integer + Clone, R: Reducer<T> + Clone> ModularCoreOps<&Self, &Self> for Mint<T, R> {
480    type Output = Self;
481    #[inline]
482    fn addm(self, rhs: &Self, m: &Self) -> Self::Output {
483        match (self.0, &rhs.0, &m.0) {
484            (Right(v1), Right(v2), Left(m)) => {
485                debug_assert!(&v1.modulus() == m && &v2.modulus() == m);
486                Self(Right(v1 + v2))
487            }
488            (_, _, _) => unreachable!(),
489        }
490    }
491    #[inline]
492    fn subm(self, rhs: &Self, m: &Self) -> Self::Output {
493        match (self.0, &rhs.0, &m.0) {
494            (Right(v1), Right(v2), Left(m)) => {
495                debug_assert!(&v1.modulus() == m && &v2.modulus() == m);
496                Self(Right(v1 - v2))
497            }
498            (_, _, _) => unreachable!(),
499        }
500    }
501    #[inline]
502    fn mulm(self, rhs: &Self, m: &Self) -> Self::Output {
503        match (self.0, &rhs.0, &m.0) {
504            (Right(v1), Right(v2), Left(m)) => {
505                debug_assert!(&v1.modulus() == m && &v2.modulus() == m);
506                Self(Right(v1 * v2))
507            }
508            (_, _, _) => unreachable!(),
509        }
510    }
511}
512impl<'b, T: Integer + Clone, R: Reducer<T> + Clone> ModularCoreOps<&'b Mint<T, R>, &'b Mint<T, R>>
513    for &Mint<T, R>
514{
515    type Output = Mint<T, R>;
516    #[inline]
517    fn addm(self, rhs: &Mint<T, R>, m: &Mint<T, R>) -> Self::Output {
518        match (&self.0, &rhs.0, &m.0) {
519            (Right(v1), Right(v2), Left(m)) => {
520                debug_assert!(&v1.modulus() == m && &v2.modulus() == m);
521                Mint(Right(v1 + v2))
522            }
523            (_, _, _) => unreachable!(),
524        }
525    }
526    #[inline]
527    fn subm(self, rhs: &Mint<T, R>, m: &Mint<T, R>) -> Self::Output {
528        match (&self.0, &rhs.0, &m.0) {
529            (Right(v1), Right(v2), Left(m)) => {
530                debug_assert!(&v1.modulus() == m && &v2.modulus() == m);
531                Mint(Right(v1 - v2))
532            }
533            (_, _, _) => unreachable!(),
534        }
535    }
536    #[inline]
537    fn mulm(self, rhs: &Mint<T, R>, m: &Mint<T, R>) -> Self::Output {
538        match (&self.0, &rhs.0, &m.0) {
539            (Right(v1), Right(v2), Left(m)) => {
540                debug_assert!(&v1.modulus() == m && &v2.modulus() == m);
541                Mint(Right(v1 * v2))
542            }
543            (_, _, _) => unreachable!(),
544        }
545    }
546}
547
548impl<T: Integer + Clone, R: Reducer<T> + Clone> ModularUnaryOps<&Self> for Mint<T, R> {
549    type Output = Self;
550    #[inline]
551    fn negm(self, m: &Self) -> Self::Output {
552        Self(Right(match (self.0, &m.0) {
553            (Left(v), Left(m)) => ReducedInt::new(v, m).neg(),
554            (Right(v), Left(m)) => {
555                debug_assert!(&v.modulus() == m);
556                v.neg()
557            }
558            (_, Right(_)) => unreachable!(),
559        }))
560    }
561    fn invm(self, _: &Self) -> Option<Self::Output> {
562        unreachable!() // not used in this crate
563    }
564    #[inline]
565    fn dblm(self, m: &Self) -> Self::Output {
566        Self(Right(match (self.0, &m.0) {
567            (Left(v), Left(m)) => ReducedInt::new(v, m).double(),
568            (Right(v), Left(m)) => {
569                debug_assert!(&v.modulus() == m);
570                v.double()
571            }
572            (_, Right(_)) => unreachable!(),
573        }))
574    }
575    #[inline]
576    fn sqm(self, m: &Self) -> Self::Output {
577        Self(Right(match (self.0, &m.0) {
578            (Left(v), Left(m)) => ReducedInt::new(v, m).square(),
579            (Right(v), Left(m)) => {
580                debug_assert!(&v.modulus() == m);
581                v.square()
582            }
583            (_, Right(_)) => unreachable!(),
584        }))
585    }
586}
587impl<T: Integer + Clone, R: Reducer<T> + Clone> ModularUnaryOps<&Mint<T, R>> for &Mint<T, R> {
588    type Output = Mint<T, R>;
589    #[inline]
590    fn negm(self, m: &Mint<T, R>) -> Self::Output {
591        Mint(Right(match (&self.0, &m.0) {
592            (Left(v), Left(m)) => ReducedInt::new(v.clone(), m).neg(),
593            (Right(v), Left(m)) => {
594                debug_assert!(&v.modulus() == m);
595                v.clone().neg()
596            }
597            (_, Right(_)) => unreachable!(),
598        }))
599    }
600    fn invm(self, _: &Mint<T, R>) -> Option<Self::Output> {
601        unreachable!() // not used in this crate
602    }
603    #[inline]
604    fn dblm(self, m: &Mint<T, R>) -> Self::Output {
605        Mint(Right(match (&self.0, &m.0) {
606            (Left(v), Left(m)) => ReducedInt::new(v.clone(), m).double(),
607            (Right(v), Left(m)) => {
608                debug_assert!(&v.modulus() == m);
609                v.clone().double()
610            }
611            (_, Right(_)) => unreachable!(),
612        }))
613    }
614    #[inline]
615    fn sqm(self, m: &Mint<T, R>) -> Self::Output {
616        Mint(Right(match (&self.0, &m.0) {
617            (Left(v), Left(m)) => ReducedInt::new(v.clone(), m).square(),
618            (Right(v), Left(m)) => {
619                debug_assert!(&v.modulus() == m);
620                v.clone().square()
621            }
622            (_, Right(_)) => unreachable!(),
623        }))
624    }
625}
626
627impl<T: Integer + Clone + for<'r> ModularSymbols<&'r T>, R: Reducer<T> + Clone>
628    ModularSymbols<&Self> for Mint<T, R>
629{
630    #[inline]
631    fn checked_jacobi(&self, n: &Self) -> Option<i8> {
632        match (&self.0, &n.0) {
633            (Left(a), Left(n)) => a.checked_jacobi(n),
634            (Right(a), Left(n)) => a.residue().checked_jacobi(n),
635            (_, Right(_)) => unreachable!(),
636        }
637    }
638    #[inline]
639    fn checked_legendre(&self, n: &Self) -> Option<i8> {
640        match (&self.0, &n.0) {
641            (Left(a), Left(n)) => a.checked_legendre(n),
642            (Right(a), Left(n)) => a.residue().checked_legendre(n),
643            (_, Right(_)) => unreachable!(),
644        }
645    }
646    #[inline]
647    fn kronecker(&self, n: &Self) -> i8 {
648        match (&self.0, &n.0) {
649            (Left(a), Left(n)) => a.kronecker(n),
650            (Right(a), Left(n)) => a.residue().kronecker(n),
651            (_, Right(_)) => unreachable!(),
652        }
653    }
654}
655
656impl<T: Integer + Clone, R: Reducer<T> + Clone> ModularPow<&Self, &Self> for Mint<T, R> {
657    type Output = Self;
658    #[inline]
659    fn powm(self, exp: &Self, m: &Self) -> Self::Output {
660        Self(Right(match (self.0, &exp.0, &m.0) {
661            (Left(v), Left(e), Left(m)) => ReducedInt::new(v, m).pow(&e.clone()),
662            (Right(v), Left(e), Left(m)) => {
663                debug_assert!(&v.modulus() == m);
664                v.pow(&e.clone())
665            }
666            (_, _, _) => unreachable!(),
667        }))
668    }
669}
670
671pub type SmallMint<T> = Mint<T, Montgomery<T>>;
672
673#[cfg(test)]
674#[allow(clippy::op_ref)]
675mod tests {
676    use super::*;
677
678    #[test]
679    fn test_basics() {
680        let a: SmallMint<u32> = 19.into();
681        let b: SmallMint<u32> = 8.into();
682        assert_eq!(a + b, 27.into());
683    }
684
685    // --- Sub, Mul, Div, Rem operators ---
686    #[test]
687    fn test_sub() {
688        let a: SmallMint<u32> = 19.into();
689        let b: SmallMint<u32> = 8.into();
690        assert_eq!(a - b, SmallMint::from(11u32));
691        // ref variants
692        let a: SmallMint<u32> = 19.into();
693        let b: SmallMint<u32> = 8.into();
694        assert_eq!(&a - &b, SmallMint::from(11u32));
695        assert_eq!(&a - b, SmallMint::from(11u32));
696        let b2: SmallMint<u32> = 8.into();
697        assert_eq!(a - &b2, SmallMint::from(11u32));
698    }
699
700    #[test]
701    fn test_mul() {
702        let a: SmallMint<u32> = 7.into();
703        let b: SmallMint<u32> = 6.into();
704        assert_eq!(a * b, SmallMint::from(42u32));
705        let a: SmallMint<u32> = 7.into();
706        let b: SmallMint<u32> = 6.into();
707        assert_eq!(&a * &b, SmallMint::from(42u32));
708    }
709
710    #[test]
711    fn test_div() {
712        let a: SmallMint<u32> = 42.into();
713        let b: SmallMint<u32> = 7.into();
714        assert_eq!(a / b, SmallMint::from(6u32));
715        // ref variants
716        let a: SmallMint<u32> = 42.into();
717        let b: SmallMint<u32> = 7.into();
718        assert_eq!(&a / &b, SmallMint::from(6u32));
719        assert_eq!(&a / b, SmallMint::from(6u32));
720        let b2: SmallMint<u32> = 7.into();
721        assert_eq!(a / &b2, SmallMint::from(6u32));
722    }
723
724    #[test]
725    fn test_rem_creates_right_variant() {
726        let a: SmallMint<u32> = 19.into();
727        let m: SmallMint<u32> = 7.into();
728        let r = a % m; // creates Right variant (ReducedInt)
729        assert_eq!(r.value(), 5);
730        // ref variants
731        let a: SmallMint<u32> = 19.into();
732        let m: SmallMint<u32> = 7.into();
733        assert_eq!((&a % &m).value(), 5);
734        assert_eq!((&a % m).value(), 5);
735        let m2: SmallMint<u32> = 7.into();
736        assert_eq!((a % &m2).value(), 5);
737    }
738
739    // --- Zero / One ---
740    #[test]
741    fn test_zero_one() {
742        let z: SmallMint<u32> = Zero::zero();
743        assert!(z.is_zero());
744        assert!(!SmallMint::from(1u32).is_zero());
745
746        let o: SmallMint<u32> = One::one();
747        assert!(o.is_one());
748        assert!(!SmallMint::from(2u32).is_one());
749    }
750
751    // --- Num::from_str_radix ---
752    #[test]
753    fn test_from_str_radix() {
754        let v: SmallMint<u32> = Num::from_str_radix("ff", 16).unwrap();
755        assert_eq!(v, SmallMint::from(255u32));
756        let v: SmallMint<u32> = Num::from_str_radix("42", 10).unwrap();
757        assert_eq!(v, SmallMint::from(42u32));
758    }
759
760    // --- Integer methods ---
761    #[test]
762    fn test_integer_methods() {
763        let a: SmallMint<u32> = 17.into();
764        let b: SmallMint<u32> = 5.into();
765
766        assert_eq!(a.div_floor(&b), SmallMint::from(3u32));
767        assert_eq!(a.mod_floor(&b), SmallMint::from(2u32));
768        assert_eq!(a.gcd(&b), SmallMint::from(1u32));
769        assert_eq!(a.lcm(&b), SmallMint::from(85u32));
770
771        let (q, r) = a.div_rem(&b);
772        assert_eq!(q, SmallMint::from(3u32));
773        assert_eq!(r, SmallMint::from(2u32));
774
775        assert!(!SmallMint::from(7u32).is_even());
776        assert!(SmallMint::from(8u32).is_even());
777        assert!(SmallMint::from(7u32).is_odd());
778        assert!(!SmallMint::from(8u32).is_odd());
779
780        assert!(SmallMint::from(15u32).is_multiple_of(&SmallMint::from(5u32)));
781        assert!(!SmallMint::from(14u32).is_multiple_of(&SmallMint::from(5u32)));
782    }
783
784    // --- Ord / PartialOrd ---
785    #[test]
786    fn test_ord() {
787        let a: SmallMint<u32> = 10.into();
788        let b: SmallMint<u32> = 20.into();
789        assert!(a < b);
790        assert!(b > a);
791        assert_eq!(a.cmp(&a), std::cmp::Ordering::Equal);
792    }
793
794    #[test]
795    fn test_ord_mixed_variants() {
796        // Left vs Right and Right vs Left
797        let left: SmallMint<u32> = 5.into();
798        let val: SmallMint<u32> = 12.into();
799        let modulus: SmallMint<u32> = 7.into();
800        let right = val % modulus; // Right variant, residue = 5
801        assert_eq!(left.cmp(&right), std::cmp::Ordering::Equal);
802        assert_eq!(right.cmp(&left), std::cmp::Ordering::Equal);
803    }
804
805    // --- FromPrimitive ---
806    #[test]
807    fn test_from_primitive() {
808        let v: SmallMint<u64> = FromPrimitive::from_u64(42).unwrap();
809        assert_eq!(v, SmallMint::from(42u64));
810
811        let v: SmallMint<u64> = FromPrimitive::from_i64(42).unwrap();
812        assert_eq!(v, SmallMint::from(42u64));
813
814        let v: SmallMint<u64> = FromPrimitive::from_f64(42.0).unwrap();
815        assert_eq!(v, SmallMint::from(42u64));
816    }
817
818    // --- ToPrimitive ---
819    #[test]
820    fn test_to_primitive_left() {
821        let v: SmallMint<u64> = 42.into();
822        assert_eq!(v.to_u64(), Some(42));
823        assert_eq!(v.to_i64(), Some(42));
824        assert_eq!(v.to_f64(), Some(42.0));
825    }
826
827    #[test]
828    fn test_to_primitive_right() {
829        // Right variant (after modular reduction)
830        let v: SmallMint<u64> = 19.into();
831        let m: SmallMint<u64> = 7.into();
832        let r = v % m; // Right variant, residue = 5
833        assert_eq!(r.to_u64(), Some(5));
834        assert_eq!(r.to_i64(), Some(5));
835        assert_eq!(r.to_f64(), Some(5.0));
836    }
837
838    // --- Pow ---
839    #[test]
840    fn test_pow() {
841        let v: SmallMint<u64> = 3.into();
842        let result: SmallMint<u64> = Pow::pow(v, 4u32);
843        assert_eq!(result, SmallMint::from(81u64));
844    }
845
846    // --- BitTest ---
847    #[test]
848    fn test_bittest() {
849        let v: SmallMint<u64> = 0b1010u64.into();
850        assert!(v.bit(1));
851        assert!(!v.bit(0));
852        assert!(v.bit(3));
853        assert_eq!(v.bits(), 4);
854        assert_eq!(v.trailing_zeros(), 1);
855    }
856
857    // --- Shr ---
858    #[test]
859    fn test_shr() {
860        let v: SmallMint<u64> = 16.into();
861        assert_eq!(v >> 2, SmallMint::from(4u64));
862        // ref variant
863        let v: SmallMint<u64> = 16.into();
864        assert_eq!(&v >> 2, SmallMint::from(4u64));
865    }
866
867    // --- Roots ---
868    #[test]
869    fn test_nth_root() {
870        let v: SmallMint<u64> = 27.into();
871        assert_eq!(v.nth_root(3), SmallMint::from(3u64));
872    }
873
874    // --- ExactRoots ---
875    #[test]
876    fn test_nth_root_exact() {
877        let v: SmallMint<u64> = 49.into();
878        assert_eq!(v.nth_root_exact(2).map(|v| v.value()), Some(7));
879        let v: SmallMint<u64> = 50.into();
880        assert!(v.nth_root_exact(2).is_none());
881    }
882
883    // --- value() on Left and Right variants ---
884    #[test]
885    fn test_value_left() {
886        let v: SmallMint<u32> = 42.into();
887        assert_eq!(v.value(), 42);
888    }
889
890    #[test]
891    fn test_value_right() {
892        let v: SmallMint<u32> = 19.into();
893        let m: SmallMint<u32> = 7.into();
894        let r = v % m;
895        assert_eq!(r.value(), 5);
896    }
897
898    // --- Modular ops (addm, subm, mulm, negm, dblm, sqm, powm) ---
899    #[test]
900    fn test_modular_addm() {
901        use num_modular::ModularCoreOps;
902        let m: SmallMint<u32> = 7.into();
903        let a = SmallMint::from(5u32) % &m;
904        let b = SmallMint::from(4u32) % &m;
905        let result = a.addm(&b, &m);
906        assert_eq!(result.value(), 2); // (5+4) % 7 = 2
907    }
908
909    #[test]
910    fn test_modular_addm_ref() {
911        use num_modular::ModularCoreOps;
912        let m: SmallMint<u32> = 7.into();
913        let a = SmallMint::from(5u32) % &m;
914        let b = SmallMint::from(4u32) % &m;
915        let result = (&a).addm(&b, &m);
916        assert_eq!(result.value(), 2);
917    }
918
919    #[test]
920    fn test_modular_subm() {
921        use num_modular::ModularCoreOps;
922        let m: SmallMint<u32> = 7.into();
923        let a = SmallMint::from(3u32) % &m;
924        let b = SmallMint::from(5u32) % &m;
925        let result = a.subm(&b, &m);
926        assert_eq!(result.value(), 5); // (3-5) % 7 = 5
927    }
928
929    #[test]
930    fn test_modular_subm_ref() {
931        use num_modular::ModularCoreOps;
932        let m: SmallMint<u32> = 7.into();
933        let a = SmallMint::from(3u32) % &m;
934        let b = SmallMint::from(5u32) % &m;
935        let result = (&a).subm(&b, &m);
936        assert_eq!(result.value(), 5);
937    }
938
939    #[test]
940    fn test_modular_mulm() {
941        use num_modular::ModularCoreOps;
942        let m: SmallMint<u32> = 7.into();
943        let a = SmallMint::from(5u32) % &m;
944        let b = SmallMint::from(4u32) % &m;
945        let result = a.mulm(&b, &m);
946        assert_eq!(result.value(), 6); // (5*4) % 7 = 20 % 7 = 6
947    }
948
949    #[test]
950    fn test_modular_mulm_ref() {
951        use num_modular::ModularCoreOps;
952        let m: SmallMint<u32> = 7.into();
953        let a = SmallMint::from(5u32) % &m;
954        let b = SmallMint::from(4u32) % &m;
955        let result = (&a).mulm(&b, &m);
956        assert_eq!(result.value(), 6);
957    }
958
959    #[test]
960    fn test_modular_negm() {
961        use num_modular::ModularUnaryOps;
962        let m: SmallMint<u32> = 7.into();
963        // From Left variant
964        let a: SmallMint<u32> = 3.into();
965        let result = a.negm(&m);
966        assert_eq!(result.value(), 4); // -3 % 7 = 4
967                                       // From Right variant
968        let m: SmallMint<u32> = 7.into();
969        let a = SmallMint::from(3u32) % &m;
970        let result = a.negm(&m);
971        assert_eq!(result.value(), 4);
972    }
973
974    #[test]
975    fn test_modular_negm_ref() {
976        use num_modular::ModularUnaryOps;
977        let m: SmallMint<u32> = 7.into();
978        let a: SmallMint<u32> = 3.into();
979        let result = (&a).negm(&m);
980        assert_eq!(result.value(), 4);
981        // Right variant ref
982        let a = SmallMint::from(3u32) % &m;
983        let result = (&a).negm(&m);
984        assert_eq!(result.value(), 4);
985    }
986
987    #[test]
988    fn test_modular_dblm() {
989        use num_modular::ModularUnaryOps;
990        let m: SmallMint<u32> = 7.into();
991        let a: SmallMint<u32> = 5.into();
992        let result = a.dblm(&m);
993        assert_eq!(result.value(), 3); // (5*2) % 7 = 3
994                                       // Right variant
995        let m: SmallMint<u32> = 7.into();
996        let a = SmallMint::from(5u32) % &m;
997        let result = a.dblm(&m);
998        assert_eq!(result.value(), 3);
999    }
1000
1001    #[test]
1002    fn test_modular_dblm_ref() {
1003        use num_modular::ModularUnaryOps;
1004        let m: SmallMint<u32> = 7.into();
1005        let a: SmallMint<u32> = 5.into();
1006        let result = (&a).dblm(&m);
1007        assert_eq!(result.value(), 3);
1008        let a = SmallMint::from(5u32) % &m;
1009        let result = (&a).dblm(&m);
1010        assert_eq!(result.value(), 3);
1011    }
1012
1013    #[test]
1014    fn test_modular_sqm() {
1015        use num_modular::ModularUnaryOps;
1016        let m: SmallMint<u32> = 7.into();
1017        let a: SmallMint<u32> = 4.into();
1018        let result = a.sqm(&m);
1019        assert_eq!(result.value(), 2); // (4^2) % 7 = 16 % 7 = 2
1020                                       // Right variant
1021        let m: SmallMint<u32> = 7.into();
1022        let a = SmallMint::from(4u32) % &m;
1023        let result = a.sqm(&m);
1024        assert_eq!(result.value(), 2);
1025    }
1026
1027    #[test]
1028    fn test_modular_sqm_ref() {
1029        use num_modular::ModularUnaryOps;
1030        let m: SmallMint<u32> = 7.into();
1031        let a: SmallMint<u32> = 4.into();
1032        let result = (&a).sqm(&m);
1033        assert_eq!(result.value(), 2);
1034        let a = SmallMint::from(4u32) % &m;
1035        let result = (&a).sqm(&m);
1036        assert_eq!(result.value(), 2);
1037    }
1038
1039    #[test]
1040    fn test_modular_powm() {
1041        use num_modular::ModularPow;
1042        let m: SmallMint<u32> = 7.into();
1043        // Left variant
1044        let base: SmallMint<u32> = 3.into();
1045        let exp: SmallMint<u32> = 4.into();
1046        let result = base.powm(&exp, &m);
1047        assert_eq!(result.value(), 4); // 3^4 % 7 = 81 % 7 = 4
1048                                       // Right variant
1049        let base = SmallMint::from(3u32) % &m;
1050        let result = base.powm(&exp, &m);
1051        assert_eq!(result.value(), 4);
1052    }
1053
1054    // --- ModularSymbols (jacobi, kronecker) ---
1055    #[test]
1056    fn test_jacobi() {
1057        use num_modular::ModularSymbols;
1058        let a: SmallMint<u64> = 2.into();
1059        let n: SmallMint<u64> = 7.into();
1060        assert_eq!(a.checked_jacobi(&n), Some(1)); // 2 is QR mod 7
1061        let a: SmallMint<u64> = 3.into();
1062        assert_eq!(a.checked_jacobi(&n), Some(-1)); // 3 is QNR mod 7
1063    }
1064
1065    #[test]
1066    fn test_jacobi_right_variant() {
1067        use num_modular::ModularSymbols;
1068        let n: SmallMint<u64> = 7.into();
1069        let a = SmallMint::from(2u64) % &n; // Right variant
1070        assert_eq!(a.checked_jacobi(&n), Some(1));
1071    }
1072
1073    #[test]
1074    fn test_kronecker() {
1075        use num_modular::ModularSymbols;
1076        let a: SmallMint<u64> = 2.into();
1077        let n: SmallMint<u64> = 7.into();
1078        assert_eq!(a.kronecker(&n), 1);
1079    }
1080
1081    #[test]
1082    fn test_kronecker_right_variant() {
1083        use num_modular::ModularSymbols;
1084        let n: SmallMint<u64> = 7.into();
1085        let a = SmallMint::from(2u64) % &n;
1086        assert_eq!(a.kronecker(&n), 1);
1087    }
1088
1089    #[test]
1090    fn test_legendre() {
1091        use num_modular::ModularSymbols;
1092        let a: SmallMint<u64> = 2.into();
1093        let n: SmallMint<u64> = 7.into();
1094        assert_eq!(a.checked_legendre(&n), Some(1));
1095    }
1096
1097    #[test]
1098    fn test_legendre_right_variant() {
1099        use num_modular::ModularSymbols;
1100        let n: SmallMint<u64> = 7.into();
1101        let a = SmallMint::from(3u64) % &n;
1102        assert_eq!(a.checked_legendre(&n), Some(-1));
1103    }
1104
1105    // --- is_zero on Right variant ---
1106    #[test]
1107    fn test_is_zero_right_variant() {
1108        let m: SmallMint<u32> = 7.into();
1109        let a = SmallMint::from(7u32) % &m;
1110        assert!(a.is_zero());
1111        let b = SmallMint::from(3u32) % &m;
1112        assert!(!b.is_zero());
1113    }
1114
1115    // --- is_one on Right variant ---
1116    #[test]
1117    fn test_is_one_right_variant() {
1118        let m: SmallMint<u32> = 7.into();
1119        let a = SmallMint::from(8u32) % &m;
1120        assert!(a.is_one());
1121    }
1122
1123    // --- is_even/is_odd on Right variant ---
1124    #[test]
1125    fn test_even_odd_right_variant() {
1126        let m: SmallMint<u32> = 7.into();
1127        let a = SmallMint::from(9u32) % &m; // residue 2
1128        assert!(a.is_even());
1129        assert!(!a.is_odd());
1130    }
1131
1132    // --- Add/Sub/Mul with Right variants ---
1133    #[test]
1134    fn test_add_right_right() {
1135        let m: SmallMint<u32> = 7.into();
1136        let a = SmallMint::from(3u32) % &m;
1137        let b = SmallMint::from(5u32) % &m;
1138        let result = a + b;
1139        assert_eq!(result.value(), 1); // (3+5) % 7 = 1
1140    }
1141
1142    #[test]
1143    fn test_add_left_right() {
1144        let m: SmallMint<u32> = 7.into();
1145        let a: SmallMint<u32> = 3.into();
1146        let b = SmallMint::from(5u32) % &m;
1147        let result = a + b;
1148        assert_eq!(result.value(), 1);
1149    }
1150
1151    #[test]
1152    fn test_add_right_left() {
1153        let m: SmallMint<u32> = 7.into();
1154        let a = SmallMint::from(3u32) % &m;
1155        let b: SmallMint<u32> = 5.into();
1156        let result = a + b;
1157        assert_eq!(result.value(), 1);
1158    }
1159
1160    #[test]
1161    fn test_sub_right_right() {
1162        let m: SmallMint<u32> = 7.into();
1163        let a = SmallMint::from(5u32) % &m;
1164        let b = SmallMint::from(3u32) % &m;
1165        let result = a - b;
1166        assert_eq!(result.value(), 2);
1167    }
1168
1169    #[test]
1170    fn test_mul_right_right() {
1171        let m: SmallMint<u32> = 7.into();
1172        let a = SmallMint::from(3u32) % &m;
1173        let b = SmallMint::from(4u32) % &m;
1174        let result = a * b;
1175        assert_eq!(result.value(), 5); // (3*4) % 7 = 12 % 7 = 5
1176    }
1177
1178    // --- gcd with Right variants ---
1179    #[test]
1180    fn test_gcd_right_left() {
1181        let m: SmallMint<u32> = 7.into();
1182        let a = SmallMint::from(6u32) % &m; // Right, residue=6
1183        let b: SmallMint<u32> = 3.into();
1184        assert_eq!(a.gcd(&b), SmallMint::from(3u32));
1185    }
1186
1187    #[test]
1188    fn test_gcd_left_right() {
1189        let m: SmallMint<u32> = 7.into();
1190        let a: SmallMint<u32> = 6.into();
1191        let b = SmallMint::from(9u32) % &m; // Right, residue=2
1192        assert_eq!(a.gcd(&b), SmallMint::from(2u32));
1193    }
1194
1195    #[test]
1196    fn test_gcd_right_right() {
1197        let m: SmallMint<u32> = 7.into();
1198        let a = SmallMint::from(6u32) % &m;
1199        let b = SmallMint::from(9u32) % &m; // residue=2
1200        assert_eq!(a.gcd(&b), SmallMint::from(2u32));
1201    }
1202
1203    // --- Rem where Right already matches modulus ---
1204    #[test]
1205    fn test_rem_right_left_same_modulus() {
1206        let m: SmallMint<u32> = 7.into();
1207        let a = SmallMint::from(19u32) % &m; // Right variant
1208        let result = a % m; // Right % Left with matching modulus
1209        assert_eq!(result.value(), 5);
1210    }
1211}