ffnt/
z32.rs

1// TODO: code duplication with z64
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4
5use std::{
6    fmt::{self, Display},
7    iter::{Product, Sum},
8    num::{IntErrorKind, TryFromIntError},
9    ops::{
10        Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
11    },
12    str::FromStr,
13};
14
15use crate::{ParseIntError, Z64, z64::TryDiv};
16
17/// Element of a finite field with a 32 bit characteristic `P`
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
20#[repr(transparent)]
21pub struct Z32<const P: u32>(u32);
22
23impl<const P: u32> Z32<P> {
24    const INFO: Z32Info = Z32Info::new(P);
25
26    /// Minimum field element, i.e. 0
27    pub const MIN: Z32<P> = {
28        assert!(P > 0);
29        Self::new_unchecked(0)
30    };
31    /// Maximum field element, i.e. `P - 1`
32    pub const MAX: Z32<P> = {
33        assert!(P > 1);
34        Self::new_unchecked(P - 1)
35    };
36
37    /// Create a new field element corresponding to some integer
38    ///
39    /// The integer is reduced modulo the field characteristic `P`
40    pub const fn new(z: i32) -> Self {
41        let res = remi(z, P, Self::info().red_struct);
42        debug_assert!(res >= 0);
43        let res = res as u32;
44        debug_assert!(res < P);
45        Self::new_unchecked(res)
46    }
47
48    /// Create a new field element corresponding to some integer
49    /// without modular reduction
50    ///
51    /// # Safety
52    ///
53    /// The argument should be less than `P`
54    pub const fn new_unchecked(z: u32) -> Self {
55        assert!(P > 0);
56        debug_assert!(z < P);
57        Self(z)
58    }
59
60    /// The multiplicative inverse `1/z` of a field element `z`.
61    ///
62    /// # Panics
63    ///
64    /// Panics if `z` is not invertible. If the characteristic `P` is
65    /// a prime power this happens only if `z` is zero.
66    pub const fn inv(&self) -> Self {
67        self.try_inv()
68            .expect("Number has no multiplicative inverse")
69    }
70
71    /// The multiplicative inverse `1/z` of a field element `z` or
72    /// `None` if the inverse does not exist
73    pub const fn try_inv(&self) -> Option<Self> {
74        let res = extended_gcd(self.0, Self::modulus());
75        if res.gcd != 1 {
76            return None;
77        }
78        let s = res.bezout[0];
79        let inv = if s < 0 {
80            debug_assert!(s + Self::modulus() as i32 >= 0);
81            s + Self::modulus() as i32
82        } else {
83            s
84        } as u32;
85        let inv = Self::new_unchecked(inv);
86        Some(inv)
87    }
88
89    /// Check if a field element `z` has a multiplicative inverse `1/z`
90    ///
91    /// If you know that the characteristic is a prime it is usually
92    /// better to check if `z` is zero.
93    pub const fn has_inv(&self) -> bool {
94        gcd(self.0, Self::modulus()) == 1
95    }
96
97    const fn info() -> &'static Z32Info {
98        &Self::INFO
99    }
100
101    /// The field characteristic `P`
102    pub const fn modulus() -> u32 {
103        P
104    }
105
106    #[allow(missing_docs)]
107    pub const fn modulus_inv() -> SpInverse32 {
108        Self::info().p_inv
109    }
110
111    /// `z` to some integer power `exp`
112    pub fn powi(self, exp: i64) -> Self {
113        if exp < 0 {
114            self.powu((-exp) as u64).inv()
115        } else {
116            self.powu(exp as u64)
117        }
118    }
119
120    /// `z` to some integer power `exp`
121    pub fn powu(mut self, mut exp: u64) -> Self {
122        let mut res = Self::new_unchecked(1);
123        while exp > 0 {
124            if exp & 1 != 0 {
125                res *= self
126            };
127            self *= self;
128            exp /= 2;
129        }
130        res
131    }
132
133    #[cfg(any(feature = "rand", feature = "num-traits"))]
134    pub(crate) const fn repr(self) -> u32 {
135        self.0
136    }
137}
138
139impl<const P: u32, const Q: u64> From<Z64<Q>> for Z32<P> {
140    fn from(z: Z64<Q>) -> Self {
141        u64::from(z).into()
142    }
143}
144
145impl<const P: u32> From<Z32<P>> for u128 {
146    fn from(i: Z32<P>) -> Self {
147        i.0 as _
148    }
149}
150
151impl<const P: u32> From<Z32<P>> for i128 {
152    fn from(i: Z32<P>) -> Self {
153        i.0 as _
154    }
155}
156
157impl<const P: u32> From<Z32<P>> for u64 {
158    fn from(i: Z32<P>) -> Self {
159        i.0 as _
160    }
161}
162
163impl<const P: u32> From<Z32<P>> for i64 {
164    fn from(i: Z32<P>) -> Self {
165        i.0 as _
166    }
167}
168
169impl<const P: u32> From<Z32<P>> for u32 {
170    fn from(i: Z32<P>) -> Self {
171        i.0
172    }
173}
174
175impl<const P: u32> From<Z32<P>> for i32 {
176    fn from(i: Z32<P>) -> Self {
177        i.0 as i32
178    }
179}
180
181impl<const P: u32> TryFrom<Z32<P>> for u16 {
182    type Error = TryFromIntError;
183
184    fn try_from(i: Z32<P>) -> Result<Self, Self::Error> {
185        i.0.try_into()
186    }
187}
188
189impl<const P: u32> TryFrom<Z32<P>> for i16 {
190    type Error = TryFromIntError;
191
192    fn try_from(i: Z32<P>) -> Result<Self, Self::Error> {
193        i.0.try_into()
194    }
195}
196
197impl<const P: u32> TryFrom<Z32<P>> for u8 {
198    type Error = TryFromIntError;
199
200    fn try_from(i: Z32<P>) -> Result<Self, Self::Error> {
201        i.0.try_into()
202    }
203}
204
205impl<const P: u32> TryFrom<Z32<P>> for i8 {
206    type Error = TryFromIntError;
207
208    fn try_from(i: Z32<P>) -> Result<Self, Self::Error> {
209        i.0.try_into()
210    }
211}
212
213impl<const P: u32> From<u128> for Z32<P> {
214    fn from(u: u128) -> Self {
215        (u.rem_euclid(P as u128) as u32).into()
216    }
217}
218
219impl<const P: u32> From<i128> for Z32<P> {
220    fn from(i: i128) -> Self {
221        (i.rem_euclid(P as i128) as u32).into()
222    }
223}
224
225impl<const P: u32> From<u64> for Z32<P> {
226    fn from(u: u64) -> Self {
227        (u.rem_euclid(P as u64) as u32).into()
228    }
229}
230
231impl<const P: u32> From<i64> for Z32<P> {
232    fn from(i: i64) -> Self {
233        (i.rem_euclid(P as i64) as u32).into()
234    }
235}
236
237impl<const P: u32> From<u32> for Z32<P> {
238    fn from(u: u32) -> Self {
239        let num = remu(u, Self::modulus(), Self::info().red_struct) as u32;
240        Self::new_unchecked(num)
241    }
242}
243
244impl<const P: u32> From<i32> for Z32<P> {
245    fn from(i: i32) -> Self {
246        Self::new(i)
247    }
248}
249
250impl<const P: u32> From<i16> for Z32<P> {
251    fn from(i: i16) -> Self {
252        Self::from(i as i32)
253    }
254}
255
256impl<const P: u32> From<u16> for Z32<P> {
257    fn from(u: u16) -> Self {
258        Self::from(u as u32)
259    }
260}
261
262impl<const P: u32> From<i8> for Z32<P> {
263    fn from(i: i8) -> Self {
264        Self::from(i as i32)
265    }
266}
267
268impl<const P: u32> From<u8> for Z32<P> {
269    fn from(u: u8) -> Self {
270        Self::from(u as u32)
271    }
272}
273
274impl<'a, const P: u32> TryFrom<&'a str> for Z32<P> {
275    type Error = ParseIntError;
276
277    fn try_from(s: &'a str) -> Result<Self, Self::Error> {
278        s.parse()
279    }
280}
281
282impl<const P: u32> FromStr for Z32<P> {
283    type Err = ParseIntError;
284
285    fn from_str(s: &str) -> Result<Self, Self::Err> {
286        let z = s.parse()?;
287        if z >= P {
288            return Err(IntErrorKind::PosOverflow.into());
289        }
290        // # Safety
291        // we just checked that z < P
292        Ok(Self::new_unchecked(z))
293    }
294}
295
296impl<const P: u32> Display for Z32<P> {
297    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
298        write!(f, "{}", self.0)
299    }
300}
301
302impl<const P: u32> AddAssign for Z32<P> {
303    fn add_assign(&mut self, rhs: Self) {
304        *self = *self + rhs;
305    }
306}
307
308impl<const P: u32> AddAssign<&Z32<P>> for Z32<P> {
309    fn add_assign(&mut self, rhs: &Self) {
310        *self = *self + *rhs;
311    }
312}
313
314impl<const P: u32> SubAssign for Z32<P> {
315    fn sub_assign(&mut self, rhs: Self) {
316        *self = *self - rhs;
317    }
318}
319
320impl<const P: u32> SubAssign<&Z32<P>> for Z32<P> {
321    fn sub_assign(&mut self, rhs: &Self) {
322        *self -= *rhs;
323    }
324}
325
326impl<const P: u32> MulAssign for Z32<P> {
327    fn mul_assign(&mut self, rhs: Self) {
328        *self = *self * rhs;
329    }
330}
331
332impl<const P: u32> MulAssign<&Z32<P>> for Z32<P> {
333    fn mul_assign(&mut self, rhs: &Self) {
334        *self = *self * *rhs;
335    }
336}
337
338impl<const P: u32> DivAssign for Z32<P> {
339    fn div_assign(&mut self, rhs: Self) {
340        *self = *self / rhs;
341    }
342}
343
344impl<const P: u32> DivAssign<&Z32<P>> for Z32<P> {
345    fn div_assign(&mut self, rhs: &Self) {
346        *self = *self / *rhs;
347    }
348}
349
350impl<const P: u32> Add for Z32<P> {
351    type Output = Self;
352
353    fn add(self, rhs: Self) -> Self::Output {
354        let res = correct_excess((self.0 + rhs.0) as i32, Self::modulus());
355        debug_assert!(res >= 0);
356        let res = res as u32;
357        Self::new_unchecked(res)
358    }
359}
360
361impl<const P: u32> Add for &Z32<P> {
362    type Output = Z32<P>;
363
364    fn add(self, rhs: Self) -> Self::Output {
365        *self + *rhs
366    }
367}
368
369impl<const P: u32> Add<Z32<P>> for &Z32<P> {
370    type Output = Z32<P>;
371
372    fn add(self, rhs: Z32<P>) -> Self::Output {
373        *self + rhs
374    }
375}
376
377impl<const P: u32> Add<&Z32<P>> for Z32<P> {
378    type Output = Z32<P>;
379
380    fn add(self, rhs: &Z32<P>) -> Self::Output {
381        self + *rhs
382    }
383}
384
385impl<const P: u32> Sub for Z32<P> {
386    type Output = Self;
387
388    fn sub(self, rhs: Self) -> Self::Output {
389        let res =
390            correct_deficit(self.0 as i32 - rhs.0 as i32, Self::modulus());
391        debug_assert!(res >= 0);
392        let res = res as u32;
393        Self::new_unchecked(res)
394    }
395}
396
397impl<const P: u32> Sub for &Z32<P> {
398    type Output = Z32<P>;
399
400    fn sub(self, rhs: Self) -> Self::Output {
401        *self - *rhs
402    }
403}
404
405impl<const P: u32> Sub<Z32<P>> for &Z32<P> {
406    type Output = Z32<P>;
407
408    fn sub(self, rhs: Z32<P>) -> Self::Output {
409        *self - rhs
410    }
411}
412
413impl<const P: u32> Sub<&Z32<P>> for Z32<P> {
414    type Output = Z32<P>;
415
416    fn sub(self, rhs: &Z32<P>) -> Self::Output {
417        self - *rhs
418    }
419}
420
421impl<const P: u32> Neg for Z32<P> {
422    type Output = Self;
423
424    fn neg(self) -> Self::Output {
425        Self::default() - self
426    }
427}
428
429impl<const P: u32> Mul for Z32<P> {
430    type Output = Self;
431
432    fn mul(self, rhs: Self) -> Self::Output {
433        let num = mul_mod(self.0, rhs.0, Self::modulus(), Self::modulus_inv());
434        Self::new_unchecked(num)
435    }
436}
437
438impl<const P: u32> Mul for &Z32<P> {
439    type Output = Z32<P>;
440
441    fn mul(self, rhs: Self) -> Self::Output {
442        *self * *rhs
443    }
444}
445
446impl<const P: u32> Mul<Z32<P>> for &Z32<P> {
447    type Output = Z32<P>;
448
449    fn mul(self, rhs: Z32<P>) -> Self::Output {
450        *self * rhs
451    }
452}
453
454impl<const P: u32> Mul<&Z32<P>> for Z32<P> {
455    type Output = Z32<P>;
456
457    fn mul(self, rhs: &Z32<P>) -> Self::Output {
458        self * *rhs
459    }
460}
461
462impl<const P: u32> Div for Z32<P> {
463    type Output = Self;
464
465    #[allow(clippy::suspicious_arithmetic_impl)]
466    fn div(self, rhs: Self) -> Self::Output {
467        self * rhs.inv()
468    }
469}
470
471const fn mul_mod(a: u32, b: u32, n: u32, ninv: SpInverse32) -> u32 {
472    let res = normalised_mul_mod(
473        a,
474        (b as i32) << ninv.shamt,
475        ((n as i32) << ninv.shamt) as u32,
476        ninv.inv,
477    ) >> ninv.shamt;
478    res as u32
479}
480
481impl<const P: u32> Div for &Z32<P> {
482    type Output = Z32<P>;
483
484    fn div(self, rhs: Self) -> Self::Output {
485        *self / *rhs
486    }
487}
488
489impl<const P: u32> Div<Z32<P>> for &Z32<P> {
490    type Output = Z32<P>;
491
492    fn div(self, rhs: Z32<P>) -> Self::Output {
493        *self / rhs
494    }
495}
496
497impl<const P: u32> Div<&Z32<P>> for Z32<P> {
498    type Output = Z32<P>;
499
500    fn div(self, rhs: &Z32<P>) -> Self::Output {
501        self / *rhs
502    }
503}
504
505impl<const P: u32> TryDiv for Z32<P> {
506    type Output = Self;
507
508    fn try_div(self, rhs: Self) -> Option<Self::Output> {
509        rhs.try_inv().map(|i| self * i)
510    }
511}
512
513impl<const P: u32> TryDiv for &Z32<P> {
514    type Output = Z32<P>;
515
516    fn try_div(self, rhs: Self) -> Option<Self::Output> {
517        (*self).try_div(*rhs)
518    }
519}
520
521impl<const P: u32> TryDiv<Z32<P>> for &Z32<P> {
522    type Output = Z32<P>;
523
524    fn try_div(self, rhs: Z32<P>) -> Option<Self::Output> {
525        (*self).try_div(rhs)
526    }
527}
528
529impl<const P: u32> TryDiv<&Z32<P>> for Z32<P> {
530    type Output = Z32<P>;
531
532    fn try_div(self, rhs: &Z32<P>) -> Option<Self::Output> {
533        self.try_div(*rhs)
534    }
535}
536
537impl<const P: u32> Sum for Z32<P> {
538    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
539        iter.fold(Self::new_unchecked(0), |a, b| a + b)
540    }
541}
542
543impl<const P: u32> Product for Z32<P> {
544    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
545        iter.fold(Self::new_unchecked(1), |a, b| a * b)
546    }
547}
548
549const fn normalised_mul_mod(a: u32, b: i32, n: u32, ninv: u32) -> i32 {
550    let u = a as u64 * b as u64;
551    let h = (u >> (SP_NBITS - 2)) as u32;
552    let q = u64_mul_high(h, ninv) >> POST_SHIFT;
553    let l = u as u32;
554    let r = l.wrapping_sub(q.wrapping_mul(n));
555    debug_assert!(r < 2 * n);
556    correct_excess(r as i32, n)
557}
558
559const fn remu(z: u32, p: u32, red: ReduceStruct) -> i32 {
560    let q = u64_mul_high(z, red.ninv);
561    let r = (z - q.wrapping_mul(p)) as i32;
562    correct_excess(r, p)
563}
564
565const fn remi(z: i32, p: u32, red: ReduceStruct) -> i32 {
566    let zu = (z as u32) & ((1u32 << (u32::BITS - 1)) - 1);
567    let r = remu(zu, p, red);
568    let s = i32_sign_mask(z) & (red.sgn as i32);
569    correct_deficit(r - s, p)
570}
571
572const fn u64_mul_high(a: u32, b: u32) -> u32 {
573    u64_get_high(a as u64 * b as u64)
574}
575
576const fn u64_get_high(u: u64) -> u32 {
577    (u >> u32::BITS) as u32
578}
579
580const fn correct_excess(a: i32, p: u32) -> i32 {
581    let n = p as i32;
582    (a - n) + (i32_sign_mask(a - n) & n)
583}
584
585const fn correct_deficit(a: i32, p: u32) -> i32 {
586    a + (i32_sign_mask(a) & (p as i32))
587}
588
589#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
590struct ExtendedGCDResult {
591    gcd: u32,
592    bezout: [i32; 2],
593}
594
595const fn extended_gcd(a: u32, b: u32) -> ExtendedGCDResult {
596    let mut old_r = a;
597    let mut r = b;
598    let mut old_s = 1;
599    let mut s = 0;
600    let mut old_t = 0;
601    let mut t = 1;
602
603    while r != 0 {
604        let quotient = old_r / r;
605        (old_r, r) = (r, old_r - quotient * r);
606        (old_s, s) = (s, old_s - quotient as i32 * s);
607        (old_t, t) = (t, old_t - quotient as i32 * t);
608    }
609    ExtendedGCDResult {
610        gcd: old_r,
611        bezout: [old_s, old_t],
612    }
613}
614
615const fn gcd(mut a: u32, mut b: u32) -> u32 {
616    while b != 0 {
617        (a, b) = (b, a % b)
618    }
619    a
620}
621
622const SP_NBITS: u32 = u32::BITS - 2;
623const PRE_SHIFT2: u32 = 2 * SP_NBITS + 1;
624const POST_SHIFT: u32 = 1;
625
626const fn used_bits(z: u32) -> u32 {
627    u32::BITS - z.leading_zeros()
628}
629
630#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
631struct Z32Info {
632    p: u32,
633    p_inv: SpInverse32,
634    red_struct: ReduceStruct,
635}
636
637impl Z32Info {
638    const fn new(p: u32) -> Self {
639        assert!(p > 1);
640        assert!(used_bits(p) <= SP_NBITS);
641
642        let p_inv = prep_mul_mod(p);
643        let red_struct = prep_rem(p);
644        Self {
645            p,
646            p_inv,
647            red_struct,
648        }
649    }
650}
651
652const fn prep_mul_mod(p: u32) -> SpInverse32 {
653    let shamt = p.leading_zeros() - (u32::BITS - SP_NBITS);
654    let inv = normalised_prep_mul_mod(p << shamt);
655    SpInverse32 { inv, shamt }
656}
657
658#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
659struct ReduceStruct {
660    ninv: u32,
661    sgn: u32,
662}
663
664const fn prep_rem(p: u32) -> ReduceStruct {
665    let mut q = (1 << (u32::BITS - 1)) / p;
666    // r = 2^31 % p
667    let r = (1 << (u32::BITS - 1)) - q * p;
668
669    q *= 2;
670    q += correct_excess_quo(2 * r as i32, p as i32).0;
671
672    ReduceStruct { ninv: q, sgn: r }
673}
674
675const fn correct_excess_quo(a: i32, n: i32) -> (u32, i32) {
676    if a >= n { (1, a - n) } else { (0, a) }
677}
678
679const fn i32_sign_mask(i: i32) -> i32 {
680    i >> (u32::BITS - 1)
681}
682
683const fn u32_sign_mask(i: u32) -> i32 {
684    i32_sign_mask(i as i32)
685}
686
687#[allow(missing_docs)]
688#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
689pub struct SpInverse32 {
690    inv: u32,
691    shamt: u32,
692}
693
694// Adapted from NTL's sp_NormalizedPrepMulMod
695//
696// Floating-point arithmetic replaced be u64 / i64 to allow `const`.
697// The performance impact is not a huge concern since this function
698// is only evaluated at compile time and only once for each prime field order.
699// This is unlike NTL, where each change triggers a recalculation?
700//
701// This only works since this function is `const` and can be therefore
702// used to compute individual `const INFO` inside `Z32<P>` for each
703// `P`. The alternatives `lazy_static!` or `OnceCell` would not be
704// recomputed, but instead incorrectly shared between `Z32<P>` with
705// different `P`!
706const fn normalised_prep_mul_mod(n: u32) -> u32 {
707    // NOTE: this is an initial approximation
708    //       the true quotient is <= 2^SP_NBITS
709    const MAX: u64 = 1u64 << (2 * SP_NBITS - 1);
710    let init_quot_approx = MAX / n as u64;
711
712    let approx_rem = MAX - n as u64 * init_quot_approx;
713
714    let approx_rem = (approx_rem << (PRE_SHIFT2 - 2 * SP_NBITS + 1)) - 1;
715
716    let approx_rem_low = approx_rem as u32;
717    let s1 = (approx_rem >> u32::BITS) as u32;
718    let s2 = approx_rem_low >> (u32::BITS - 1);
719    let approx_rem_high = s1.wrapping_add(s2);
720
721    let approx_rem_low = approx_rem_low as i32;
722    let approx_rem_high = approx_rem_high as i32;
723
724    let bpl = 1i64 << u32::BITS;
725
726    let fr = approx_rem_low as i64 + approx_rem_high as i64 * bpl;
727
728    // now convert fr*ninv to a long
729    // but we have to be careful: fr may be negative.
730    // the result should still give floor(r/n) pm 1,
731    // and is computed in a way that avoids branching
732
733    let mut q1 = (fr / n as i64) as i32;
734    if q1 < 0 {
735        // This counteracts the round-to-zero behavior of conversion
736        // to i32.  It should be compiled into branch-free code.
737        q1 -= 1
738    }
739
740    let mut q1 = q1 as u32;
741    let approx_rem_low = approx_rem_low as u32;
742    let sub = q1.wrapping_mul(n);
743
744    let approx_rem = approx_rem_low.wrapping_sub(sub);
745
746    q1 += (1
747        + u32_sign_mask(approx_rem)
748        + u32_sign_mask(approx_rem.wrapping_sub(n))) as u32;
749
750    ((init_quot_approx as u32) << (PRE_SHIFT2 - 2 * SP_NBITS + 1))
751        .wrapping_add(q1)
752
753    // NTL_PRE_SHIFT1 is 0, so no further shift required
754}
755
756/// [Z32] variant for fast repeated multiplication
757///
758/// Use this variant to speed up repeated multiplication by the same value:
759/// ```
760/// # use ffnt::{Z32, z32::Z32FastMul};
761/// const P: u32 = 10007;
762/// let mut numbers = Vec::from_iter((1..1000).map(Z32::<P>::from));
763///
764/// let factor: Z32FastMul<P> = 12.into();
765/// for number in &mut numbers {
766///     // same as `number *= Z32::from(12)`, but faster
767///     *number *= &factor;
768/// }
769/// ```
770///
771#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
772#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
773pub struct Z32FastMul<const P: u32> {
774    val: Z32<P>,
775    val_over_mod_approx: u32,
776}
777
778impl<const P: u32> From<Z32<P>> for Z32FastMul<P> {
779    fn from(val: Z32<P>) -> Self {
780        let val_over_mod_approx = Self::prep_mul_mod_precon(val.0);
781        Self {
782            val,
783            val_over_mod_approx,
784        }
785    }
786}
787
788impl<const P: u32> Z32FastMul<P> {
789    fn prep_mul_mod_precon(val: u32) -> u32 {
790        let p_inv = Z32::<P>::INFO.p_inv;
791        normalized_prep_mul_mod_precon(
792            val << p_inv.shamt,
793            P << p_inv.shamt,
794            p_inv.inv,
795        ) << (u32::BITS - SP_NBITS)
796    }
797}
798
799fn normalized_prep_mul_mod_precon(val: u32, p: u32, p_inv: u32) -> u32 {
800    let h = val << 2;
801    let q = u64_mul_high(h, p_inv);
802    let q = q >> POST_SHIFT;
803    let l = val << SP_NBITS;
804    let r = l.wrapping_sub(q.wrapping_mul(p)); // r in [0..2*p)
805    debug_assert!(r < 2 * p);
806
807    q.saturating_add_signed(1 + i32_sign_mask(r as i32 - p as i32)) // NOTE: not shifted
808}
809
810impl<const P: u32> Mul<Z32FastMul<P>> for Z32<P> {
811    type Output = Z32<P>;
812
813    fn mul(self, rhs: Z32FastMul<P>) -> Self::Output {
814        let res = mul_mod_precon(self.0, rhs.val.0, P, rhs.val_over_mod_approx);
815        Z32::new_unchecked(res as u32)
816    }
817}
818
819impl<const P: u32> Mul<Z32<P>> for Z32FastMul<P> {
820    type Output = Z32<P>;
821
822    fn mul(self, rhs: Z32<P>) -> Self::Output {
823        rhs * self
824    }
825}
826
827impl<const P: u32> Mul<Z32FastMul<P>> for &Z32<P> {
828    type Output = Z32<P>;
829
830    fn mul(self, rhs: Z32FastMul<P>) -> Self::Output {
831        *self * rhs
832    }
833}
834
835impl<const P: u32> Mul<Z32<P>> for &Z32FastMul<P> {
836    type Output = Z32<P>;
837
838    fn mul(self, rhs: Z32<P>) -> Self::Output {
839        *self * rhs
840    }
841}
842
843impl<const P: u32> Mul<&Z32FastMul<P>> for Z32<P> {
844    type Output = Z32<P>;
845
846    fn mul(self, rhs: &Z32FastMul<P>) -> Self::Output {
847        self * *rhs
848    }
849}
850
851impl<const P: u32> Mul<&Z32<P>> for Z32FastMul<P> {
852    type Output = Z32<P>;
853
854    fn mul(self, rhs: &Z32<P>) -> Self::Output {
855        self * *rhs
856    }
857}
858
859impl<'a, const P: u32> Mul<&'a Z32FastMul<P>> for &Z32<P> {
860    type Output = Z32<P>;
861
862    fn mul(self, rhs: &'a Z32FastMul<P>) -> Self::Output {
863        *self * *rhs
864    }
865}
866
867impl<'a, const P: u32> Mul<&'a Z32<P>> for &Z32FastMul<P> {
868    type Output = Z32<P>;
869
870    fn mul(self, rhs: &'a Z32<P>) -> Self::Output {
871        *self * *rhs
872    }
873}
874
875impl<const P: u32> MulAssign<Z32FastMul<P>> for Z32<P> {
876    fn mul_assign(&mut self, rhs: Z32FastMul<P>) {
877        *self = *self * rhs
878    }
879}
880
881impl<const P: u32> MulAssign<&Z32FastMul<P>> for Z32<P> {
882    fn mul_assign(&mut self, rhs: &Z32FastMul<P>) {
883        *self = *self * rhs
884    }
885}
886
887fn mul_mod_precon(lhs: u32, rhs: u32, p: u32, rhs_over_mod_approx: u32) -> i32 {
888    let q = u64_mul_high(lhs, rhs_over_mod_approx);
889    let lhs_times_rhs = lhs.wrapping_mul(rhs);
890    let q_times_p = q.wrapping_mul(p);
891    let r = lhs_times_rhs.wrapping_sub(q_times_p);
892    correct_excess(r as i32, p)
893}
894
895macro_rules! impl_fastmul_from {
896    ( $( $t:ty ),* ) => {
897        $(
898            impl<const P: u32> From<$t> for Z32FastMul<P> {
899                fn from(t: $t) -> Self {
900                    Self::from(Z32::from(t))
901                }
902            }
903        )*
904    }
905}
906
907impl_fastmul_from!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
908
909impl<const P: u32> From<Z32FastMul<P>> for Z32<P> {
910    fn from(z: Z32FastMul<P>) -> Self {
911        z.val
912    }
913}
914
915#[cfg(test)]
916mod tests {
917
918    use ::rand::{Rng, SeedableRng};
919    use once_cell::sync::Lazy;
920    use rug::{Integer, ops::Pow};
921
922    use super::*;
923
924    const PRIMES: [u32; 3] = [3, 65521, 1073741789];
925
926    #[test]
927    fn z32_has_inv() {
928        type Z = Z32<6>;
929        assert!(!Z::from(0).has_inv());
930        assert!(Z::from(1).has_inv());
931        assert!(!Z::from(2).has_inv());
932        assert!(!Z::from(3).has_inv());
933        assert!(!Z::from(4).has_inv());
934        assert!(Z::from(5).has_inv());
935        assert_eq!(Z::from(6), Z::from(0));
936    }
937
938    #[test]
939    #[should_panic]
940    fn z32_inv0() {
941        type Z = Z32<6>;
942        Z::from(0).inv();
943    }
944
945    #[test]
946    #[should_panic]
947    fn z32_inv2() {
948        type Z = Z32<6>;
949        Z::from(2).inv();
950    }
951
952    #[test]
953    fn z32_constr() {
954        let z: Z32<3> = 2.into();
955        assert_eq!(u32::from(z), 2);
956        let z: Z32<3> = (-1).into();
957        assert_eq!(u32::from(z), 2);
958        let z: Z32<3> = 5.into();
959        assert_eq!(u32::from(z), 2);
960
961        let z: Z32<3> = 0.into();
962        assert_eq!(u32::from(z), 0);
963        let z: Z32<3> = 3.into();
964        assert_eq!(u32::from(z), 0);
965
966        let z: Z32<3> = 2u32.into();
967        assert_eq!(u32::from(z), 2);
968        let z: Z32<3> = 5u32.into();
969        assert_eq!(u32::from(z), 2);
970
971        let z: Z32<3> = 0u32.into();
972        assert_eq!(u32::from(z), 0);
973        let z: Z32<3> = 3u32.into();
974        assert_eq!(u32::from(z), 0);
975    }
976
977    static POINTS: Lazy<[i32; 1000]> = Lazy::new(|| {
978        let mut pts = [0; 1000];
979        let mut rng = rand_xoshiro::Xoshiro256StarStar::seed_from_u64(0);
980        for pt in &mut pts {
981            *pt = rng.random();
982        }
983        pts
984    });
985
986    #[test]
987    fn tst_conv() {
988        for pt in *POINTS {
989            let z: Z32<{ PRIMES[0] }> = pt.into();
990            let z: i32 = z.into();
991            assert_eq!(z, pt.rem_euclid(PRIMES[0] as i32));
992        }
993
994        for pt in *POINTS {
995            let z: Z32<{ PRIMES[1] }> = pt.into();
996            let z: i32 = z.into();
997            assert_eq!(z, pt.rem_euclid(PRIMES[1] as i32));
998        }
999
1000        for pt in *POINTS {
1001            let z: Z32<{ PRIMES[2] }> = pt.into();
1002            let z: i32 = z.into();
1003            assert_eq!(z, pt.rem_euclid(PRIMES[2] as i32));
1004        }
1005    }
1006
1007    #[test]
1008    fn tst_add() {
1009        for pt1 in *POINTS {
1010            let z1: Z32<{ PRIMES[0] }> = pt1.into();
1011            let pt1 = pt1 as i64;
1012            for pt2 in *POINTS {
1013                let z2: Z32<{ PRIMES[0] }> = pt2.into();
1014                let pt2 = pt2 as i64;
1015                let sum1: i32 = (z1 + z2).into();
1016                let sum2 = (pt1 + pt2).rem_euclid(PRIMES[0] as i64) as i32;
1017                assert_eq!(sum1, sum2);
1018            }
1019        }
1020
1021        for pt1 in *POINTS {
1022            let z1: Z32<{ PRIMES[1] }> = pt1.into();
1023            let pt1 = pt1 as i64;
1024            for pt2 in *POINTS {
1025                let z2: Z32<{ PRIMES[1] }> = pt2.into();
1026                let pt2 = pt2 as i64;
1027                let sum1: i32 = (z1 + z2).into();
1028                let sum2 = (pt1 + pt2).rem_euclid(PRIMES[1] as i64) as i32;
1029                assert_eq!(sum1, sum2);
1030            }
1031        }
1032
1033        for pt1 in *POINTS {
1034            let z1: Z32<{ PRIMES[2] }> = pt1.into();
1035            let pt1 = pt1 as i64;
1036            for pt2 in *POINTS {
1037                let z2: Z32<{ PRIMES[2] }> = pt2.into();
1038                let pt2 = pt2 as i64;
1039                let sum1: i32 = (z1 + z2).into();
1040                let sum2 = (pt1 + pt2).rem_euclid(PRIMES[2] as i64) as i32;
1041                assert_eq!(sum1, sum2);
1042            }
1043        }
1044    }
1045
1046    #[test]
1047    fn tst_sub() {
1048        for pt1 in *POINTS {
1049            let z1: Z32<{ PRIMES[0] }> = pt1.into();
1050            let pt1 = pt1 as i64;
1051            for pt2 in *POINTS {
1052                let z2: Z32<{ PRIMES[0] }> = pt2.into();
1053                let pt2 = pt2 as i64;
1054                let sum1: i32 = (z1 - z2).into();
1055                let sum2 = (pt1 - pt2).rem_euclid(PRIMES[0] as i64) as i32;
1056                assert_eq!(sum1, sum2);
1057            }
1058        }
1059
1060        for pt1 in *POINTS {
1061            let z1: Z32<{ PRIMES[1] }> = pt1.into();
1062            let pt1 = pt1 as i64;
1063            for pt2 in *POINTS {
1064                let z2: Z32<{ PRIMES[1] }> = pt2.into();
1065                let pt2 = pt2 as i64;
1066                let sum1: i32 = (z1 - z2).into();
1067                let sum2 = (pt1 - pt2).rem_euclid(PRIMES[1] as i64) as i32;
1068                assert_eq!(sum1, sum2);
1069            }
1070        }
1071
1072        for pt1 in *POINTS {
1073            let z1: Z32<{ PRIMES[2] }> = pt1.into();
1074            let pt1 = pt1 as i64;
1075            for pt2 in *POINTS {
1076                let z2: Z32<{ PRIMES[2] }> = pt2.into();
1077                let pt2 = pt2 as i64;
1078                let sum1: i32 = (z1 - z2).into();
1079                let sum2 = (pt1 - pt2).rem_euclid(PRIMES[2] as i64) as i32;
1080                assert_eq!(sum1, sum2);
1081            }
1082        }
1083    }
1084
1085    #[test]
1086    fn tst_mul() {
1087        for pt1 in *POINTS {
1088            let z1: Z32<{ PRIMES[0] }> = pt1.into();
1089            let pt1 = pt1 as i64;
1090            for pt2 in *POINTS {
1091                let z2: Z32<{ PRIMES[0] }> = pt2.into();
1092                let pt2 = pt2 as i64;
1093                let prod1: i32 = (z1 * z2).into();
1094                let prod2 = (pt1 * pt2).rem_euclid(PRIMES[0] as i64) as i32;
1095                assert_eq!(prod1, prod2);
1096            }
1097        }
1098
1099        for pt1 in *POINTS {
1100            let z1: Z32<{ PRIMES[1] }> = pt1.into();
1101            let pt1 = pt1 as i64;
1102            for pt2 in *POINTS {
1103                let z2: Z32<{ PRIMES[1] }> = pt2.into();
1104                let pt2 = pt2 as i64;
1105                let prod1: i32 = (z1 * z2).into();
1106                let prod2 = (pt1 * pt2).rem_euclid(PRIMES[1] as i64) as i32;
1107                assert_eq!(prod1, prod2);
1108            }
1109        }
1110
1111        for pt1 in *POINTS {
1112            let z1: Z32<{ PRIMES[2] }> = pt1.into();
1113            let pt1 = pt1 as i64;
1114            for pt2 in *POINTS {
1115                let z2: Z32<{ PRIMES[2] }> = pt2.into();
1116                let pt2 = pt2 as i64;
1117                let prod1: i32 = (z1 * z2).into();
1118                let prod2 = (pt1 * pt2).rem_euclid(PRIMES[2] as i64) as i32;
1119                assert_eq!(prod1, prod2);
1120            }
1121        }
1122    }
1123
1124    #[test]
1125    fn tst_fastmul() {
1126        for pt1 in *POINTS {
1127            let z1: Z32<{ PRIMES[0] }> = pt1.into();
1128            let fast_z1 = Z32FastMul::from(z1);
1129            for pt2 in *POINTS {
1130                let z2: Z32<{ PRIMES[0] }> = pt2.into();
1131                assert_eq!(z1 * z2, fast_z1 * z2);
1132            }
1133        }
1134
1135        for pt1 in *POINTS {
1136            let z1: Z32<{ PRIMES[1] }> = pt1.into();
1137            let fast_z1 = Z32FastMul::from(z1);
1138            for pt2 in *POINTS {
1139                let z2: Z32<{ PRIMES[1] }> = pt2.into();
1140                assert_eq!(z1 * z2, fast_z1 * z2);
1141            }
1142        }
1143
1144        for pt1 in *POINTS {
1145            let z1: Z32<{ PRIMES[2] }> = pt1.into();
1146            let fast_z1 = Z32FastMul::from(z1);
1147            for pt2 in *POINTS {
1148                let z2: Z32<{ PRIMES[2] }> = pt2.into();
1149                assert_eq!(z1 * z2, fast_z1 * z2);
1150            }
1151        }
1152    }
1153
1154    #[test]
1155    fn tst_div() {
1156        for pt1 in *POINTS {
1157            let z1: Z32<{ PRIMES[0] }> = pt1.into();
1158            for pt2 in *POINTS {
1159                let z2: Z32<{ PRIMES[0] }> = pt2.into();
1160                if i32::from(z2) == 0 {
1161                    continue;
1162                }
1163                let div = z1 / z2;
1164                assert_eq!(z1, div * z2);
1165            }
1166        }
1167
1168        for pt1 in *POINTS {
1169            let z1: Z32<{ PRIMES[1] }> = pt1.into();
1170            for pt2 in *POINTS {
1171                let z2: Z32<{ PRIMES[1] }> = pt2.into();
1172                if i32::from(z2) == 0 {
1173                    continue;
1174                }
1175                let div = z1 / z2;
1176                assert_eq!(z1, div * z2);
1177            }
1178        }
1179
1180        for pt1 in *POINTS {
1181            let z1: Z32<{ PRIMES[2] }> = pt1.into();
1182            for pt2 in *POINTS {
1183                let z2: Z32<{ PRIMES[2] }> = pt2.into();
1184                if i32::from(z2) == 0 {
1185                    continue;
1186                }
1187                let div = z1 / z2;
1188                assert_eq!(z1, div * z2);
1189            }
1190        }
1191    }
1192
1193    #[test]
1194    fn tst_pow() {
1195        let mut rng = rand_xoshiro::Xoshiro256StarStar::seed_from_u64(2849);
1196        for pt1 in *POINTS {
1197            let base = Integer::from(pt1);
1198            for _ in 0..100 {
1199                let exp: u8 = rng.random();
1200                let pow = base.clone().pow(exp as u32);
1201                // ensure remainder is positive and less than the mod
1202                let ref_pow0 =
1203                    (pow.clone() % PRIMES[0] + PRIMES[0]) % PRIMES[0];
1204                let ref_pow0: u32 = ref_pow0.try_into().unwrap();
1205                let z: Z32<{ PRIMES[0] }> = pt1.into();
1206                let pow0: u32 = z.powu(exp as u64).into();
1207                assert_eq!(pow0, ref_pow0);
1208
1209                let ref_pow0 =
1210                    (pow.clone() % PRIMES[1] + PRIMES[1]) % PRIMES[1];
1211                let ref_pow0: u32 = ref_pow0.try_into().unwrap();
1212                let z: Z32<{ PRIMES[1] }> = pt1.into();
1213                let pow0: u32 = z.powu(exp as u64).into();
1214                assert_eq!(pow0, ref_pow0);
1215
1216                let ref_pow0 = (pow % PRIMES[2] + PRIMES[2]) % PRIMES[2];
1217                let ref_pow0: u32 = ref_pow0.try_into().unwrap();
1218                let z: Z32<{ PRIMES[2] }> = pt1.into();
1219                let pow0: u32 = z.powu(exp as u64).into();
1220                assert_eq!(pow0, ref_pow0);
1221            }
1222        }
1223    }
1224}