use core:: {
cmp::Ordering,
mem,
};
use super::div::bitlen;
#[derive(Clone, Copy, Debug, Default, Eq, Ord, Hash)]
pub struct U256 {
hi: u128,
lo: u128,
}
const HALF_BITS: u32 = u128::BITS / 2;
const HMASK: u128 = !0 >> HALF_BITS;
const ZERO: U256 = U256::new(0);
const ONE: U256 = U256::new(1);
const fn lo64(u: u128) -> u128 {
u & HMASK
}
const fn hi64(u: u128) -> u128 {
u >> HALF_BITS
}
impl U256 {
pub const BITS: u16 = mem::size_of::<Self>() as u16 * u8::BITS as u16;
pub const fn new(lo: u128) -> Self {
Self { hi: 0, lo, }
}
pub const fn from_prod(lhs: u128, rhs: u128) -> Self {
let lo = lo64(lhs) * lo64(rhs);
let m0 = hi64(lhs) * lo64(rhs);
let m1 = lo64(lhs) * hi64(rhs);
let hi = hi64(lhs) * hi64(rhs);
let (m1, overflow) = m0.overflowing_add(m1);
let m0 = lo64(m1) + hi64(lo);
let m1 = (hi64(m1) | (overflow as u128) << HALF_BITS) + hi64(m0);
Self {
hi: hi + m1,
lo: m0 << HALF_BITS | lo64(lo),
}
}
pub const fn mul128(mut self, rhs: u128) -> (Self, u16) {
let l0 = lo64(self.lo) * lo64(rhs);
let l1 = hi64(self.lo) * lo64(rhs);
let l1h = hi64(l1);
let l1l = lo64(l1) << HALF_BITS;
let l2 = lo64(self.hi) * lo64(rhs);
let l3 = hi64(self.hi) * lo64(rhs);
let l3l = lo64(l3) << HALF_BITS;
let l3h = hi64(l3);
let h0 = lo64(self.lo) * hi64(rhs);
let h0h = hi64(h0);
let h0l = lo64(h0) << HALF_BITS;
let h1 = hi64(self.lo) * hi64(rhs);
let h2 = lo64(self.hi) * hi64(rhs);
let h2l = lo64(h2) << HALF_BITS;
let h2h = hi64(h2);
let h3 = hi64(self.hi) * hi64(rhs);
let mut overflow = 0u128;
self.lo = match l0.overflowing_add(l1l) {
(s, false) => s,
(s, true) => {
overflow += 1;
s
},
};
self.lo = match self.lo.overflowing_add(h0l) {
(s, false) => s,
(s, true) => {
overflow += 1;
s
},
};
self.hi = match overflow.overflowing_add(l1h) {
(s, false) => {
overflow = 0;
s
},
(s, true) => {
overflow = 1;
s
},
};
self.hi = match self.hi.overflowing_add(l2) {
(s, false) => s,
(s, true) => {
overflow += 1;
s
},
};
self.hi = match self.hi.overflowing_add(l3l) {
(s, false) => s,
(s, true) => {
overflow += 1;
s
},
};
self.hi = match self.hi.overflowing_add(h0h) {
(s, false) => s,
(s, true) => {
overflow += 1;
s
},
};
self.hi = match self.hi.overflowing_add(h1) {
(s, false) => s,
(s, true) => {
overflow += 1;
s
},
};
self.hi = match self.hi.overflowing_add(h2l) {
(s, false) => s,
(s, true) => {
overflow += 1;
s
},
};
overflow += l3h + h2h + h3;
(self, if overflow != 0 {
let n = bitlen(overflow);
self = self.shr(n);
self.hi |= overflow << u128::BITS as u8 - n;
n as u16
} else {
0
})
}
pub const fn mul_exp5(mut self, mut exp5: u16) -> (Self, u16) {
const HEAD1: u128 = 1 << 126;
const HEAD2: u128 = 1 << 127;
const HEAD3: u128 = 3 << 126;
let mut round_down = 0;
while exp5 > 0 {
exp5 -= 1;
let head = self.hi & !(!0 >> 2);
let (added, overflowed) = self.overflowing_add(self.shl(2));
let (head, rounded) = match (head, overflowed) {
(HEAD1, false) => (HEAD1, 1),
(HEAD1, true) => (HEAD2, 2),
(HEAD2, false) => (HEAD2, 2),
(HEAD2, true) => (HEAD3, 2),
(HEAD3, false) => (HEAD3, 2),
(HEAD3, true) => (HEAD2, 3),
(_, true) => (HEAD1, 1),
(_, _) => {
self = added;
continue;
},
};
self = added.shr(rounded).bitor(Self { hi: head, lo: 0 });
round_down += rounded as u16;
}
(self, round_down)
}
pub const fn shl(self, rhs: u8) -> Self {
if rhs == 0 {
self
} else if u128::BITS as u8 > rhs {
Self {
hi: self.hi << rhs | self.lo >> u128::BITS as u8 - rhs,
lo: self.lo << rhs,
}
} else {
Self {
hi: self.lo << rhs - u128::BITS as u8,
lo: 0,
}
}
}
pub const fn shr(self, rhs: u8) -> Self {
if rhs == 0 {
self
} else if u128::BITS as u8 > rhs {
Self {
hi: self.hi >> rhs,
lo: self.hi << u128::BITS as u8 - rhs | self.lo & !0 >> rhs,
}
} else {
Self {
hi: 0,
lo: self.hi >> rhs - u128::BITS as u8,
}
}
}
pub const fn bitand(self, rhs: Self) -> Self {
Self {
hi: self.hi & rhs.hi,
lo: self.lo & rhs.lo,
}
}
pub const fn bitor(self, rhs: Self) -> Self {
Self {
hi: self.hi | rhs.hi,
lo: self.lo | rhs.lo,
}
}
pub const fn not(self) -> Self {
Self {
hi: !self.hi,
lo: !self.lo,
}
}
pub const fn cmp(&self, other: &Self) -> Ordering {
if self.hi == other.hi {
if self.lo == other.lo {
Ordering::Equal
} else if self.lo < other.lo {
Ordering::Less
} else {
Ordering::Greater
}
} else if self.hi < other.hi {
Ordering::Less
} else {
Ordering::Greater
}
}
pub const fn eq(&self, other: &Self) -> bool {
self.hi == other.hi && self.lo == other.lo
}
pub const fn bitlen(self) -> u16 {
let len = bitlen(self.hi);
if len == 0 {
bitlen(self.lo) as u16
} else {
len as u16 + u128::BITS as u16
}
}
pub const fn is_set(self, p: u8) -> bool {
let result = ONE.shl(p).bitand(self);
result.lo != 0 || result.hi != 0
}
pub const fn leading_zeros(self) -> u8 {
(if self.hi == 0 {
self.lo.leading_zeros() + u128::BITS
} else {
self.hi.leading_zeros()
}) as u8
}
pub const fn trailing_zeros(self) -> u8 {
(if self.lo == 0 {
self.hi.trailing_zeros() + u128::BITS
} else {
self.lo.trailing_zeros()
}) as u8
}
pub const fn extract2(self) -> (Self, u8) {
if self.eq(&ZERO) {
(self, 0)
} else {
let exp2 = self.trailing_zeros();
(self.shr(exp2), exp2)
}
}
pub const fn extract5(self, extr: bool) -> (Self, u8) {
const FIVE: U256 = U256::new(5);
let (mut q, mut exp5) = if extr {
let (q, r) = self.div_rem(FIVE, true);
if !r.eq(&ZERO) {
return (self, 0);
}
(q, 1)
} else if self.eq(&ZERO) {
return (self, 0);
} else {
(self, 0)
};
loop {
let (qq, r) = q.div_rem(FIVE, false);
if !r.eq(&ZERO) {
break (q, exp5);
}
exp5 += 1;
q = qq;
}
}
pub const fn round_down(self) -> (u128, u8) {
if self.hi == 0 {
(self.lo, 0)
} else {
let exp2 = (self.bitlen() - u128::BITS as u16) as u8;
(self.shr(exp2).lo, exp2)
}
}
pub const fn reduce(a: Self, b: Self) -> (Self, Self) {
let gcd = a.gcd(b);
(a.div_rem(gcd, false).0, b.div_rem(gcd, false).0)
}
pub const fn gcd(mut self, mut rhs: Self) -> Self {
debug_assert!(
self.eq(&ZERO) && rhs.eq(&ZERO) &&
self.trailing_zeros() == 0 && rhs.trailing_zeros() == 0
);
loop {
let (big, small) = match self.diff(rhs) {
(diff, true) => (rhs, diff),
(diff, false) => (self, diff),
};
if small.eq(&ZERO) {
break big;
}
self = big;
rhs = small.extract2().0;
}
}
pub const fn overflowing_add(self, rhs: Self) -> (Self, bool) {
let (lo, overflowed) = self.lo.overflowing_add(rhs.lo);
let (hi, overflowed) = if overflowed {
(self.hi + rhs.hi).overflowing_add(1)
} else {
self.hi.overflowing_add(rhs.hi)
};
(Self { hi, lo }, overflowed)
}
const fn unchecked_sub(self, rhs: Self) -> Self {
let (lo, overflowed) = self.lo.overflowing_sub(rhs.lo);
Self {
hi: self.hi - rhs.hi - overflowed as u128,
lo,
}
}
pub const fn diff(self, rhs: Self) -> (Self, bool) {
match self.cmp(&rhs) {
Ordering::Greater => (self.unchecked_sub(rhs), false),
Ordering::Less => (rhs.unchecked_sub(self), true),
Ordering::Equal => (ZERO, false),
}
}
pub const fn div_rem(self, rhs: Self, extr: bool) -> (Self, Self) {
assert!(!rhs.eq(&ZERO));
if extr {
if rhs.hi != 0 {
self.div_rem256(rhs, true)
} else {
self.div_rem128(rhs.lo, true)
}
} else {
match self.cmp(&rhs) {
Ordering::Greater => match (self.hi == 0, rhs.hi == 0) {
(_, false) => self.div_rem256(rhs, false),
(false, true) => self.div_rem128(rhs.lo, false),
(true, true) => (Self::new(self.lo / rhs.lo), Self::new(self.lo % rhs.lo)),
},
Ordering::Equal => (ONE, ZERO),
Ordering::Less => (ZERO, self),
}
}
}
const fn div_rem128(self, rhs: u128, extr: bool) -> (Self, Self) {
let (mut qbits, mut numer) = if extr {
(u128::BITS as u16 + 1, self.hi >> 1 & 1 << u128::BITS - 1)
} else {
(u128::BITS as u16, self.hi)
};
let mut q = ZERO;
loop {
let (qq, r) = (numer / rhs, numer % rhs);
let n = u128::BITS as u8 - bitlen(r);
q = q.bitor(Self::new(qq));
if qbits == 0 {
return (q, Self::new(r));
}
if n as u16 >= qbits {
numer = r << qbits | !(!0 << qbits) & self.lo;
q = q.shl(qbits as u8);
qbits = 0;
} else {
numer = r << n | if qbits > u128::BITS as u16 {
(self.hi & 1) << n - 1 | self.lo >> u128::BITS as u8 - n + 1
} else {
self.lo >> u128::BITS as u8 - n
};
q = q.shl(n);
qbits -= n as u16;
}
}
}
const fn div_rem256(self, rhs: Self, extr: bool) -> (Self, Self) {
let dlen = rhs.bitlen();
let mut qbits = self.bitlen() - dlen;
let mut numer = if extr {
qbits += 1;
ONE.shl(dlen as u8 - 1).bitor(self.shr(qbits as u8 - 1))
} else {
self.shr(qbits as u8)
};
let mut q = ZERO;
let mut n = 0;
loop {
q = q.shl(n);
match numer.cmp(&rhs) {
Ordering::Less => {
qbits -= 1;
numer = numer.shl(1);
if self.is_set(qbits as u8) {
numer = numer.bitor(ONE);
}
n = 1;
},
_ => {
let r = numer.unchecked_sub(rhs);
q = q.shl(n).bitor(ONE);
n = (dlen - r.bitlen()) as u8;
qbits = match qbits.overflowing_sub(n as u16) {
(_, true) => {
let mask = ZERO.not().shl(qbits as u8).not();
break (q, r.shl(qbits as u8).bitor(self.bitand(mask)))
},
(qbits, false) => qbits,
};
numer = r.shl(n).bitor(
self.shr(qbits as u8).bitand(
ZERO.not().shl(n).not()
)
);
},
}
}
}
}
impl PartialEq for U256 {
fn eq(&self, other: &Self) -> bool {
U256::eq(self, other)
}
}
impl PartialOrd for U256 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
U256::cmp(self, other).into()
}
}