use crate::{int::I256, uint::U256};
use core::mem::MaybeUninit;
#[inline(always)]
fn udiv256_by_128_to_128(u1: u128, u0: u128, mut v: u128, r: &mut u128) -> u128 {
const N_UDWORD_BITS: u32 = 128;
const B: u128 = 1 << (N_UDWORD_BITS / 2); let (un1, un0): (u128, u128); let (vn1, vn0): (u128, u128); let (mut q1, mut q0): (u128, u128); let (un128, un21, un10): (u128, u128, u128);
let s = v.leading_zeros();
if s > 0 {
v <<= s;
un128 = (u1 << s) | (u0 >> (N_UDWORD_BITS - s));
un10 = u0 << s; } else {
un128 = u1;
un10 = u0;
}
vn1 = v >> (N_UDWORD_BITS / 2);
vn0 = v & 0xFFFF_FFFF_FFFF_FFFF;
un1 = un10 >> (N_UDWORD_BITS / 2);
un0 = un10 & 0xFFFF_FFFF_FFFF_FFFF;
q1 = un128 / vn1;
let mut rhat = un128 - q1 * vn1;
while q1 >= B || q1 * vn0 > B * rhat + un1 {
q1 -= 1;
rhat += vn1;
if rhat >= B {
break;
}
}
un21 = un128
.wrapping_mul(B)
.wrapping_add(un1)
.wrapping_sub(q1.wrapping_mul(v));
q0 = un21 / vn1;
rhat = un21 - q0 * vn1;
while q0 >= B || q0 * vn0 > B * rhat + un0 {
q0 -= 1;
rhat += vn1;
if rhat >= B {
break;
}
}
*r = (un21
.wrapping_mul(B)
.wrapping_add(un0)
.wrapping_sub(q0.wrapping_mul(v)))
>> s;
q1 * B + q0
}
#[allow(clippy::many_single_char_names)]
pub fn udivmod4(
res: &mut MaybeUninit<U256>,
a: &U256,
b: &U256,
rem: Option<&mut MaybeUninit<U256>>,
) {
if a.high() | b.high() == 0 {
if let Some(rem) = rem {
rem.write(U256::from_words(0, a.low() % b.low()));
}
res.write(U256::from_words(0, a.low() / b.low()));
return;
}
let dividend = *a;
let divisor = *b;
let quotient: U256;
let mut remainder: U256;
if divisor > dividend {
if let Some(rem) = rem {
rem.write(dividend);
}
res.write(U256::ZERO);
return;
}
if *divisor.high() == 0 {
remainder = U256::ZERO;
if dividend.high() < divisor.low() {
quotient = U256::from_words(
0,
udiv256_by_128_to_128(
*dividend.high(),
*dividend.low(),
*divisor.low(),
remainder.low_mut(),
),
);
} else {
quotient = U256::from_words(
dividend.high() / divisor.low(),
udiv256_by_128_to_128(
dividend.high() % divisor.low(),
*dividend.low(),
*divisor.low(),
remainder.low_mut(),
),
);
}
if let Some(rem) = rem {
rem.write(remainder);
}
res.write(quotient);
return;
}
(quotient, remainder) = div_mod_knuth(÷nd, &divisor);
if let Some(rem) = rem {
rem.write(remainder);
}
res.write(quotient);
}
#[inline]
pub fn div_mod_knuth(u: &U256, v: &U256) -> (U256, U256) {
const N_UDWORD_BITS: u32 = 128;
#[inline]
fn full_shl(a: &U256, shift: u32) -> [u128; 3] {
debug_assert!(shift < N_UDWORD_BITS);
let mut u = [0_u128; 3];
let u_lo = a.low() << shift;
let u_hi = a >> (N_UDWORD_BITS - shift);
u[0] = u_lo;
u[1] = *u_hi.low();
u[2] = *u_hi.high();
u
}
#[inline]
fn full_shr(u: &[u128; 3], shift: u32) -> U256 {
debug_assert!(shift < N_UDWORD_BITS);
let mut res = U256::ZERO;
*res.low_mut() = u[0] >> shift;
*res.high_mut() = u[1] >> shift;
if shift > 0 {
let sh = N_UDWORD_BITS - shift;
*res.low_mut() |= u[1] << sh;
*res.high_mut() |= u[2] << sh;
}
res
}
#[inline]
const fn split_u128_to_u128(a: u128) -> (u128, u128) {
(a & 0xFFFFFFFFFFFFFFFF, a >> (N_UDWORD_BITS / 2))
}
#[inline]
const fn fullmul_u128(a: u128, b: u128) -> (u128, u128) {
let (a0, a1) = split_u128_to_u128(a);
let (b0, b1) = split_u128_to_u128(b);
let mut t = a0 * b0;
let mut k: u128;
let w3: u128;
(w3, k) = split_u128_to_u128(t);
t = a1 * b0 + k;
let (w1, w2) = split_u128_to_u128(t);
t = a0 * b1 + w1;
k = t >> 64;
let w_hi = a1 * b1 + w2 + k;
let w_lo = (t << 64) + w3;
(w_lo, w_hi)
}
#[inline]
fn fullmul_u256_u128(a: &U256, b: u128) -> [u128; 3] {
let mut acc = [0_u128; 3];
let mut lo: u128;
let mut carry: u128;
let c: bool;
if b != 0 {
(lo, carry) = fullmul_u128(*a.low(), b);
acc[0] = lo;
acc[1] = carry;
(lo, carry) = fullmul_u128(*a.high(), b);
(acc[1], c) = acc[1].overflowing_add(lo);
acc[2] = carry + c as u128;
}
acc
}
#[inline]
const fn add_carry(a: u128, b: u128, c: bool) -> (u128, bool) {
let (res1, overflow1) = b.overflowing_add(c as u128);
let (res2, overflow2) = u128::overflowing_add(a, res1);
(res2, overflow1 || overflow2)
}
#[inline]
const fn sub_carry(a: u128, b: u128, c: bool) -> (u128, bool) {
let (res1, overflow1) = b.overflowing_add(c as u128);
let (res2, overflow2) = u128::overflowing_sub(a, res1);
(res2, overflow1 || overflow2)
}
let shift = v.high().leading_zeros();
debug_assert!(shift < N_UDWORD_BITS);
let v = v << shift;
debug_assert!(v.high() >> (N_UDWORD_BITS - 1) == 1);
let mut u = full_shl(u, shift);
let mut q = U256::ZERO;
let v_n_1 = *v.high();
let v_n_2 = *v.low();
let mut r_hat: u128 = 0;
let u_jn = u[2];
let mut q_hat = if u_jn < v_n_1 {
let mut q_hat = udiv256_by_128_to_128(u_jn, u[1], v_n_1, &mut r_hat);
let mut overflow: bool;
loop {
let another_iteration = {
let (lo, hi) = fullmul_u128(q_hat, v_n_2);
hi > r_hat || (hi == r_hat && lo > u[0])
};
if !another_iteration {
break;
}
q_hat -= 1;
(r_hat, overflow) = r_hat.overflowing_add(v_n_1);
if overflow {
break;
}
}
q_hat
} else {
u128::MAX
};
let q_hat_v = fullmul_u256_u128(&v, q_hat);
let mut c = false;
(u[0], c) = sub_carry(u[0], q_hat_v[0], c);
(u[1], c) = sub_carry(u[1], q_hat_v[1], c);
(u[2], c) = sub_carry(u[2], q_hat_v[2], c);
if c {
q_hat -= 1;
c = false;
(u[0], c) = add_carry(u[0], *v.low(), c);
(u[1], c) = add_carry(u[1], *v.high(), c);
u[2] = u[2].wrapping_add(c as u128);
}
*q.low_mut() = q_hat;
let remainder = full_shr(&u, shift);
(q, remainder)
}
#[inline]
pub fn udiv2(r: &mut U256, a: &U256) {
let (a, b) = (*r, a);
let res = unsafe { &mut *(r as *mut U256).cast() };
udivmod4(res, &a, b, None);
}
#[inline]
pub fn udiv3(r: &mut MaybeUninit<U256>, a: &U256, b: &U256) {
udivmod4(r, a, b, None);
}
#[inline]
pub fn urem2(r: &mut U256, a: &U256) {
let mut res = MaybeUninit::uninit();
let (a, b) = (*r, a);
let r = unsafe { &mut *(r as *mut U256).cast() };
udivmod4(&mut res, &a, b, Some(r));
}
#[inline]
pub fn urem3(r: &mut MaybeUninit<U256>, a: &U256, b: &U256) {
let mut res = MaybeUninit::uninit();
udivmod4(&mut res, a, b, Some(r));
}
pub fn idivmod4(
res: &mut MaybeUninit<I256>,
a: &I256,
b: &I256,
rem: Option<&mut MaybeUninit<I256>>,
) {
const BITS_IN_TWORD_M1: u32 = 255;
let s_a = a >> BITS_IN_TWORD_M1; let mut s_b = b >> BITS_IN_TWORD_M1; let a = (a ^ s_a).wrapping_sub(s_a); let b = (b ^ s_b).wrapping_sub(s_b); s_b ^= s_a; udivmod4(
cast!(uninit: res),
cast!(ref: &a),
cast!(ref: &b),
cast!(optuninit: rem),
);
let q = unsafe { res.assume_init_ref() };
let q = (q ^ s_b).wrapping_sub(s_b); res.write(q);
if let Some(rem) = rem {
let r = unsafe { rem.assume_init_ref() };
let r = (r ^ s_a).wrapping_sub(s_a);
rem.write(r);
}
}
#[inline]
pub fn idiv2(r: &mut I256, a: &I256) {
let (a, b) = (*r, a);
let res = unsafe { &mut *(r as *mut I256).cast() };
idivmod4(res, &a, b, None);
}
#[inline]
pub fn idiv3(r: &mut MaybeUninit<I256>, a: &I256, b: &I256) {
idivmod4(r, a, b, None);
}
#[inline]
pub fn irem2(r: &mut I256, a: &I256) {
let mut res = MaybeUninit::uninit();
let (a, b) = (*r, a);
let r = unsafe { &mut *(r as *mut I256).cast() };
idivmod4(&mut res, &a, b, Some(r));
}
#[inline]
pub fn irem3(r: &mut MaybeUninit<I256>, a: &I256, b: &I256) {
let mut res = MaybeUninit::uninit();
idivmod4(&mut res, a, b, Some(r));
}
#[cfg(test)]
mod tests {
use super::*;
use crate::AsU256;
fn udiv(a: impl AsU256, b: impl AsU256) -> U256 {
let mut r = MaybeUninit::uninit();
udiv3(&mut r, &a.as_u256(), &b.as_u256());
unsafe { r.assume_init() }
}
fn urem(a: impl AsU256, b: impl AsU256) -> U256 {
let mut r = MaybeUninit::uninit();
urem3(&mut r, &a.as_u256(), &b.as_u256());
unsafe { r.assume_init() }
}
#[test]
fn division() {
assert_eq!(udiv(100, 9), 11);
assert_eq!(udiv(!0u128, U256::ONE << 128u32), 0);
assert_eq!(udiv(U256::from_words(100, 0), U256::from_words(10, 0)), 10);
assert_eq!(udiv(U256::from_words(100, 1337), U256::ONE << 130u32), 25);
assert_eq!(
udiv(U256::from_words(1337, !0), U256::from_words(63, 0)),
21
);
assert_eq!(
udiv(U256::from_words(42, 0), U256::ONE),
U256::from_words(42, 0),
);
assert_eq!(
udiv(U256::from_words(42, 42), U256::ONE << 42),
42u128 << (128 - 42),
);
assert_eq!(
udiv(U256::from_words(1337, !0), 0xc0ffee),
35996389033280467545299711090127855,
);
assert_eq!(
udiv(U256::from_words(42, 0), 99),
144362216269489045105674075880144089708,
);
assert_eq!(
udiv(U256::from_words(100, 100), U256::from_words(1000, 1000)),
0,
);
assert_eq!(
udiv(U256::from_words(1337, !0), U256::from_words(43, !0)),
30,
);
}
#[test]
#[should_panic]
fn division_by_zero() {
udiv(1, 0);
}
#[test]
fn remainder() {
assert_eq!(urem(100, 9), 1);
assert_eq!(urem(!0u128, U256::ONE << 128u32), !0u128);
assert_eq!(urem(U256::from_words(100, 0), U256::from_words(10, 0)), 0);
assert_eq!(urem(U256::from_words(100, 1337), U256::ONE << 130u32), 1337);
assert_eq!(
urem(U256::from_words(1337, !0), U256::from_words(63, 0)),
U256::from_words(14, !0),
);
assert_eq!(urem(U256::from_words(42, 0), U256::ONE), 0);
assert_eq!(urem(U256::from_words(42, 42), U256::ONE << 42), 42);
assert_eq!(urem(U256::from_words(1337, !0), 0xc0ffee), 1910477);
assert_eq!(urem(U256::from_words(42, 0), 99), 60);
assert_eq!(
urem(U256::from_words(100, 100), U256::from_words(1000, 1000)),
U256::from_words(100, 100),
);
assert_eq!(
urem(U256::from_words(1337, !0), U256::from_words(43, !0)),
U256::from_words(18, 29),
);
}
#[test]
#[should_panic]
fn remainder_by_zero() {
urem(1, 0);
}
}