1use crate::reduced::{impl_reduced_binary_pow, Vanilla};
31use crate::{DivExact, ModularUnaryOps, Reducer};
32
33#[must_use]
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct PreMulInv1by1<T> {
40 m: T,
44
45 shift: u32,
47}
48
49macro_rules! impl_premulinv_1by1_for {
50 ($T:ty) => {
51 impl PreMulInv1by1<$T> {
52 pub const fn new(divisor: $T) -> Self {
53 debug_assert!(divisor > 1);
54
55 let n = <$T>::BITS - (divisor - 1).leading_zeros();
57
58 let (lo, _hi) = split(merge(0, ones(n) - (divisor - 1)) / extend(divisor));
75 debug_assert!(_hi == 0);
76 Self {
77 shift: n - 1,
78 m: lo + 1,
79 }
80 }
81
82 #[inline]
84 pub const fn div_rem(&self, a: $T, d: $T) -> ($T, $T) {
85 let (_, t) = split(wmul(self.m, a));
102 let q = (t + ((a - t) >> 1)) >> self.shift;
104 let r = a - q * d;
105 (q, r)
106 }
107 }
108
109 impl DivExact<$T, PreMulInv1by1<$T>> for $T {
110 type Output = $T;
111
112 #[inline]
113 fn div_exact(self, d: $T, pre: &PreMulInv1by1<$T>) -> Option<Self::Output> {
114 let (q, r) = pre.div_rem(self, d);
115 if r == 0 {
116 Some(q)
117 } else {
118 None
119 }
120 }
121 }
122 };
123}
124
125#[must_use]
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub struct Normalized2by1Divisor<T> {
133 divisor: T,
135
136 m: T,
138}
139
140macro_rules! impl_normdiv_2by1_for {
141 ($T:ty, $D:ty) => {
142 impl Normalized2by1Divisor<$T> {
143 #[inline]
148 pub const fn invert_word(divisor: $T) -> $T {
149 let (m, _hi) = split(<$D>::MAX / extend(divisor));
150 debug_assert!(_hi == 1);
151 m
152 }
153
154 #[inline]
158 pub const fn new(divisor: $T) -> Self {
159 assert!(divisor.leading_zeros() == 0);
160 Self {
161 divisor,
162 m: Self::invert_word(divisor),
163 }
164 }
165
166 #[inline]
168 pub const fn div_rem_1by1(&self, a: $T) -> ($T, $T) {
169 if a < self.divisor {
170 (0, a)
171 } else {
172 (1, a - self.divisor) }
174 }
175
176 #[inline]
179 pub const fn div_rem_2by1(&self, a: $D) -> ($T, $T) {
180 let (a_lo, a_hi) = split(a);
181 debug_assert!(a_hi < self.divisor);
182
183 let (q0, q1) = split(wmul(self.m, a_hi) + a);
187
188 let q = q1.wrapping_add(1);
191 let r = a_lo.wrapping_sub(q.wrapping_mul(self.divisor));
192
193 let (_, decrease) = split(extend(q0).wrapping_sub(extend(r)));
227 let mut q = q.wrapping_add(decrease);
228 let mut r = r.wrapping_add(decrease & self.divisor);
229
230 if r >= self.divisor {
233 q += 1;
234 r -= self.divisor;
235 }
236
237 (q, r)
238 }
239 }
240 };
241}
242
243#[must_use]
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub struct PreMulInv2by1<T> {
247 div: Normalized2by1Divisor<T>,
248 shift: u32,
249}
250
251impl<T> PreMulInv2by1<T> {
252 #[inline]
253 pub const fn divider(&self) -> &Normalized2by1Divisor<T> {
254 &self.div
255 }
256 #[inline]
257 pub const fn shift(&self) -> u32 {
258 self.shift
259 }
260}
261
262macro_rules! impl_premulinv_2by1_reducer_for {
263 ($T:ty) => {
264 impl PreMulInv2by1<$T> {
265 #[inline]
266 pub const fn new(divisor: $T) -> Self {
267 let shift = divisor.leading_zeros();
268 let div = Normalized2by1Divisor::<$T>::new(divisor << shift);
269 Self { div, shift }
270 }
271
272 #[inline]
274 pub const fn divisor(&self) -> $T {
275 self.div.divisor
276 }
277 }
278
279 impl Reducer<$T> for PreMulInv2by1<$T> {
280 #[inline]
281 fn new(m: &$T) -> Self {
282 PreMulInv2by1::<$T>::new(*m)
283 }
284 #[inline]
285 fn transform(&self, target: $T) -> $T {
286 if self.shift == 0 {
287 self.div.div_rem_1by1(target).1
288 } else {
289 self.div.div_rem_2by1(extend(target) << self.shift).1
290 }
291 }
292 #[inline]
293 fn check(&self, target: &$T) -> bool {
294 *target < self.div.divisor && target & ones(self.shift) == 0
295 }
296 #[inline]
297 fn residue(&self, target: $T) -> $T {
298 target >> self.shift
299 }
300 #[inline]
301 fn modulus(&self) -> $T {
302 self.div.divisor >> self.shift
303 }
304 #[inline]
305 fn is_zero(&self, target: &$T) -> bool {
306 *target == 0
307 }
308
309 #[inline(always)]
310 fn add(&self, lhs: &$T, rhs: &$T) -> $T {
311 Vanilla::<$T>::add(&self.div.divisor, *lhs, *rhs)
312 }
313 #[inline(always)]
314 fn dbl(&self, target: $T) -> $T {
315 Vanilla::<$T>::dbl(&self.div.divisor, target)
316 }
317 #[inline(always)]
318 fn sub(&self, lhs: &$T, rhs: &$T) -> $T {
319 Vanilla::<$T>::sub(&self.div.divisor, *lhs, *rhs)
320 }
321 #[inline(always)]
322 fn neg(&self, target: $T) -> $T {
323 Vanilla::<$T>::neg(&self.div.divisor, target)
324 }
325
326 #[inline(always)]
327 fn inv(&self, target: $T) -> Option<$T> {
328 self.residue(target)
329 .invm(&self.modulus())
330 .map(|v| v << self.shift)
331 }
332 #[inline]
333 fn mul(&self, lhs: &$T, rhs: &$T) -> $T {
334 self.div.div_rem_2by1(wmul(lhs >> self.shift, *rhs)).1
335 }
336 #[inline]
337 fn sqr(&self, target: $T) -> $T {
338 self.div.div_rem_2by1(wsqr(target) >> self.shift).1
339 }
340
341 impl_reduced_binary_pow!($T);
342 }
343 };
344}
345
346#[must_use]
353#[derive(Debug, Clone, Copy, PartialEq, Eq)]
354pub struct Normalized3by2Divisor<T, D> {
355 divisor: D,
357
358 m: T,
360}
361
362macro_rules! impl_normdiv_3by2_for {
363 ($T:ty, $D:ty) => {
364 impl Normalized3by2Divisor<$T, $D> {
365 #[inline]
371 pub const fn invert_double_word(divisor: $D) -> $T {
372 let (d0, d1) = split(divisor);
373 let mut v = Normalized2by1Divisor::<$T>::invert_word(d1);
374 let (mut p, c) = d1.wrapping_mul(v).overflowing_add(d0);
377 if c {
378 v -= 1;
379 if p >= d1 {
380 v -= 1;
381 p -= d1;
382 }
383 p = p.wrapping_sub(d1);
384 }
385 let (t0, t1) = split(extend(v) * extend(d0));
388 let (p, c) = p.overflowing_add(t1);
389 if c {
390 v -= 1;
391 if merge(t0, p) >= divisor {
392 v -= 1;
393 }
394 }
395
396 v
397 }
398
399 #[inline]
403 pub const fn new(divisor: $D) -> Self {
404 assert!(divisor.leading_zeros() == 0);
405 Self {
406 divisor,
407 m: Self::invert_double_word(divisor),
408 }
409 }
410
411 #[inline]
412 pub const fn div_rem_2by2(&self, a: $D) -> ($D, $D) {
413 if a < self.divisor {
414 (0, a)
415 } else {
416 (1, a - self.divisor) }
418 }
419
420 pub const fn div_rem_3by2(&self, a_lo: $T, a_hi: $D) -> ($T, $D) {
423 debug_assert!(a_hi < self.divisor);
424 let (a1, a2) = split(a_hi);
425 let (d0, d1) = split(self.divisor);
426
427 let (q0, q1) = split(wmul(self.m, a2) + a_hi);
429 let r1 = a1.wrapping_sub(q1.wrapping_mul(d1));
430 let t = wmul(d0, q1);
431 let r = merge(a_lo, r1).wrapping_sub(t).wrapping_sub(self.divisor);
432
433 let (_, r1) = split(r);
438 let (_, decrease) = split(extend(r1).wrapping_sub(extend(q0)));
439 let mut q1 = q1.wrapping_sub(decrease);
440 let mut r = r.wrapping_add(merge(!decrease, !decrease) & self.divisor);
441
442 if r >= self.divisor {
444 q1 += 1;
445 r -= self.divisor;
446 }
447
448 (q1, r)
449 }
450
451 pub const fn div_rem_4by2(&self, a_lo: $D, a_hi: $D) -> ($D, $D) {
455 let (a0, a1) = split(a_lo);
456 let (q1, r1) = self.div_rem_3by2(a1, a_hi);
457 let (q0, r0) = self.div_rem_3by2(a0, r1);
458 (merge(q0, q1), r0)
459 }
460 }
461 };
462}
463
464#[must_use]
466#[derive(Debug, Clone, Copy, PartialEq, Eq)]
467pub struct PreMulInv3by2<T, D> {
468 div: Normalized3by2Divisor<T, D>,
469 shift: u32,
470}
471
472impl<T, D> PreMulInv3by2<T, D> {
473 #[inline]
474 pub const fn divider(&self) -> &Normalized3by2Divisor<T, D> {
475 &self.div
476 }
477 #[inline]
478 pub const fn shift(&self) -> u32 {
479 self.shift
480 }
481}
482
483macro_rules! impl_premulinv_3by2_reducer_for {
484 ($T:ty, $D:ty) => {
485 impl PreMulInv3by2<$T, $D> {
486 #[inline]
487 pub const fn new(divisor: $D) -> Self {
488 let shift = divisor.leading_zeros();
489 let div = Normalized3by2Divisor::<$T, $D>::new(divisor << shift);
490 Self { div, shift }
491 }
492
493 #[inline]
495 pub const fn divisor(&self) -> $D {
496 self.div.divisor
497 }
498 }
499
500 impl Reducer<$D> for PreMulInv3by2<$T, $D> {
501 #[inline]
502 fn new(m: &$D) -> Self {
503 assert!(*m > <$T>::MAX as $D);
504 let shift = m.leading_zeros();
505 let div = Normalized3by2Divisor::<$T, $D>::new(m << shift);
506 Self { div, shift }
507 }
508 #[inline]
509 fn transform(&self, target: $D) -> $D {
510 if self.shift == 0 {
511 self.div.div_rem_2by2(target).1
512 } else {
513 let (lo, hi) = split(target);
514 let (n0, carry) = split(extend(lo) << self.shift);
515 let n12 = (extend(hi) << self.shift) | extend(carry);
516 self.div.div_rem_3by2(n0, n12).1
517 }
518 }
519 #[inline]
520 fn check(&self, target: &$D) -> bool {
521 *target < self.div.divisor && split(*target).0 & ones(self.shift) == 0
522 }
523 #[inline]
524 fn residue(&self, target: $D) -> $D {
525 target >> self.shift
526 }
527 #[inline]
528 fn modulus(&self) -> $D {
529 self.div.divisor >> self.shift
530 }
531 #[inline]
532 fn is_zero(&self, target: &$D) -> bool {
533 *target == 0
534 }
535
536 #[inline(always)]
537 fn add(&self, lhs: &$D, rhs: &$D) -> $D {
538 Vanilla::<$D>::add(&self.div.divisor, *lhs, *rhs)
539 }
540 #[inline(always)]
541 fn dbl(&self, target: $D) -> $D {
542 Vanilla::<$D>::dbl(&self.div.divisor, target)
543 }
544 #[inline(always)]
545 fn sub(&self, lhs: &$D, rhs: &$D) -> $D {
546 Vanilla::<$D>::sub(&self.div.divisor, *lhs, *rhs)
547 }
548 #[inline(always)]
549 fn neg(&self, target: $D) -> $D {
550 Vanilla::<$D>::neg(&self.div.divisor, target)
551 }
552
553 #[inline(always)]
554 fn inv(&self, target: $D) -> Option<$D> {
555 self.residue(target)
556 .invm(&self.modulus())
557 .map(|v| v << self.shift)
558 }
559 #[inline]
560 fn mul(&self, lhs: &$D, rhs: &$D) -> $D {
561 let prod = DoubleWordModule::wmul(lhs >> self.shift, *rhs);
562 let (lo, hi) = DoubleWordModule::split(prod);
563 self.div.div_rem_4by2(lo, hi).1
564 }
565 #[inline]
566 fn sqr(&self, target: $D) -> $D {
567 let prod = DoubleWordModule::wsqr(target) >> self.shift;
568 let (lo, hi) = DoubleWordModule::split(prod);
569 self.div.div_rem_4by2(lo, hi).1
570 }
571
572 impl_reduced_binary_pow!($D);
573 }
574 };
575}
576
577macro_rules! collect_impls {
578 ($T:ident, $ns:ident) => {
579 mod $ns {
580 use super::*;
581 use crate::word::$T::*;
582
583 impl_premulinv_1by1_for!(Word);
584 impl_normdiv_2by1_for!(Word, DoubleWord);
585 impl_premulinv_2by1_reducer_for!(Word);
586 impl_normdiv_3by2_for!(Word, DoubleWord);
587 impl_premulinv_3by2_reducer_for!(Word, DoubleWord);
588 }
589 };
590}
591collect_impls!(u8, u8_impl);
592collect_impls!(u16, u16_impl);
593collect_impls!(u32, u32_impl);
594collect_impls!(u64, u64_impl);
595collect_impls!(usize, usize_impl);
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use crate::reduced::tests::ReducedTester;
601 use rand::prelude::*;
602
603 #[test]
604 #[allow(unstable_name_collisions)]
605 fn test_mul_inv_1by1() {
606 type Word = u64;
607 let mut rng = StdRng::seed_from_u64(1);
608 for _ in 0..400000 {
609 let d_bits = rng.gen_range(2..=Word::BITS);
610 let max_d = Word::MAX >> (Word::BITS - d_bits);
611 let d = rng.gen_range(max_d / 2 + 1..=max_d);
612 let fast_div = PreMulInv1by1::<Word>::new(d);
613 let n = rng.gen();
614 let (q, r) = fast_div.div_rem(n, d);
615 assert_eq!(q, n / d);
616 assert_eq!(r, n % d);
617
618 if r == 0 {
619 assert_eq!(n.div_exact(d, &fast_div), Some(q));
620 } else {
621 assert_eq!(n.div_exact(d, &fast_div), None);
622 }
623 }
624 }
625
626 #[test]
627 fn test_mul_inv_2by1() {
628 type Word = u64;
629 type Divider = Normalized2by1Divisor<Word>;
630 use crate::word::u64::*;
631
632 let fast_div = Divider::new(Word::MAX);
633 assert_eq!(fast_div.div_rem_2by1(0), (0, 0));
634
635 let mut rng = StdRng::seed_from_u64(1);
636 for _ in 0..200000 {
637 let d = rng.gen_range(Word::MAX / 2 + 1..=Word::MAX);
638 let q = rng.gen();
639 let r = rng.gen_range(0..d);
640 let (a0, a1) = split(wmul(q, d) + extend(r));
641 let fast_div = Divider::new(d);
642 assert_eq!(fast_div.div_rem_2by1(merge(a0, a1)), (q, r));
643 }
644 }
645
646 #[test]
647 fn test_mul_inv_3by2() {
648 type Word = u64;
649 type DoubleWord = u128;
650 type Divider = Normalized3by2Divisor<Word, DoubleWord>;
651 use crate::word::u64::*;
652
653 let d = DoubleWord::MAX;
654 let fast_div = Divider::new(d);
655 assert_eq!(fast_div.div_rem_3by2(0, 0), (0, 0));
656
657 let mut rng = StdRng::seed_from_u64(1);
658 for _ in 0..100000 {
659 let d = rng.gen_range(DoubleWord::MAX / 2 + 1..=DoubleWord::MAX);
660 let r = rng.gen_range(0..d);
661 let q = rng.gen();
662
663 let (d0, d1) = split(d);
664 let (r0, r1) = split(r);
665 let (a0, c) = split(wmul(q, d0) + extend(r0));
666 let (a1, a2) = split(wmul(q, d1) + extend(r1) + extend(c));
667 let a12 = merge(a1, a2);
668
669 let fast_div = Divider::new(d);
670 assert_eq!(
671 fast_div.div_rem_3by2(a0, a12),
672 (q, r),
673 "failed at {:?} / {}",
674 (a0, a12),
675 d
676 );
677 }
678 }
679
680 #[test]
681 fn test_mul_inv_4by2() {
682 type Word = u64;
683 type DoubleWord = u128;
684 type Divider = Normalized3by2Divisor<Word, DoubleWord>;
685 use crate::word::u128::*;
686
687 let mut rng = StdRng::seed_from_u64(1);
688 for _ in 0..20000 {
689 let d = rng.gen_range(DoubleWord::MAX / 2 + 1..=DoubleWord::MAX);
690 let q = rng.gen();
691 let r = rng.gen_range(0..d);
692 let (a_lo, a_hi) = split(wmul(q, d) + r as DoubleWord);
693 let fast_div = Divider::new(d);
694 assert_eq!(fast_div.div_rem_4by2(a_lo, a_hi), (q, r));
695 }
696 }
697
698 #[test]
699 fn test_2by1_against_modops() {
700 for _ in 0..10 {
701 ReducedTester::<u8>::test_against_modops::<PreMulInv2by1<u8>>(0);
702 ReducedTester::<u16>::test_against_modops::<PreMulInv2by1<u16>>(0);
703 ReducedTester::<u32>::test_against_modops::<PreMulInv2by1<u32>>(0);
704 ReducedTester::<u64>::test_against_modops::<PreMulInv2by1<u64>>(0);
705 ReducedTester::<usize>::test_against_modops::<PreMulInv2by1<usize>>(0);
707 }
708 }
709
710 #[test]
711 fn test_3by2_against_modops() {
712 for _ in 0..10 {
713 ReducedTester::<u16>::test_against_modops::<PreMulInv3by2<u8, u16>>(2);
714 ReducedTester::<u32>::test_against_modops::<PreMulInv3by2<u16, u32>>(2);
715 ReducedTester::<u64>::test_against_modops::<PreMulInv3by2<u32, u64>>(2);
716 ReducedTester::<u128>::test_against_modops::<PreMulInv3by2<u64, u128>>(2);
717 }
718 }
719}