use num_bigint::{BigInt, ParseBigIntError, RandBigInt, Sign, ToBigInt};
use num_traits::{FromPrimitive, Num, One, Signed, ToBytes, ToPrimitive, Zero};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::{
cmp::Ordering,
fmt::{Display, Formatter},
iter::Sum,
ops::{Add, BitAnd, Div, Mul, Neg, Rem, Shr, Sub},
str::FromStr,
};
use Number::{BigNum, SmallNum};
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum Number {
SmallNum(i128),
BigNum(BigInt),
}
impl From<Number> for f64 {
fn from(number: Number) -> Self {
match number {
SmallNum(i) => i.to_f64().unwrap(),
BigNum(n) => n.to_f64().unwrap(),
}
}
}
impl From<Number> for BigInt {
#[inline(always)]
fn from(value: Number) -> Self {
match value {
SmallNum(n) => n.into(),
BigNum(n) => n,
}
}
}
impl From<BigInt> for Number {
#[inline(always)]
fn from(value: BigInt) -> Self {
if let Some(val64) = value.to_i64() {
SmallNum(val64 as i128)
} else {
BigNum(value)
}
}
}
impl From<f64> for Number {
fn from(value: f64) -> Self {
BigInt::from_f64(value).expect("BigInt from f64.").into()
}
}
impl From<f32> for Number {
fn from(value: f32) -> Self {
BigInt::from_f32(value).expect("BigInt from f32.").into()
}
}
impl<const N: usize> From<[u8; N]> for Number {
fn from(value: [u8; N]) -> Self {
BigInt::from_bytes_le(Sign::Plus, &value).into()
}
}
fn cut_off<const N: usize>(val: &[u8]) -> [u8; N] {
(0..N)
.map(|i| if i < val.len() { val[i] } else { 0 })
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
impl<const N: usize> From<Number> for [u8; N] {
fn from(value: Number) -> Self {
match value {
SmallNum(i) => cut_off(&i.to_le_bytes()),
BigNum(i) => cut_off(&i.to_le_bytes()),
}
}
}
impl From<&str> for Number {
#[inline(always)]
fn from(value: &str) -> Self {
Number::from(BigInt::from_str(value).unwrap())
}
}
macro_rules! impl_from_low {
($t: ty) => {
impl From<$t> for Number {
#[inline(always)]
fn from(value: $t) -> Self {
SmallNum(value as i128)
}
}
impl_from_ref!($t);
};
}
macro_rules! impl_from_high {
($t: ty) => {
impl From<$t> for Number {
#[inline(always)]
fn from(value: $t) -> Self {
if let Some(val64) = value.to_i64() {
SmallNum(val64 as i128)
} else {
BigNum(value.into())
}
}
}
impl_from_ref!($t);
};
}
macro_rules! impl_from_ref {
($t: ty) => {
impl<'a> From<&'a $t> for Number {
#[inline(always)]
fn from(value: &'a $t) -> Self {
Number::from(*value)
}
}
};
}
impl_from_low!(bool);
impl_from_low!(u8);
impl_from_low!(u16);
impl_from_low!(u32);
impl_from_low!(i8);
impl_from_low!(i16);
impl_from_low!(i32);
impl_from_high!(usize);
impl_from_high!(u64);
impl_from_high!(u128);
impl_from_high!(isize);
impl_from_high!(i64);
impl_from_high!(i128);
macro_rules! match_binary_op {
($f: ident, $s: ident, $r: ident) => {
match ($s, $r) {
(SmallNum(a), SmallNum(b)) => a.$f(b).into(),
(SmallNum(a), BigNum(b)) => a.$f(b).into(),
(BigNum(a), SmallNum(b)) => a.$f(b).into(),
(BigNum(a), BigNum(b)) => a.$f(b).into(),
}
};
}
macro_rules! match_single_op {
($f: ident, $n: ident, $i: expr) => {
match ($n) {
SmallNum(a) => a.$f($i as i128).into(),
BigNum(a) => a.$f($i).into(),
}
};
}
macro_rules! match_single_op_reverse {
($f: ident, $n: ident, $i: expr) => {
match ($n) {
SmallNum(a) => ($i as i128).$f(a).into(),
BigNum(a) => ($i as i128).$f(a).into(),
}
};
}
macro_rules! binary_op {
($t: ident, $f: ident) => {
impl $t<Number> for Number {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: Number) -> Number {
match_binary_op!($f, self, rhs)
}
}
impl<'b> $t<&'b Number> for Number {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: &'b Number) -> Number {
match_binary_op!($f, self, rhs)
}
}
impl $t<i32> for Number {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: i32) -> Number {
match_single_op!($f, self, rhs)
}
}
impl $t<&i32> for Number {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: &i32) -> Number {
match_single_op!($f, self, *rhs)
}
}
impl $t<Number> for i32 {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: Number) -> Number {
match_single_op_reverse!($f, rhs, self)
}
}
impl $t<Number> for &i32 {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: Number) -> Number {
match_single_op_reverse!($f, rhs, *self)
}
}
impl<'a> $t<Number> for &'a Number {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: Number) -> Number {
match_binary_op!($f, self, rhs)
}
}
impl<'a, 'b> $t<&'b Number> for &'a Number {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: &'b Number) -> Number {
match_binary_op!($f, self, rhs)
}
}
impl<'a> $t<i32> for &'a Number {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: i32) -> Number {
match_single_op!($f, self, rhs)
}
}
impl<'a> $t<&i32> for &'a Number {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: &i32) -> Number {
match_single_op!($f, self, *rhs)
}
}
impl<'a> $t<&'a Number> for i32 {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: &'a Number) -> Number {
match_single_op_reverse!($f, rhs, self)
}
}
impl<'a> $t<&'a Number> for &i32 {
type Output = Number;
#[inline(always)]
fn $f(self, rhs: &'a Number) -> Number {
match_single_op_reverse!($f, rhs, *self)
}
}
};
}
binary_op!(Add, add);
binary_op!(Sub, sub);
binary_op!(Mul, mul);
binary_op!(Div, div);
binary_op!(Rem, rem);
impl BitAnd for Number {
type Output = Number;
fn bitand(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(SmallNum(a), SmallNum(b)) => SmallNum(a.bitand(b)),
(BigNum(a), SmallNum(b)) => a
.bitand(b.to_bigint().expect("i128 to BigInt always works"))
.into(),
(SmallNum(a), BigNum(b)) => b
.bitand(a.to_bigint().expect("i128 to BigInt always works."))
.into(),
(BigNum(a), BigNum(b)) => a.bitand(b).into(),
}
}
}
impl Shr<usize> for Number {
type Output = Number;
#[inline(always)]
fn shr(self, rhs: usize) -> Self::Output {
match self {
SmallNum(n) => (n >> rhs.min(127)).into(), BigNum(n) => (n >> rhs).into(),
}
}
}
impl Shr<&usize> for Number {
type Output = Number;
#[inline(always)]
fn shr(self, rhs: &usize) -> Self::Output {
match self {
SmallNum(n) => (n >> rhs).into(),
BigNum(n) => (n >> rhs).into(),
}
}
}
impl PartialOrd<Self> for Number {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq<i32> for Number {
fn eq(&self, other: &i32) -> bool {
self == &Number::from(*other)
}
}
impl PartialEq<Number> for i32 {
fn eq(&self, other: &Number) -> bool {
other == &Number::from(*self)
}
}
impl PartialOrd<i32> for Number {
#[inline(always)]
fn partial_cmp(&self, other: &i32) -> Option<Ordering> {
Some(self.cmp(&Number::from(*other)))
}
}
impl PartialOrd<Number> for i32 {
#[inline(always)]
fn partial_cmp(&self, other: &Number) -> Option<Ordering> {
Some(Number::from(*self).cmp(other))
}
}
impl Ord for Number {
#[inline(always)]
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(SmallNum(a), SmallNum(b)) => a.cmp(b),
(SmallNum(_), BigNum(b)) => match b.sign() {
Sign::Minus => Ordering::Greater,
Sign::NoSign => unreachable!("zero is a small num, not a big num"),
Sign::Plus => Ordering::Less,
},
(BigNum(a), SmallNum(_)) => match a.sign() {
Sign::Minus => Ordering::Less,
Sign::NoSign => unreachable!("zero is a small num, not a big num"),
Sign::Plus => Ordering::Greater,
},
(BigNum(a), BigNum(b)) => a.cmp(b),
}
}
}
impl Num for Number {
type FromStrRadixErr = ParseBigIntError;
fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
Ok(BigInt::from_str_radix(str, radix)?.into())
}
}
impl FromStr for Number {
type Err = ParseBigIntError;
fn from_str(str: &str) -> Result<Self, Self::Err> {
if let Some(stripped) = str.strip_prefix("0x") {
Self::from_str_radix(stripped, 16)
} else if let Some(stripped) = str.strip_prefix("0o") {
Self::from_str_radix(stripped, 8)
} else if let Some(stripped) = str.strip_prefix("0b") {
Self::from_str_radix(stripped, 2)
} else {
Self::from_str_radix(str, 10)
}
}
}
impl Zero for Number {
#[inline(always)]
fn zero() -> Self {
SmallNum(0)
}
#[inline(always)]
fn is_zero(&self) -> bool {
*self == SmallNum(0)
}
}
impl One for Number {
#[inline(always)]
fn one() -> Self {
SmallNum(1)
}
}
impl Neg for Number {
type Output = Number;
#[inline(always)]
fn neg(self) -> Self::Output {
match self {
SmallNum(n) => (-n).into(),
BigNum(n) => (-n).into(),
}
}
}
impl Signed for Number {
#[inline(always)]
fn abs(&self) -> Self {
match self {
SmallNum(n) => n.abs().into(),
BigNum(n) => n.abs().into(),
}
}
#[inline(always)]
fn abs_sub(&self, other: &Self) -> Self {
if self <= other {
0.into()
} else {
self.clone() - other.clone()
}
}
#[inline(always)]
fn signum(&self) -> Self {
match self {
SmallNum(n) => n.signum().into(),
BigNum(n) => n.signum().into(),
}
}
#[inline(always)]
fn is_positive(&self) -> bool {
match self {
SmallNum(n) => n.is_positive(),
BigNum(n) => n.is_positive(),
}
}
#[inline(always)]
fn is_negative(&self) -> bool {
match self {
SmallNum(n) => n.is_negative(),
BigNum(n) => n.is_negative(),
}
}
}
impl Number {
#[inline(always)]
pub fn bit(&self, idx: usize) -> bool {
match self {
SmallNum(n) => ((n >> idx.min(127)) & 1) == 1,
BigNum(n) => n.bit(idx as u64),
}
}
#[inline(always)]
pub fn power_of_two(idx: usize) -> Number {
if idx < 63 {
SmallNum(1 << idx)
} else {
BigNum(BigInt::from(1) << idx)
}
}
#[inline(always)]
pub fn negative_power_of_two(idx: usize) -> Number {
if idx < 64 {
SmallNum(-1 << idx)
} else {
BigNum(BigInt::from(-1) << idx)
}
}
#[inline(always)]
pub fn bits(&self) -> usize {
match self {
SmallNum(n) => {
if *n == 0 {
1
} else {
1 + n.abs().ilog2() as usize
}
}
BigNum(n) => n.bits() as usize,
}
}
pub fn gen_range<R: Rng + ?Sized>(rng: &mut R, lower: &Number, upper: &Number) -> Number {
match (lower, upper) {
(SmallNum(l), SmallNum(u)) => rng.gen_range((*l)..(*u)).into(),
(BigNum(l), BigNum(u)) => rng.gen_bigint_range(l, u).into(),
(SmallNum(l), BigNum(u)) => rng.gen_bigint_range(&(*l).into(), u).into(),
(BigNum(l), SmallNum(u)) => rng.gen_bigint_range(l, &(*u).into()).into(),
}
}
}
impl Sum for Number {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
let mut sum: Number = SmallNum(0);
for n in iter {
sum = sum + n
}
sum
}
}
impl Display for Number {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
SmallNum(n) => n.fmt(f),
BigNum(n) => n.fmt(f),
}
}
}
impl ToPrimitive for Number {
#[inline(always)]
fn to_i64(&self) -> Option<i64> {
match self {
SmallNum(n) => n.to_i64(),
BigNum(n) => n.to_i64(),
}
}
#[inline(always)]
fn to_i128(&self) -> Option<i128> {
match self {
SmallNum(n) => n.to_i128(),
BigNum(n) => n.to_i128(),
}
}
#[inline(always)]
fn to_u64(&self) -> Option<u64> {
match self {
SmallNum(n) => n.to_u64(),
BigNum(n) => n.to_u64(),
}
}
#[inline(always)]
fn to_u128(&self) -> Option<u128> {
match self {
SmallNum(n) => n.to_u128(),
BigNum(n) => n.to_u128(),
}
}
}
impl Number {
pub fn log2(&self) -> f64 {
f64::from(self.clone()).log2()
}
}
#[cfg(test)]
mod tests {
use crate::utils::number::Number;
#[test]
fn conversions_from_u8_arr() {
assert_eq!(Number::from([]), Number::from(0));
assert_eq!(Number::from([123u8]), Number::from(123));
assert_eq!(Number::from([44u8, 1]), Number::from(300));
}
fn test_conversion_to_u8_arr<const N: usize>(num: Number, true_res: [u8; N]) {
let arr: [u8; N] = num.into();
assert_eq!(arr, true_res);
}
#[test]
fn conversions_to_u8_arr() {
test_conversion_to_u8_arr(Number::from(0), [0u8, 0u8, 0u8]);
test_conversion_to_u8_arr(Number::from(123), [123u8]);
test_conversion_to_u8_arr(Number::from(300), [44u8, 1u8]);
}
}