use uint::construct_uint;
construct_uint! {
pub struct U128(2);
}
construct_uint! {
pub struct U256(4);
}
pub trait Upcast {
fn as_u256(self) -> U256;
}
impl Upcast for U128 {
fn as_u256(self) -> U256 {
U256([self.0[0], self.0[1], 0, 0])
}
}
pub trait Downcast {
fn as_u128(self) -> u128;
fn as_u64(self) -> u64;
fn checked_as_u128(self) -> Option<u128>;
fn checked_as_u64(self) -> Option<u64>;
fn is_zero(self) -> bool;
}
impl Downcast for U256 {
#[inline]
fn as_u128(self) -> u128 {
self.checked_as_u128().unwrap()
}
#[inline]
fn as_u64(self) -> u64 {
self.checked_as_u64().unwrap()
}
#[inline]
fn checked_as_u128(self) -> Option<u128> {
if self.0[2] != 0 || self.0[3] != 0 {
return None;
}
Some(((self.0[1] as u128) << 64) + self.0[0] as u128)
}
#[inline]
fn checked_as_u64(self) -> Option<u64> {
if self.0[1] != 0 || self.0[2] != 0 || self.0[3] != 0 {
return None;
}
Some(self.0[0])
}
#[inline]
fn is_zero(self) -> bool {
return self.0[0] == 0 && self.0[1] == 0 && self.0[2] == 0 && self.0[3] == 0;
}
}
pub trait LowHigh {
fn lo(self) -> u64;
fn hi(self) -> u64;
fn lo_u128(self) -> u128;
fn hi_u128(self) -> u128;
fn from_hi_lo(hi: u64, lo: u64) -> u128;
}
const U64_MAX: u128 = u64::MAX as u128;
impl LowHigh for u128 {
#[inline]
fn lo(self) -> u64 {
(self & U64_MAX) as u64
}
#[inline]
fn hi(self) -> u64 {
(self >> 64) as u64
}
#[inline]
fn lo_u128(self) -> u128 {
self & U64_MAX
}
#[inline]
fn hi_u128(self) -> u128 {
self >> 64
}
#[inline]
fn from_hi_lo(hi: u64, lo: u64) -> u128 {
(hi as u128) << 64 | (lo as u128)
}
}
pub trait Shift {
type Output;
fn shift_left(self, num: u32) -> Self::Output;
fn shift_word_left(self) -> Self::Output;
fn checked_shift_word_left(self) -> Option<Self::Output>;
fn shift_right(self, num: u32) -> Self::Output;
fn shift_word_right(self) -> Self::Output;
}
impl Shift for U256 {
type Output = U256;
fn shift_left(self, mut num: u32) -> Self::Output {
if num >= 256u32 {
return U256([0, 0, 0, 0]);
}
let mut result = self.clone();
while num > 64 {
result = result.shift_word_left();
num -= 64;
}
if num == 0 {
return result;
}
result.0[3] = result.0[3] << num | (result.0[2] >> (64 - num));
result.0[2] = result.0[2] << num | (result.0[1] >> (64 - num));
result.0[1] = result.0[1] << num | (result.0[0] >> (64 - num));
result.0[0] = result.0[0] << num;
result
}
#[inline]
fn shift_word_left(self) -> Self::Output {
U256([0, self.0[0], self.0[1], self.0[2]])
}
#[inline]
fn checked_shift_word_left(self) -> Option<Self::Output> {
if self.0[3] > 0 {
return None;
}
Some(self.shift_word_left())
}
fn shift_right(self, mut num: u32) -> Self::Output {
if num >= 256u32 {
return U256([0, 0, 0, 0]);
}
let mut result = self.clone();
while num >= 64 {
result = result.shift_word_right();
num -= 64;
}
if num == 0 {
return result;
}
result.0[0] = result.0[0] >> num | (result.0[1] << (64 - num));
result.0[1] = result.0[1] >> num | (result.0[2] << (64 - num));
result.0[2] = result.0[2] >> num | (result.0[3] << (64 - num));
result.0[3] = result.0[3] >> num;
result
}
#[inline]
fn shift_word_right(self) -> Self::Output {
U256([self.0[1], self.0[2], self.0[3], 0])
}
}
#[cfg(test)]
mod test_low_high {
use proptest::prelude::*;
use super::*;
#[test]
fn test_up_cast() {
let n = U128::from(128u128);
let n256 = n.as_u256();
assert_eq!(n256, U256::from(128));
let n2 = U128::from(u128::MAX);
let n256 = n2.as_u256();
assert_eq!(n256, U256::from(u128::MAX));
}
#[test]
fn test_down_cast() {
let n = U256::from(u128::MAX);
let n_128 = n.as_u128();
assert_eq!(u128::MAX, n_128);
}
#[test]
#[should_panic(expected = "called `Option::unwrap()` on a `None` value")]
fn test_down_cast_panic() {
let n = U256::from(u128::MAX);
n.as_u64();
}
#[test]
fn test_lo() {
let v1 = 1u128;
let v_min = 0u128;
let v_32_max = u32::MAX as u128;
let v_64_max = u64::MAX as u128;
let v_96_max = ((u32::MAX as u128) << 64) + u64::MAX as u128;
let v_128_max = u128::MAX;
assert_eq!(v1.lo(), 1);
assert_eq!(v_min.lo(), 0);
assert_eq!(v_32_max.lo(), u32::MAX as u64);
assert_eq!(v_64_max.lo(), u64::MAX);
assert_eq!(v_96_max.lo(), u64::MAX);
assert_eq!(v_128_max.lo(), u64::MAX);
}
#[test]
fn test_hi() {
let v1 = 1u128;
let v_min = 0u128;
let v_96_max = ((u32::MAX as u128) << 64) + u64::MAX as u128;
let v_128_max = u128::MAX;
assert_eq!(v1.hi(), 0);
assert_eq!(v_min.hi(), 0);
assert_eq!(v_96_max.hi(), u32::MAX as u64);
assert_eq!(v_128_max.hi(), u64::MAX);
}
#[test]
fn test_hi_u128() {
let v1 = 1u128;
let v_min = 0u128;
let v_96_max = ((u32::MAX as u128) << 64) + u64::MAX as u128;
let v_128_max = u128::MAX;
assert_eq!(v1.hi_u128(), 0u128);
assert_eq!(v_min.hi_u128(), 0u128);
assert_eq!(v_96_max.hi_u128(), u32::MAX as u128);
assert_eq!(v_128_max.hi_u128(), u64::MAX as u128);
}
#[test]
fn test_lo_u128() {
let v1 = 1u128;
let v_min = 0u128;
let v_96_max = ((u32::MAX as u128) << 64) + u64::MAX as u128;
let v_128_max = u128::MAX;
assert_eq!(v1.lo_u128(), 1u128);
assert_eq!(v_min.lo_u128(), 0u128);
assert_eq!(v_96_max.lo_u128(), u64::MAX as u128);
assert_eq!(v_128_max.lo_u128(), u64::MAX as u128);
}
proptest! {
#[test]
fn fuzz_test(n in u128::MIN..u128::MAX) {
assert_eq!(n.lo(), (n & U64_MAX) as u64);
assert_eq!(n.hi(), (n >> 64) as u64);
assert_eq!(n.lo_u128(), n & U64_MAX);
assert_eq!(n.hi_u128(), n >> 64 );
}
}
}
#[cfg(test)]
mod test_shift {
use proptest::prelude::*;
proptest! {
#[test]
fn fuzz_test(n in u128::MIN..u128::MAX) {
let v = U256([n.lo(), n.hi(), n.lo(), n.hi()]);
assert_eq!(v.shift_word_left(), v << 64);
assert_eq!(v.shift_word_right(), v >> 64);
for i in 0u32..100u32 {
assert_eq!(v.shift_right(i), v >> i);
}
}
}
}
#[cfg(test)]
mod normal_test {
use super::*;
#[test]
#[should_panic(expected = "arithmetic operation overflow")]
fn u256_mul() {
let a = U256::from(u128::MAX) + 1;
let b = U256::from(u128::MAX) + 1;
let r = a * b;
println!("{}", r);
}
}