use core::cmp::Ordering;
use core::fmt;
use oxinum_core::Sign;
use oxinum_int::native::BigUint;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum FloatClass {
#[default]
Finite,
Infinite,
Nan,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RoundingMode {
HalfEven,
HalfAway,
HalfToZero,
ToZero,
ToInf,
ToNegInf,
AwayFromZero,
}
#[derive(Clone)]
pub struct BigFloat {
pub(crate) class: FloatClass,
pub(crate) sign: Sign,
pub(crate) mantissa: BigUint,
pub(crate) exponent: i64,
pub(crate) precision: u32,
}
impl BigFloat {
pub fn zero(prec: u32) -> Self {
assert!(prec > 0, "BigFloat precision must be > 0");
Self {
class: FloatClass::Finite,
sign: Sign::Positive,
mantissa: BigUint::zero(),
exponent: 0,
precision: prec,
}
}
pub fn nan(prec: u32) -> Self {
assert!(prec > 0, "BigFloat precision must be > 0");
Self {
class: FloatClass::Nan,
sign: Sign::Positive,
mantissa: BigUint::zero(),
exponent: 0,
precision: prec,
}
}
pub fn infinity(prec: u32) -> Self {
assert!(prec > 0, "BigFloat precision must be > 0");
Self {
class: FloatClass::Infinite,
sign: Sign::Positive,
mantissa: BigUint::zero(),
exponent: 0,
precision: prec,
}
}
pub fn neg_infinity(prec: u32) -> Self {
assert!(prec > 0, "BigFloat precision must be > 0");
Self {
class: FloatClass::Infinite,
sign: Sign::Negative,
mantissa: BigUint::zero(),
exponent: 0,
precision: prec,
}
}
pub fn from_parts(
sign: Sign,
mantissa: BigUint,
exponent: i64,
prec: u32,
mode: RoundingMode,
) -> Self {
assert!(prec > 0, "BigFloat precision must be > 0");
if mantissa.is_zero() {
return Self::zero(prec);
}
let mut out = Self {
class: FloatClass::Finite,
sign,
mantissa,
exponent,
precision: prec,
};
out.canonicalize_normalize();
out.round_to_precision_in_place(prec, mode);
out
}
#[inline]
pub fn precision(&self) -> u32 {
self.precision
}
#[inline]
pub fn sign(&self) -> Sign {
self.sign
}
#[inline]
pub fn mantissa(&self) -> &BigUint {
&self.mantissa
}
#[inline]
pub fn exponent(&self) -> i64 {
self.exponent
}
#[inline]
pub fn is_zero(&self) -> bool {
matches!(self.class, FloatClass::Finite) && self.mantissa.is_zero()
}
#[inline]
pub fn is_finite(&self) -> bool {
matches!(self.class, FloatClass::Finite)
}
#[inline]
pub fn is_infinite(&self) -> bool {
matches!(self.class, FloatClass::Infinite)
}
#[inline]
pub fn is_nan(&self) -> bool {
matches!(self.class, FloatClass::Nan)
}
#[inline]
pub fn is_normal(&self) -> bool {
matches!(self.class, FloatClass::Finite) && !self.mantissa.is_zero()
}
pub fn classify(&self) -> core::num::FpCategory {
use core::num::FpCategory;
match self.class {
FloatClass::Nan => FpCategory::Nan,
FloatClass::Infinite => FpCategory::Infinite,
FloatClass::Finite if self.mantissa.is_zero() => FpCategory::Zero,
FloatClass::Finite => FpCategory::Normal,
}
}
#[inline]
pub fn is_sign_positive(&self) -> bool {
self.sign == Sign::Positive
}
#[inline]
pub fn is_sign_negative(&self) -> bool {
self.sign == Sign::Negative
}
pub fn signum(&self) -> i32 {
if self.is_zero() {
0
} else if self.sign == Sign::Negative {
-1
} else {
1
}
}
pub fn abs(&self) -> Self {
let mut out = self.clone();
out.sign = Sign::Positive;
out
}
pub fn neg(&self) -> Self {
if self.is_nan() {
return self.clone();
}
if self.is_zero() {
return self.clone();
}
let mut out = self.clone();
out.sign = match self.sign {
Sign::Positive => Sign::Negative,
Sign::Negative => Sign::Positive,
};
out
}
#[must_use]
pub fn with_precision(self, prec: u32, mode: RoundingMode) -> Self {
self.round_to_precision(prec, mode)
}
#[must_use]
pub fn round_to_precision(mut self, prec: u32, mode: RoundingMode) -> Self {
self.round_to_precision_in_place(prec, mode);
self
}
pub(crate) fn round_to_precision_in_place(&mut self, prec: u32, mode: RoundingMode) {
assert!(prec > 0, "BigFloat precision must be > 0");
if !self.is_finite() {
self.precision = prec;
return;
}
self.precision = prec;
if self.mantissa.is_zero() {
self.sign = Sign::Positive;
self.exponent = 0;
return;
}
self.absorb_trailing_zeros();
let cur_bits = self.mantissa.bit_length();
let target = prec as u64;
match cur_bits.cmp(&target) {
Ordering::Less => {
let shift = target - cur_bits;
self.mantissa = self.mantissa.shl_bits(shift);
self.exponent = self.exponent.saturating_sub(shift as i64);
}
Ordering::Equal => { }
Ordering::Greater => {
let drop = cur_bits - target;
self.round_drop_low_bits(drop, mode);
}
}
debug_assert!(
self.mantissa.is_zero() || self.mantissa.bit_length() == self.precision as u64,
"BigFloat normalize invariant violated after round_to_precision",
);
debug_assert!(
!self.mantissa.is_zero() || self.sign == Sign::Positive,
"BigFloat canonical-zero invariant violated",
);
}
pub(crate) fn absorb_trailing_zeros(&mut self) {
if self.mantissa.is_zero() {
return;
}
let tz = self.mantissa.trailing_zeros();
if tz > 0 {
self.mantissa = self.mantissa.shr_bits(tz);
self.exponent = self.exponent.saturating_add(tz as i64);
}
}
pub(crate) fn canonicalize_normalize(&mut self) {
if self.mantissa.is_zero() {
self.sign = Sign::Positive;
self.exponent = 0;
return;
}
self.absorb_trailing_zeros();
let cur_bits = self.mantissa.bit_length();
let target = self.precision as u64;
if cur_bits < target {
let shift = target - cur_bits;
self.mantissa = self.mantissa.shl_bits(shift);
self.exponent = self.exponent.saturating_sub(shift as i64);
}
}
fn round_drop_low_bits(&mut self, drop: u64, mode: RoundingMode) {
debug_assert!(drop > 0);
let round_bit = self.mantissa.test_bit(drop - 1);
let sticky = if drop >= 2 {
(self.mantissa.trailing_zeros()) < (drop - 1)
} else {
false
};
let mut quotient = self.mantissa.shr_bits(drop);
let negative = self.sign == Sign::Negative;
let increment = match mode {
RoundingMode::ToZero => false,
RoundingMode::AwayFromZero => round_bit || sticky,
RoundingMode::ToInf => !negative && (round_bit || sticky),
RoundingMode::ToNegInf => negative && (round_bit || sticky),
RoundingMode::HalfAway => round_bit,
RoundingMode::HalfToZero => round_bit && sticky,
RoundingMode::HalfEven => {
if !round_bit {
false
} else if sticky {
true
} else {
quotient.test_bit(0)
}
}
};
if increment {
let one = BigUint::one();
quotient = "ient + &one;
}
self.exponent = self.exponent.saturating_add(drop as i64);
self.mantissa = quotient;
if self.mantissa.is_zero() {
self.sign = Sign::Positive;
self.exponent = 0;
return;
}
let cur_bits = self.mantissa.bit_length();
let target = self.precision as u64;
match cur_bits.cmp(&target) {
Ordering::Equal => {}
Ordering::Greater => {
let extra = cur_bits - target;
debug_assert_eq!(
extra, 1,
"rounding increment should overflow by at most one bit"
);
self.mantissa = self.mantissa.shr_bits(extra);
self.exponent = self.exponent.saturating_add(extra as i64);
}
Ordering::Less => {
let shift = target - cur_bits;
self.mantissa = self.mantissa.shl_bits(shift);
self.exponent = self.exponent.saturating_sub(shift as i64);
}
}
}
}
impl PartialEq for BigFloat {
fn eq(&self, other: &Self) -> bool {
match (self.class, other.class) {
(FloatClass::Nan, _) | (_, FloatClass::Nan) => false,
(FloatClass::Infinite, FloatClass::Infinite) => self.sign == other.sign,
(FloatClass::Infinite, _) | (_, FloatClass::Infinite) => false,
(FloatClass::Finite, FloatClass::Finite) => {
if self.is_zero() && other.is_zero() {
return true;
}
if self.is_zero() != other.is_zero() {
return false;
}
self.sign == other.sign
&& self.exponent == other.exponent
&& self.mantissa == other.mantissa
}
}
}
}
impl BigFloat {
pub(crate) fn cmp_finite(&self, other: &Self) -> Ordering {
match (self.is_zero(), other.is_zero()) {
(true, true) => return Ordering::Equal,
(true, false) => {
return if other.sign == Sign::Negative {
Ordering::Greater
} else {
Ordering::Less
};
}
(false, true) => {
return if self.sign == Sign::Negative {
Ordering::Less
} else {
Ordering::Greater
};
}
(false, false) => {}
}
match (self.sign, other.sign) {
(Sign::Positive, Sign::Negative) => Ordering::Greater,
(Sign::Negative, Sign::Positive) => Ordering::Less,
(Sign::Positive, Sign::Positive) => cmp_magnitudes(self, other),
(Sign::Negative, Sign::Negative) => cmp_magnitudes(other, self),
}
}
pub fn total_cmp(&self, other: &Self) -> Ordering {
fn rank(x: &BigFloat) -> u8 {
match x.class {
FloatClass::Infinite if x.sign == Sign::Negative => 0,
FloatClass::Finite => 1,
FloatClass::Infinite => 2, FloatClass::Nan => 3,
}
}
let (ra, rb) = (rank(self), rank(other));
if ra != rb {
return ra.cmp(&rb);
}
match self.class {
FloatClass::Finite => self.cmp_finite(other),
_ => Ordering::Equal, }
}
}
impl PartialOrd for BigFloat {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match (self.class, other.class) {
(FloatClass::Nan, _) | (_, FloatClass::Nan) => None,
(FloatClass::Infinite, FloatClass::Infinite) => Some(match (self.sign, other.sign) {
(Sign::Positive, Sign::Positive) | (Sign::Negative, Sign::Negative) => {
Ordering::Equal
}
(Sign::Negative, Sign::Positive) => Ordering::Less,
(Sign::Positive, Sign::Negative) => Ordering::Greater,
}),
(FloatClass::Infinite, FloatClass::Finite) => Some(if self.sign == Sign::Negative {
Ordering::Less
} else {
Ordering::Greater
}),
(FloatClass::Finite, FloatClass::Infinite) => Some(if other.sign == Sign::Negative {
Ordering::Greater
} else {
Ordering::Less
}),
(FloatClass::Finite, FloatClass::Finite) => Some(self.cmp_finite(other)),
}
}
}
pub(crate) fn cmp_magnitudes(a: &BigFloat, b: &BigFloat) -> Ordering {
let top_a = a
.exponent
.saturating_add(a.mantissa.bit_length() as i64 - 1);
let top_b = b
.exponent
.saturating_add(b.mantissa.bit_length() as i64 - 1);
match top_a.cmp(&top_b) {
Ordering::Equal => {
if a.exponent >= b.exponent {
let shift = (a.exponent - b.exponent) as u64;
let lhs = a.mantissa.shl_bits(shift);
lhs.cmp(&b.mantissa)
} else {
let shift = (b.exponent - a.exponent) as u64;
let rhs = b.mantissa.shl_bits(shift);
a.mantissa.cmp(&rhs)
}
}
non_eq => non_eq,
}
}
impl fmt::Display for BigFloat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.class {
FloatClass::Nan => return f.write_str("NaN"),
FloatClass::Infinite => {
return f.write_str(if self.sign == Sign::Negative {
"-inf"
} else {
"inf"
});
}
FloatClass::Finite => {}
}
if self.is_zero() {
return f.write_str("0xb0p0");
}
if self.sign == Sign::Negative {
f.write_str("-")?;
}
f.write_str("0xb")?;
let bits = self.mantissa.bit_length();
for i in (0..bits).rev() {
if self.mantissa.test_bit(i) {
f.write_str("1")?;
} else {
f.write_str("0")?;
}
}
write!(f, "p{}", self.exponent)
}
}
impl fmt::Debug for BigFloat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"BigFloat {{ class: {:?}, sign: {:?}, mantissa: {}, exponent: {}, precision: {} }}",
self.class, self.sign, self.mantissa, self.exponent, self.precision
)
}
}