1#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4
5use std::{
6 fmt::{self, Display},
7 iter::{Product, Sum},
8 num::{IntErrorKind, TryFromIntError},
9 ops::{
10 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
11 },
12 str::FromStr,
13};
14
15use crate::{ParseIntError, Z64, z64::TryDiv};
16
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
20#[repr(transparent)]
21pub struct Z32<const P: u32>(u32);
22
23impl<const P: u32> Z32<P> {
24 const INFO: Z32Info = Z32Info::new(P);
25
26 pub const MIN: Z32<P> = {
28 assert!(P > 0);
29 Self::new_unchecked(0)
30 };
31 pub const MAX: Z32<P> = {
33 assert!(P > 1);
34 Self::new_unchecked(P - 1)
35 };
36
37 pub const fn new(z: i32) -> Self {
41 let res = remi(z, P, Self::info().red_struct);
42 debug_assert!(res >= 0);
43 let res = res as u32;
44 debug_assert!(res < P);
45 Self::new_unchecked(res)
46 }
47
48 pub const fn new_unchecked(z: u32) -> Self {
55 assert!(P > 0);
56 debug_assert!(z < P);
57 Self(z)
58 }
59
60 pub const fn inv(&self) -> Self {
67 self.try_inv()
68 .expect("Number has no multiplicative inverse")
69 }
70
71 pub const fn try_inv(&self) -> Option<Self> {
74 let res = extended_gcd(self.0, Self::modulus());
75 if res.gcd != 1 {
76 return None;
77 }
78 let s = res.bezout[0];
79 let inv = if s < 0 {
80 debug_assert!(s + Self::modulus() as i32 >= 0);
81 s + Self::modulus() as i32
82 } else {
83 s
84 } as u32;
85 let inv = Self::new_unchecked(inv);
86 Some(inv)
87 }
88
89 pub const fn has_inv(&self) -> bool {
94 gcd(self.0, Self::modulus()) == 1
95 }
96
97 const fn info() -> &'static Z32Info {
98 &Self::INFO
99 }
100
101 pub const fn modulus() -> u32 {
103 P
104 }
105
106 #[allow(missing_docs)]
107 pub const fn modulus_inv() -> SpInverse32 {
108 Self::info().p_inv
109 }
110
111 pub fn powi(self, exp: i64) -> Self {
113 if exp < 0 {
114 self.powu((-exp) as u64).inv()
115 } else {
116 self.powu(exp as u64)
117 }
118 }
119
120 pub fn powu(mut self, mut exp: u64) -> Self {
122 let mut res = Self::new_unchecked(1);
123 while exp > 0 {
124 if exp & 1 != 0 {
125 res *= self
126 };
127 self *= self;
128 exp /= 2;
129 }
130 res
131 }
132
133 #[cfg(any(feature = "rand", feature = "num-traits"))]
134 pub(crate) const fn repr(self) -> u32 {
135 self.0
136 }
137}
138
139impl<const P: u32, const Q: u64> From<Z64<Q>> for Z32<P> {
140 fn from(z: Z64<Q>) -> Self {
141 u64::from(z).into()
142 }
143}
144
145impl<const P: u32> From<Z32<P>> for u128 {
146 fn from(i: Z32<P>) -> Self {
147 i.0 as _
148 }
149}
150
151impl<const P: u32> From<Z32<P>> for i128 {
152 fn from(i: Z32<P>) -> Self {
153 i.0 as _
154 }
155}
156
157impl<const P: u32> From<Z32<P>> for u64 {
158 fn from(i: Z32<P>) -> Self {
159 i.0 as _
160 }
161}
162
163impl<const P: u32> From<Z32<P>> for i64 {
164 fn from(i: Z32<P>) -> Self {
165 i.0 as _
166 }
167}
168
169impl<const P: u32> From<Z32<P>> for u32 {
170 fn from(i: Z32<P>) -> Self {
171 i.0
172 }
173}
174
175impl<const P: u32> From<Z32<P>> for i32 {
176 fn from(i: Z32<P>) -> Self {
177 i.0 as i32
178 }
179}
180
181impl<const P: u32> TryFrom<Z32<P>> for u16 {
182 type Error = TryFromIntError;
183
184 fn try_from(i: Z32<P>) -> Result<Self, Self::Error> {
185 i.0.try_into()
186 }
187}
188
189impl<const P: u32> TryFrom<Z32<P>> for i16 {
190 type Error = TryFromIntError;
191
192 fn try_from(i: Z32<P>) -> Result<Self, Self::Error> {
193 i.0.try_into()
194 }
195}
196
197impl<const P: u32> TryFrom<Z32<P>> for u8 {
198 type Error = TryFromIntError;
199
200 fn try_from(i: Z32<P>) -> Result<Self, Self::Error> {
201 i.0.try_into()
202 }
203}
204
205impl<const P: u32> TryFrom<Z32<P>> for i8 {
206 type Error = TryFromIntError;
207
208 fn try_from(i: Z32<P>) -> Result<Self, Self::Error> {
209 i.0.try_into()
210 }
211}
212
213impl<const P: u32> From<u128> for Z32<P> {
214 fn from(u: u128) -> Self {
215 (u.rem_euclid(P as u128) as u32).into()
216 }
217}
218
219impl<const P: u32> From<i128> for Z32<P> {
220 fn from(i: i128) -> Self {
221 (i.rem_euclid(P as i128) as u32).into()
222 }
223}
224
225impl<const P: u32> From<u64> for Z32<P> {
226 fn from(u: u64) -> Self {
227 (u.rem_euclid(P as u64) as u32).into()
228 }
229}
230
231impl<const P: u32> From<i64> for Z32<P> {
232 fn from(i: i64) -> Self {
233 (i.rem_euclid(P as i64) as u32).into()
234 }
235}
236
237impl<const P: u32> From<u32> for Z32<P> {
238 fn from(u: u32) -> Self {
239 let num = remu(u, Self::modulus(), Self::info().red_struct) as u32;
240 Self::new_unchecked(num)
241 }
242}
243
244impl<const P: u32> From<i32> for Z32<P> {
245 fn from(i: i32) -> Self {
246 Self::new(i)
247 }
248}
249
250impl<const P: u32> From<i16> for Z32<P> {
251 fn from(i: i16) -> Self {
252 Self::from(i as i32)
253 }
254}
255
256impl<const P: u32> From<u16> for Z32<P> {
257 fn from(u: u16) -> Self {
258 Self::from(u as u32)
259 }
260}
261
262impl<const P: u32> From<i8> for Z32<P> {
263 fn from(i: i8) -> Self {
264 Self::from(i as i32)
265 }
266}
267
268impl<const P: u32> From<u8> for Z32<P> {
269 fn from(u: u8) -> Self {
270 Self::from(u as u32)
271 }
272}
273
274impl<'a, const P: u32> TryFrom<&'a str> for Z32<P> {
275 type Error = ParseIntError;
276
277 fn try_from(s: &'a str) -> Result<Self, Self::Error> {
278 s.parse()
279 }
280}
281
282impl<const P: u32> FromStr for Z32<P> {
283 type Err = ParseIntError;
284
285 fn from_str(s: &str) -> Result<Self, Self::Err> {
286 let z = s.parse()?;
287 if z >= P {
288 return Err(IntErrorKind::PosOverflow.into());
289 }
290 Ok(Self::new_unchecked(z))
293 }
294}
295
296impl<const P: u32> Display for Z32<P> {
297 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
298 write!(f, "{}", self.0)
299 }
300}
301
302impl<const P: u32> AddAssign for Z32<P> {
303 fn add_assign(&mut self, rhs: Self) {
304 *self = *self + rhs;
305 }
306}
307
308impl<const P: u32> AddAssign<&Z32<P>> for Z32<P> {
309 fn add_assign(&mut self, rhs: &Self) {
310 *self = *self + *rhs;
311 }
312}
313
314impl<const P: u32> SubAssign for Z32<P> {
315 fn sub_assign(&mut self, rhs: Self) {
316 *self = *self - rhs;
317 }
318}
319
320impl<const P: u32> SubAssign<&Z32<P>> for Z32<P> {
321 fn sub_assign(&mut self, rhs: &Self) {
322 *self -= *rhs;
323 }
324}
325
326impl<const P: u32> MulAssign for Z32<P> {
327 fn mul_assign(&mut self, rhs: Self) {
328 *self = *self * rhs;
329 }
330}
331
332impl<const P: u32> MulAssign<&Z32<P>> for Z32<P> {
333 fn mul_assign(&mut self, rhs: &Self) {
334 *self = *self * *rhs;
335 }
336}
337
338impl<const P: u32> DivAssign for Z32<P> {
339 fn div_assign(&mut self, rhs: Self) {
340 *self = *self / rhs;
341 }
342}
343
344impl<const P: u32> DivAssign<&Z32<P>> for Z32<P> {
345 fn div_assign(&mut self, rhs: &Self) {
346 *self = *self / *rhs;
347 }
348}
349
350impl<const P: u32> Add for Z32<P> {
351 type Output = Self;
352
353 fn add(self, rhs: Self) -> Self::Output {
354 let res = correct_excess((self.0 + rhs.0) as i32, Self::modulus());
355 debug_assert!(res >= 0);
356 let res = res as u32;
357 Self::new_unchecked(res)
358 }
359}
360
361impl<const P: u32> Add for &Z32<P> {
362 type Output = Z32<P>;
363
364 fn add(self, rhs: Self) -> Self::Output {
365 *self + *rhs
366 }
367}
368
369impl<const P: u32> Add<Z32<P>> for &Z32<P> {
370 type Output = Z32<P>;
371
372 fn add(self, rhs: Z32<P>) -> Self::Output {
373 *self + rhs
374 }
375}
376
377impl<const P: u32> Add<&Z32<P>> for Z32<P> {
378 type Output = Z32<P>;
379
380 fn add(self, rhs: &Z32<P>) -> Self::Output {
381 self + *rhs
382 }
383}
384
385impl<const P: u32> Sub for Z32<P> {
386 type Output = Self;
387
388 fn sub(self, rhs: Self) -> Self::Output {
389 let res =
390 correct_deficit(self.0 as i32 - rhs.0 as i32, Self::modulus());
391 debug_assert!(res >= 0);
392 let res = res as u32;
393 Self::new_unchecked(res)
394 }
395}
396
397impl<const P: u32> Sub for &Z32<P> {
398 type Output = Z32<P>;
399
400 fn sub(self, rhs: Self) -> Self::Output {
401 *self - *rhs
402 }
403}
404
405impl<const P: u32> Sub<Z32<P>> for &Z32<P> {
406 type Output = Z32<P>;
407
408 fn sub(self, rhs: Z32<P>) -> Self::Output {
409 *self - rhs
410 }
411}
412
413impl<const P: u32> Sub<&Z32<P>> for Z32<P> {
414 type Output = Z32<P>;
415
416 fn sub(self, rhs: &Z32<P>) -> Self::Output {
417 self - *rhs
418 }
419}
420
421impl<const P: u32> Neg for Z32<P> {
422 type Output = Self;
423
424 fn neg(self) -> Self::Output {
425 Self::default() - self
426 }
427}
428
429impl<const P: u32> Mul for Z32<P> {
430 type Output = Self;
431
432 fn mul(self, rhs: Self) -> Self::Output {
433 let num = mul_mod(self.0, rhs.0, Self::modulus(), Self::modulus_inv());
434 Self::new_unchecked(num)
435 }
436}
437
438impl<const P: u32> Mul for &Z32<P> {
439 type Output = Z32<P>;
440
441 fn mul(self, rhs: Self) -> Self::Output {
442 *self * *rhs
443 }
444}
445
446impl<const P: u32> Mul<Z32<P>> for &Z32<P> {
447 type Output = Z32<P>;
448
449 fn mul(self, rhs: Z32<P>) -> Self::Output {
450 *self * rhs
451 }
452}
453
454impl<const P: u32> Mul<&Z32<P>> for Z32<P> {
455 type Output = Z32<P>;
456
457 fn mul(self, rhs: &Z32<P>) -> Self::Output {
458 self * *rhs
459 }
460}
461
462impl<const P: u32> Div for Z32<P> {
463 type Output = Self;
464
465 #[allow(clippy::suspicious_arithmetic_impl)]
466 fn div(self, rhs: Self) -> Self::Output {
467 self * rhs.inv()
468 }
469}
470
471const fn mul_mod(a: u32, b: u32, n: u32, ninv: SpInverse32) -> u32 {
472 let res = normalised_mul_mod(
473 a,
474 (b as i32) << ninv.shamt,
475 ((n as i32) << ninv.shamt) as u32,
476 ninv.inv,
477 ) >> ninv.shamt;
478 res as u32
479}
480
481impl<const P: u32> Div for &Z32<P> {
482 type Output = Z32<P>;
483
484 fn div(self, rhs: Self) -> Self::Output {
485 *self / *rhs
486 }
487}
488
489impl<const P: u32> Div<Z32<P>> for &Z32<P> {
490 type Output = Z32<P>;
491
492 fn div(self, rhs: Z32<P>) -> Self::Output {
493 *self / rhs
494 }
495}
496
497impl<const P: u32> Div<&Z32<P>> for Z32<P> {
498 type Output = Z32<P>;
499
500 fn div(self, rhs: &Z32<P>) -> Self::Output {
501 self / *rhs
502 }
503}
504
505impl<const P: u32> TryDiv for Z32<P> {
506 type Output = Self;
507
508 fn try_div(self, rhs: Self) -> Option<Self::Output> {
509 rhs.try_inv().map(|i| self * i)
510 }
511}
512
513impl<const P: u32> TryDiv for &Z32<P> {
514 type Output = Z32<P>;
515
516 fn try_div(self, rhs: Self) -> Option<Self::Output> {
517 (*self).try_div(*rhs)
518 }
519}
520
521impl<const P: u32> TryDiv<Z32<P>> for &Z32<P> {
522 type Output = Z32<P>;
523
524 fn try_div(self, rhs: Z32<P>) -> Option<Self::Output> {
525 (*self).try_div(rhs)
526 }
527}
528
529impl<const P: u32> TryDiv<&Z32<P>> for Z32<P> {
530 type Output = Z32<P>;
531
532 fn try_div(self, rhs: &Z32<P>) -> Option<Self::Output> {
533 self.try_div(*rhs)
534 }
535}
536
537impl<const P: u32> Sum for Z32<P> {
538 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
539 iter.fold(Self::new_unchecked(0), |a, b| a + b)
540 }
541}
542
543impl<const P: u32> Product for Z32<P> {
544 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
545 iter.fold(Self::new_unchecked(1), |a, b| a * b)
546 }
547}
548
549const fn normalised_mul_mod(a: u32, b: i32, n: u32, ninv: u32) -> i32 {
550 let u = a as u64 * b as u64;
551 let h = (u >> (SP_NBITS - 2)) as u32;
552 let q = u64_mul_high(h, ninv) >> POST_SHIFT;
553 let l = u as u32;
554 let r = l.wrapping_sub(q.wrapping_mul(n));
555 debug_assert!(r < 2 * n);
556 correct_excess(r as i32, n)
557}
558
559const fn remu(z: u32, p: u32, red: ReduceStruct) -> i32 {
560 let q = u64_mul_high(z, red.ninv);
561 let r = (z - q.wrapping_mul(p)) as i32;
562 correct_excess(r, p)
563}
564
565const fn remi(z: i32, p: u32, red: ReduceStruct) -> i32 {
566 let zu = (z as u32) & ((1u32 << (u32::BITS - 1)) - 1);
567 let r = remu(zu, p, red);
568 let s = i32_sign_mask(z) & (red.sgn as i32);
569 correct_deficit(r - s, p)
570}
571
572const fn u64_mul_high(a: u32, b: u32) -> u32 {
573 u64_get_high(a as u64 * b as u64)
574}
575
576const fn u64_get_high(u: u64) -> u32 {
577 (u >> u32::BITS) as u32
578}
579
580const fn correct_excess(a: i32, p: u32) -> i32 {
581 let n = p as i32;
582 (a - n) + (i32_sign_mask(a - n) & n)
583}
584
585const fn correct_deficit(a: i32, p: u32) -> i32 {
586 a + (i32_sign_mask(a) & (p as i32))
587}
588
589#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
590struct ExtendedGCDResult {
591 gcd: u32,
592 bezout: [i32; 2],
593}
594
595const fn extended_gcd(a: u32, b: u32) -> ExtendedGCDResult {
596 let mut old_r = a;
597 let mut r = b;
598 let mut old_s = 1;
599 let mut s = 0;
600 let mut old_t = 0;
601 let mut t = 1;
602
603 while r != 0 {
604 let quotient = old_r / r;
605 (old_r, r) = (r, old_r - quotient * r);
606 (old_s, s) = (s, old_s - quotient as i32 * s);
607 (old_t, t) = (t, old_t - quotient as i32 * t);
608 }
609 ExtendedGCDResult {
610 gcd: old_r,
611 bezout: [old_s, old_t],
612 }
613}
614
615const fn gcd(mut a: u32, mut b: u32) -> u32 {
616 while b != 0 {
617 (a, b) = (b, a % b)
618 }
619 a
620}
621
622const SP_NBITS: u32 = u32::BITS - 2;
623const PRE_SHIFT2: u32 = 2 * SP_NBITS + 1;
624const POST_SHIFT: u32 = 1;
625
626const fn used_bits(z: u32) -> u32 {
627 u32::BITS - z.leading_zeros()
628}
629
630#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
631struct Z32Info {
632 p: u32,
633 p_inv: SpInverse32,
634 red_struct: ReduceStruct,
635}
636
637impl Z32Info {
638 const fn new(p: u32) -> Self {
639 assert!(p > 1);
640 assert!(used_bits(p) <= SP_NBITS);
641
642 let p_inv = prep_mul_mod(p);
643 let red_struct = prep_rem(p);
644 Self {
645 p,
646 p_inv,
647 red_struct,
648 }
649 }
650}
651
652const fn prep_mul_mod(p: u32) -> SpInverse32 {
653 let shamt = p.leading_zeros() - (u32::BITS - SP_NBITS);
654 let inv = normalised_prep_mul_mod(p << shamt);
655 SpInverse32 { inv, shamt }
656}
657
658#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
659struct ReduceStruct {
660 ninv: u32,
661 sgn: u32,
662}
663
664const fn prep_rem(p: u32) -> ReduceStruct {
665 let mut q = (1 << (u32::BITS - 1)) / p;
666 let r = (1 << (u32::BITS - 1)) - q * p;
668
669 q *= 2;
670 q += correct_excess_quo(2 * r as i32, p as i32).0;
671
672 ReduceStruct { ninv: q, sgn: r }
673}
674
675const fn correct_excess_quo(a: i32, n: i32) -> (u32, i32) {
676 if a >= n { (1, a - n) } else { (0, a) }
677}
678
679const fn i32_sign_mask(i: i32) -> i32 {
680 i >> (u32::BITS - 1)
681}
682
683const fn u32_sign_mask(i: u32) -> i32 {
684 i32_sign_mask(i as i32)
685}
686
687#[allow(missing_docs)]
688#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
689pub struct SpInverse32 {
690 inv: u32,
691 shamt: u32,
692}
693
694const fn normalised_prep_mul_mod(n: u32) -> u32 {
707 const MAX: u64 = 1u64 << (2 * SP_NBITS - 1);
710 let init_quot_approx = MAX / n as u64;
711
712 let approx_rem = MAX - n as u64 * init_quot_approx;
713
714 let approx_rem = (approx_rem << (PRE_SHIFT2 - 2 * SP_NBITS + 1)) - 1;
715
716 let approx_rem_low = approx_rem as u32;
717 let s1 = (approx_rem >> u32::BITS) as u32;
718 let s2 = approx_rem_low >> (u32::BITS - 1);
719 let approx_rem_high = s1.wrapping_add(s2);
720
721 let approx_rem_low = approx_rem_low as i32;
722 let approx_rem_high = approx_rem_high as i32;
723
724 let bpl = 1i64 << u32::BITS;
725
726 let fr = approx_rem_low as i64 + approx_rem_high as i64 * bpl;
727
728 let mut q1 = (fr / n as i64) as i32;
734 if q1 < 0 {
735 q1 -= 1
738 }
739
740 let mut q1 = q1 as u32;
741 let approx_rem_low = approx_rem_low as u32;
742 let sub = q1.wrapping_mul(n);
743
744 let approx_rem = approx_rem_low.wrapping_sub(sub);
745
746 q1 += (1
747 + u32_sign_mask(approx_rem)
748 + u32_sign_mask(approx_rem.wrapping_sub(n))) as u32;
749
750 ((init_quot_approx as u32) << (PRE_SHIFT2 - 2 * SP_NBITS + 1))
751 .wrapping_add(q1)
752
753 }
755
756#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
772#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
773pub struct Z32FastMul<const P: u32> {
774 val: Z32<P>,
775 val_over_mod_approx: u32,
776}
777
778impl<const P: u32> From<Z32<P>> for Z32FastMul<P> {
779 fn from(val: Z32<P>) -> Self {
780 let val_over_mod_approx = Self::prep_mul_mod_precon(val.0);
781 Self {
782 val,
783 val_over_mod_approx,
784 }
785 }
786}
787
788impl<const P: u32> Z32FastMul<P> {
789 fn prep_mul_mod_precon(val: u32) -> u32 {
790 let p_inv = Z32::<P>::INFO.p_inv;
791 normalized_prep_mul_mod_precon(
792 val << p_inv.shamt,
793 P << p_inv.shamt,
794 p_inv.inv,
795 ) << (u32::BITS - SP_NBITS)
796 }
797}
798
799fn normalized_prep_mul_mod_precon(val: u32, p: u32, p_inv: u32) -> u32 {
800 let h = val << 2;
801 let q = u64_mul_high(h, p_inv);
802 let q = q >> POST_SHIFT;
803 let l = val << SP_NBITS;
804 let r = l.wrapping_sub(q.wrapping_mul(p)); debug_assert!(r < 2 * p);
806
807 q.saturating_add_signed(1 + i32_sign_mask(r as i32 - p as i32)) }
809
810impl<const P: u32> Mul<Z32FastMul<P>> for Z32<P> {
811 type Output = Z32<P>;
812
813 fn mul(self, rhs: Z32FastMul<P>) -> Self::Output {
814 let res = mul_mod_precon(self.0, rhs.val.0, P, rhs.val_over_mod_approx);
815 Z32::new_unchecked(res as u32)
816 }
817}
818
819impl<const P: u32> Mul<Z32<P>> for Z32FastMul<P> {
820 type Output = Z32<P>;
821
822 fn mul(self, rhs: Z32<P>) -> Self::Output {
823 rhs * self
824 }
825}
826
827impl<const P: u32> Mul<Z32FastMul<P>> for &Z32<P> {
828 type Output = Z32<P>;
829
830 fn mul(self, rhs: Z32FastMul<P>) -> Self::Output {
831 *self * rhs
832 }
833}
834
835impl<const P: u32> Mul<Z32<P>> for &Z32FastMul<P> {
836 type Output = Z32<P>;
837
838 fn mul(self, rhs: Z32<P>) -> Self::Output {
839 *self * rhs
840 }
841}
842
843impl<const P: u32> Mul<&Z32FastMul<P>> for Z32<P> {
844 type Output = Z32<P>;
845
846 fn mul(self, rhs: &Z32FastMul<P>) -> Self::Output {
847 self * *rhs
848 }
849}
850
851impl<const P: u32> Mul<&Z32<P>> for Z32FastMul<P> {
852 type Output = Z32<P>;
853
854 fn mul(self, rhs: &Z32<P>) -> Self::Output {
855 self * *rhs
856 }
857}
858
859impl<'a, const P: u32> Mul<&'a Z32FastMul<P>> for &Z32<P> {
860 type Output = Z32<P>;
861
862 fn mul(self, rhs: &'a Z32FastMul<P>) -> Self::Output {
863 *self * *rhs
864 }
865}
866
867impl<'a, const P: u32> Mul<&'a Z32<P>> for &Z32FastMul<P> {
868 type Output = Z32<P>;
869
870 fn mul(self, rhs: &'a Z32<P>) -> Self::Output {
871 *self * *rhs
872 }
873}
874
875impl<const P: u32> MulAssign<Z32FastMul<P>> for Z32<P> {
876 fn mul_assign(&mut self, rhs: Z32FastMul<P>) {
877 *self = *self * rhs
878 }
879}
880
881impl<const P: u32> MulAssign<&Z32FastMul<P>> for Z32<P> {
882 fn mul_assign(&mut self, rhs: &Z32FastMul<P>) {
883 *self = *self * rhs
884 }
885}
886
887fn mul_mod_precon(lhs: u32, rhs: u32, p: u32, rhs_over_mod_approx: u32) -> i32 {
888 let q = u64_mul_high(lhs, rhs_over_mod_approx);
889 let lhs_times_rhs = lhs.wrapping_mul(rhs);
890 let q_times_p = q.wrapping_mul(p);
891 let r = lhs_times_rhs.wrapping_sub(q_times_p);
892 correct_excess(r as i32, p)
893}
894
895macro_rules! impl_fastmul_from {
896 ( $( $t:ty ),* ) => {
897 $(
898 impl<const P: u32> From<$t> for Z32FastMul<P> {
899 fn from(t: $t) -> Self {
900 Self::from(Z32::from(t))
901 }
902 }
903 )*
904 }
905}
906
907impl_fastmul_from!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
908
909impl<const P: u32> From<Z32FastMul<P>> for Z32<P> {
910 fn from(z: Z32FastMul<P>) -> Self {
911 z.val
912 }
913}
914
915#[cfg(test)]
916mod tests {
917
918 use ::rand::{Rng, SeedableRng};
919 use once_cell::sync::Lazy;
920 use rug::{Integer, ops::Pow};
921
922 use super::*;
923
924 const PRIMES: [u32; 3] = [3, 65521, 1073741789];
925
926 #[test]
927 fn z32_has_inv() {
928 type Z = Z32<6>;
929 assert!(!Z::from(0).has_inv());
930 assert!(Z::from(1).has_inv());
931 assert!(!Z::from(2).has_inv());
932 assert!(!Z::from(3).has_inv());
933 assert!(!Z::from(4).has_inv());
934 assert!(Z::from(5).has_inv());
935 assert_eq!(Z::from(6), Z::from(0));
936 }
937
938 #[test]
939 #[should_panic]
940 fn z32_inv0() {
941 type Z = Z32<6>;
942 Z::from(0).inv();
943 }
944
945 #[test]
946 #[should_panic]
947 fn z32_inv2() {
948 type Z = Z32<6>;
949 Z::from(2).inv();
950 }
951
952 #[test]
953 fn z32_constr() {
954 let z: Z32<3> = 2.into();
955 assert_eq!(u32::from(z), 2);
956 let z: Z32<3> = (-1).into();
957 assert_eq!(u32::from(z), 2);
958 let z: Z32<3> = 5.into();
959 assert_eq!(u32::from(z), 2);
960
961 let z: Z32<3> = 0.into();
962 assert_eq!(u32::from(z), 0);
963 let z: Z32<3> = 3.into();
964 assert_eq!(u32::from(z), 0);
965
966 let z: Z32<3> = 2u32.into();
967 assert_eq!(u32::from(z), 2);
968 let z: Z32<3> = 5u32.into();
969 assert_eq!(u32::from(z), 2);
970
971 let z: Z32<3> = 0u32.into();
972 assert_eq!(u32::from(z), 0);
973 let z: Z32<3> = 3u32.into();
974 assert_eq!(u32::from(z), 0);
975 }
976
977 static POINTS: Lazy<[i32; 1000]> = Lazy::new(|| {
978 let mut pts = [0; 1000];
979 let mut rng = rand_xoshiro::Xoshiro256StarStar::seed_from_u64(0);
980 for pt in &mut pts {
981 *pt = rng.random();
982 }
983 pts
984 });
985
986 #[test]
987 fn tst_conv() {
988 for pt in *POINTS {
989 let z: Z32<{ PRIMES[0] }> = pt.into();
990 let z: i32 = z.into();
991 assert_eq!(z, pt.rem_euclid(PRIMES[0] as i32));
992 }
993
994 for pt in *POINTS {
995 let z: Z32<{ PRIMES[1] }> = pt.into();
996 let z: i32 = z.into();
997 assert_eq!(z, pt.rem_euclid(PRIMES[1] as i32));
998 }
999
1000 for pt in *POINTS {
1001 let z: Z32<{ PRIMES[2] }> = pt.into();
1002 let z: i32 = z.into();
1003 assert_eq!(z, pt.rem_euclid(PRIMES[2] as i32));
1004 }
1005 }
1006
1007 #[test]
1008 fn tst_add() {
1009 for pt1 in *POINTS {
1010 let z1: Z32<{ PRIMES[0] }> = pt1.into();
1011 let pt1 = pt1 as i64;
1012 for pt2 in *POINTS {
1013 let z2: Z32<{ PRIMES[0] }> = pt2.into();
1014 let pt2 = pt2 as i64;
1015 let sum1: i32 = (z1 + z2).into();
1016 let sum2 = (pt1 + pt2).rem_euclid(PRIMES[0] as i64) as i32;
1017 assert_eq!(sum1, sum2);
1018 }
1019 }
1020
1021 for pt1 in *POINTS {
1022 let z1: Z32<{ PRIMES[1] }> = pt1.into();
1023 let pt1 = pt1 as i64;
1024 for pt2 in *POINTS {
1025 let z2: Z32<{ PRIMES[1] }> = pt2.into();
1026 let pt2 = pt2 as i64;
1027 let sum1: i32 = (z1 + z2).into();
1028 let sum2 = (pt1 + pt2).rem_euclid(PRIMES[1] as i64) as i32;
1029 assert_eq!(sum1, sum2);
1030 }
1031 }
1032
1033 for pt1 in *POINTS {
1034 let z1: Z32<{ PRIMES[2] }> = pt1.into();
1035 let pt1 = pt1 as i64;
1036 for pt2 in *POINTS {
1037 let z2: Z32<{ PRIMES[2] }> = pt2.into();
1038 let pt2 = pt2 as i64;
1039 let sum1: i32 = (z1 + z2).into();
1040 let sum2 = (pt1 + pt2).rem_euclid(PRIMES[2] as i64) as i32;
1041 assert_eq!(sum1, sum2);
1042 }
1043 }
1044 }
1045
1046 #[test]
1047 fn tst_sub() {
1048 for pt1 in *POINTS {
1049 let z1: Z32<{ PRIMES[0] }> = pt1.into();
1050 let pt1 = pt1 as i64;
1051 for pt2 in *POINTS {
1052 let z2: Z32<{ PRIMES[0] }> = pt2.into();
1053 let pt2 = pt2 as i64;
1054 let sum1: i32 = (z1 - z2).into();
1055 let sum2 = (pt1 - pt2).rem_euclid(PRIMES[0] as i64) as i32;
1056 assert_eq!(sum1, sum2);
1057 }
1058 }
1059
1060 for pt1 in *POINTS {
1061 let z1: Z32<{ PRIMES[1] }> = pt1.into();
1062 let pt1 = pt1 as i64;
1063 for pt2 in *POINTS {
1064 let z2: Z32<{ PRIMES[1] }> = pt2.into();
1065 let pt2 = pt2 as i64;
1066 let sum1: i32 = (z1 - z2).into();
1067 let sum2 = (pt1 - pt2).rem_euclid(PRIMES[1] as i64) as i32;
1068 assert_eq!(sum1, sum2);
1069 }
1070 }
1071
1072 for pt1 in *POINTS {
1073 let z1: Z32<{ PRIMES[2] }> = pt1.into();
1074 let pt1 = pt1 as i64;
1075 for pt2 in *POINTS {
1076 let z2: Z32<{ PRIMES[2] }> = pt2.into();
1077 let pt2 = pt2 as i64;
1078 let sum1: i32 = (z1 - z2).into();
1079 let sum2 = (pt1 - pt2).rem_euclid(PRIMES[2] as i64) as i32;
1080 assert_eq!(sum1, sum2);
1081 }
1082 }
1083 }
1084
1085 #[test]
1086 fn tst_mul() {
1087 for pt1 in *POINTS {
1088 let z1: Z32<{ PRIMES[0] }> = pt1.into();
1089 let pt1 = pt1 as i64;
1090 for pt2 in *POINTS {
1091 let z2: Z32<{ PRIMES[0] }> = pt2.into();
1092 let pt2 = pt2 as i64;
1093 let prod1: i32 = (z1 * z2).into();
1094 let prod2 = (pt1 * pt2).rem_euclid(PRIMES[0] as i64) as i32;
1095 assert_eq!(prod1, prod2);
1096 }
1097 }
1098
1099 for pt1 in *POINTS {
1100 let z1: Z32<{ PRIMES[1] }> = pt1.into();
1101 let pt1 = pt1 as i64;
1102 for pt2 in *POINTS {
1103 let z2: Z32<{ PRIMES[1] }> = pt2.into();
1104 let pt2 = pt2 as i64;
1105 let prod1: i32 = (z1 * z2).into();
1106 let prod2 = (pt1 * pt2).rem_euclid(PRIMES[1] as i64) as i32;
1107 assert_eq!(prod1, prod2);
1108 }
1109 }
1110
1111 for pt1 in *POINTS {
1112 let z1: Z32<{ PRIMES[2] }> = pt1.into();
1113 let pt1 = pt1 as i64;
1114 for pt2 in *POINTS {
1115 let z2: Z32<{ PRIMES[2] }> = pt2.into();
1116 let pt2 = pt2 as i64;
1117 let prod1: i32 = (z1 * z2).into();
1118 let prod2 = (pt1 * pt2).rem_euclid(PRIMES[2] as i64) as i32;
1119 assert_eq!(prod1, prod2);
1120 }
1121 }
1122 }
1123
1124 #[test]
1125 fn tst_fastmul() {
1126 for pt1 in *POINTS {
1127 let z1: Z32<{ PRIMES[0] }> = pt1.into();
1128 let fast_z1 = Z32FastMul::from(z1);
1129 for pt2 in *POINTS {
1130 let z2: Z32<{ PRIMES[0] }> = pt2.into();
1131 assert_eq!(z1 * z2, fast_z1 * z2);
1132 }
1133 }
1134
1135 for pt1 in *POINTS {
1136 let z1: Z32<{ PRIMES[1] }> = pt1.into();
1137 let fast_z1 = Z32FastMul::from(z1);
1138 for pt2 in *POINTS {
1139 let z2: Z32<{ PRIMES[1] }> = pt2.into();
1140 assert_eq!(z1 * z2, fast_z1 * z2);
1141 }
1142 }
1143
1144 for pt1 in *POINTS {
1145 let z1: Z32<{ PRIMES[2] }> = pt1.into();
1146 let fast_z1 = Z32FastMul::from(z1);
1147 for pt2 in *POINTS {
1148 let z2: Z32<{ PRIMES[2] }> = pt2.into();
1149 assert_eq!(z1 * z2, fast_z1 * z2);
1150 }
1151 }
1152 }
1153
1154 #[test]
1155 fn tst_div() {
1156 for pt1 in *POINTS {
1157 let z1: Z32<{ PRIMES[0] }> = pt1.into();
1158 for pt2 in *POINTS {
1159 let z2: Z32<{ PRIMES[0] }> = pt2.into();
1160 if i32::from(z2) == 0 {
1161 continue;
1162 }
1163 let div = z1 / z2;
1164 assert_eq!(z1, div * z2);
1165 }
1166 }
1167
1168 for pt1 in *POINTS {
1169 let z1: Z32<{ PRIMES[1] }> = pt1.into();
1170 for pt2 in *POINTS {
1171 let z2: Z32<{ PRIMES[1] }> = pt2.into();
1172 if i32::from(z2) == 0 {
1173 continue;
1174 }
1175 let div = z1 / z2;
1176 assert_eq!(z1, div * z2);
1177 }
1178 }
1179
1180 for pt1 in *POINTS {
1181 let z1: Z32<{ PRIMES[2] }> = pt1.into();
1182 for pt2 in *POINTS {
1183 let z2: Z32<{ PRIMES[2] }> = pt2.into();
1184 if i32::from(z2) == 0 {
1185 continue;
1186 }
1187 let div = z1 / z2;
1188 assert_eq!(z1, div * z2);
1189 }
1190 }
1191 }
1192
1193 #[test]
1194 fn tst_pow() {
1195 let mut rng = rand_xoshiro::Xoshiro256StarStar::seed_from_u64(2849);
1196 for pt1 in *POINTS {
1197 let base = Integer::from(pt1);
1198 for _ in 0..100 {
1199 let exp: u8 = rng.random();
1200 let pow = base.clone().pow(exp as u32);
1201 let ref_pow0 =
1203 (pow.clone() % PRIMES[0] + PRIMES[0]) % PRIMES[0];
1204 let ref_pow0: u32 = ref_pow0.try_into().unwrap();
1205 let z: Z32<{ PRIMES[0] }> = pt1.into();
1206 let pow0: u32 = z.powu(exp as u64).into();
1207 assert_eq!(pow0, ref_pow0);
1208
1209 let ref_pow0 =
1210 (pow.clone() % PRIMES[1] + PRIMES[1]) % PRIMES[1];
1211 let ref_pow0: u32 = ref_pow0.try_into().unwrap();
1212 let z: Z32<{ PRIMES[1] }> = pt1.into();
1213 let pow0: u32 = z.powu(exp as u64).into();
1214 assert_eq!(pow0, ref_pow0);
1215
1216 let ref_pow0 = (pow % PRIMES[2] + PRIMES[2]) % PRIMES[2];
1217 let ref_pow0: u32 = ref_pow0.try_into().unwrap();
1218 let z: Z32<{ PRIMES[2] }> = pt1.into();
1219 let pow0: u32 = z.powu(exp as u64).into();
1220 assert_eq!(pow0, ref_pow0);
1221 }
1222 }
1223 }
1224}