use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};
pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H);
fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> {
if self.hi() < n {
Some(unsafe { self.unchecked_narrowing_div_rem(n) })
} else {
None
}
}
}
macro_rules! impl_narrowing_div_primitive {
($D:ident) => {
impl NarrowingDiv for $D {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
if self.hi() >= n {
unsafe { core::hint::unreachable_unchecked() }
}
((self / n.widen()).cast(), (self % n.widen()).cast())
}
}
};
}
macro_rules! impl_narrowing_div_recurse {
($D:ident) => {
impl NarrowingDiv for $D {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
if self.hi() >= n {
unsafe { core::hint::unreachable_unchecked() }
}
let lz = n.leading_zeros();
let a = self << lz;
let b = n << lz;
let ah = a.hi();
let (a0, a1) = a.lo().lo_hi();
let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };
(Self::H::from_lo_hi(q0, q1), r >> lz)
}
}
};
}
impl_narrowing_div_primitive!(u16);
impl_narrowing_div_primitive!(u32);
impl_narrowing_div_primitive!(u64);
impl_narrowing_div_primitive!(u128);
impl_narrowing_div_recurse!(u256);
unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D)
where
U: HInt,
U::D: Int + NarrowingDiv,
{
if n.leading_zeros() > 0 || a >= n {
unsafe { core::hint::unreachable_unchecked() }
}
let (n0, n1) = n.lo_hi();
let (a1, a2) = a.lo_hi();
let mut q;
let mut r;
let mut wrap;
if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) {
q = q0;
r = U::D::from_lo_hi(a0, r1);
let d = q.widen_mul(n0);
(r, wrap) = r.overflowing_sub(d);
if !wrap {
return (q, r);
}
q -= U::ONE;
} else {
debug_assert!(a2 == n1 && a1 < n0);
r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0));
q = U::MAX;
}
(r, wrap) = r.overflowing_add(n);
if wrap {
return (q, r);
}
q -= U::ONE;
(r, wrap) = r.overflowing_add(n);
debug_assert!(wrap, "estimated quotient should be off by at most two");
(q, r)
}
#[cfg(test)]
mod test {
use super::{HInt, NarrowingDiv};
#[test]
fn inverse_mul() {
for x in 0..=u8::MAX {
for y in 1..=u8::MAX {
let xy = x.widen_mul(y);
assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
assert_eq!(
(xy + (y - 1) as u16).checked_narrowing_div_rem(y),
Some((x, y - 1))
);
if y > 1 {
assert_eq!((xy + 1).checked_narrowing_div_rem(y), Some((x, 1)));
assert_eq!(
(xy + (y - 2) as u16).checked_narrowing_div_rem(y),
Some((x, y - 2))
);
}
}
}
}
}