extern crate alloc;
use super::bigint::BigInt;
use super::bigint::LossFraction;
use core::cmp::Ordering;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RoundingMode {
None,
NearestTiesToEven,
NearestTiesToAway,
Zero,
Positive,
Negative,
}
impl RoundingMode {
pub fn from_string(s: &str) -> Option<Self> {
match s {
"NearestTiesToEven" => Some(RoundingMode::NearestTiesToEven),
"NearestTiesToAway" => Some(RoundingMode::NearestTiesToAway),
"Zero" => Some(RoundingMode::Zero),
"Positive" => Some(RoundingMode::Positive),
"Negative" => Some(RoundingMode::Negative),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Semantics {
pub exponent: usize,
pub precision: usize,
pub mode: RoundingMode,
}
impl Semantics {
pub const fn new(
exponent: usize,
precision: usize,
mode: RoundingMode,
) -> Self {
Semantics {
exponent,
precision,
mode,
}
}
pub fn get_precision(&self) -> usize {
self.precision
}
pub fn get_mantissa_len(&self) -> usize {
self.precision - 1
}
pub fn get_exponent_len(&self) -> usize {
self.exponent
}
pub fn get_rounding_mode(&self) -> RoundingMode {
self.mode
}
pub fn increase_precision(&self, more: usize) -> Semantics {
Semantics::new(self.exponent, self.precision + more, self.mode)
}
pub fn grow_log(&self, more: usize) -> Semantics {
let log2 = self.log_precision();
Semantics::new(self.exponent, self.precision + more + log2, self.mode)
}
pub fn log_precision(&self) -> usize {
64 - (self.precision as u64).leading_zeros() as usize
}
pub fn increase_exponent(&self, more: usize) -> Semantics {
Semantics::new(self.exponent + more, self.precision, self.mode)
}
pub fn with_rm(&self, rm: RoundingMode) -> Semantics {
Semantics::new(self.exponent, self.precision, rm)
}
pub(crate) fn get_bias(&self) -> i64 {
let e = self.get_exponent_len();
((1u64 << (e - 1)) - 1) as i64
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Category {
Infinity,
NaN,
Normal,
Zero,
}
#[derive(Debug, Clone)]
pub struct Float {
sem: Semantics,
sign: bool,
exp: i64,
mantissa: BigInt,
category: Category,
}
impl Float {
pub(crate) fn get_mantissa_len(&self) -> usize {
self.sem.get_mantissa_len()
}
pub(crate) fn get_exponent_len(&self) -> usize {
self.sem.get_exponent_len()
}
pub fn from_parts(
sem: Semantics,
sign: bool,
exp: i64,
mantissa: BigInt,
) -> Self {
if mantissa.is_zero() {
return Float::zero(sem, sign);
}
Float {
sem,
sign,
exp,
mantissa,
category: Category::Normal,
}
}
pub(crate) fn raw(
sem: Semantics,
sign: bool,
exp: i64,
mantissa: BigInt,
category: Category,
) -> Self {
Float {
sem,
sign,
exp,
mantissa,
category,
}
}
pub fn zero(sem: Semantics, sign: bool) -> Self {
Float {
sem,
sign,
exp: 0,
mantissa: BigInt::zero(),
category: Category::Zero,
}
}
pub fn one(sem: Semantics, sign: bool) -> Self {
let mut one = BigInt::one();
one.shift_left(sem.get_mantissa_len());
Float {
sem,
sign,
exp: 0,
mantissa: one,
category: Category::Normal,
}
}
pub fn inf(sem: Semantics, sign: bool) -> Self {
Float {
sem,
sign,
exp: 0,
mantissa: BigInt::zero(),
category: Category::Infinity,
}
}
pub fn nan(sem: Semantics, sign: bool) -> Self {
Float {
sem,
sign,
exp: 0,
mantissa: BigInt::zero(),
category: Category::NaN,
}
}
pub fn is_negative(&self) -> bool {
self.sign
}
pub fn is_inf(&self) -> bool {
if let Category::Infinity = self.category {
return true;
}
false
}
pub fn is_nan(&self) -> bool {
if let Category::NaN = self.category {
return true;
}
false
}
pub fn is_zero(&self) -> bool {
if let Category::Zero = self.category {
return true;
}
false
}
pub fn is_normal(&self) -> bool {
if let Category::Normal = self.category {
return true;
}
false
}
pub fn get_semantics(&self) -> Semantics {
self.sem
}
pub fn get_rounding_mode(&self) -> RoundingMode {
self.sem.get_rounding_mode()
}
pub fn set_sign(&mut self, sign: bool) {
self.sign = sign
}
pub fn get_sign(&self) -> bool {
self.sign
}
pub fn get_mantissa(&self) -> BigInt {
self.mantissa.clone()
}
pub fn get_exp(&self) -> i64 {
self.exp
}
pub fn get_category(&self) -> Category {
self.category
}
pub fn neg(&self) -> Self {
Self::raw(
self.sem,
!self.sign,
self.exp,
self.mantissa.clone(),
self.category,
)
}
pub(super) fn align_mantissa(&mut self) {
let bits =
self.sem.get_precision() as i64 - self.mantissa.msb_index() as i64;
if bits > 0 {
self.exp += bits;
self.mantissa.shift_left(bits as usize);
}
}
#[cfg(feature = "std")]
pub fn dump(&self) {
use std::println;
let sign = if self.sign { "-" } else { "+" };
match self.category {
Category::NaN => {
println!("[{}NaN]", sign);
}
Category::Infinity => {
println!("[{}Inf]", sign);
}
Category::Zero => {
println!("[{}0.0]", sign);
}
Category::Normal => {
let m = self.mantissa.as_binary();
println!("FP[{} E={:4} M = {}]", sign, self.exp, m);
}
}
}
#[cfg(not(feature = "std"))]
pub fn dump(&self) {
}
pub(crate) fn get_bias(&self) -> i64 {
self.sem.get_bias()
}
pub fn get_exp_bounds(&self) -> (i64, i64) {
let exp_min: i64 = -self.get_bias() + 1;
let exp_max: i64 = (1 << self.get_exponent_len()) - self.get_bias() - 2;
(exp_min, exp_max)
}
}
use RoundingMode::NearestTiesToEven as nte;
pub const BF16: Semantics = Semantics::new(8, 8, nte);
pub const FP16: Semantics = Semantics::new(5, 11, nte);
pub const FP32: Semantics = Semantics::new(8, 24, nte);
pub const FP64: Semantics = Semantics::new(11, 53, nte);
pub const FP128: Semantics = Semantics::new(15, 113, nte);
pub const FP256: Semantics = Semantics::new(19, 237, nte);
pub(crate) fn shift_right_with_loss(
val: &BigInt,
bits: usize,
) -> (BigInt, LossFraction) {
let mut val = val.clone();
let loss = val.get_loss_kind_for_bit(bits);
val.shift_right(bits);
(val, loss)
}
fn combine_loss_fraction(msb: LossFraction, lsb: LossFraction) -> LossFraction {
if !lsb.is_exactly_zero() {
if msb.is_exactly_zero() {
return LossFraction::LessThanHalf;
} else if msb.is_exactly_half() {
return LossFraction::MoreThanHalf;
}
}
msb
}
#[test]
fn shift_right_fraction() {
let x: BigInt = BigInt::from_u64(0b10000000);
let res = shift_right_with_loss(&x, 3);
assert!(res.1.is_exactly_zero());
let x: BigInt = BigInt::from_u64(0b10000111);
let res = shift_right_with_loss(&x, 3);
assert!(res.1.is_mt_half());
let x: BigInt = BigInt::from_u64(0b10000100);
let res = shift_right_with_loss(&x, 3);
assert!(res.1.is_exactly_half());
let x: BigInt = BigInt::from_u64(0b10000001);
let res = shift_right_with_loss(&x, 3);
assert!(res.1.is_lt_half());
}
impl Float {
fn overflow(&mut self, rm: RoundingMode) {
let bounds = self.get_exp_bounds();
let inf = Self::inf(self.sem, self.sign);
let max = Self::from_parts(
self.sem,
self.sign,
bounds.1,
BigInt::all1s(self.get_mantissa_len()),
);
*self = match rm {
RoundingMode::None => inf,
RoundingMode::NearestTiesToEven => inf,
RoundingMode::NearestTiesToAway => inf,
RoundingMode::Zero => max,
RoundingMode::Positive => {
if self.sign {
max
} else {
inf
}
}
RoundingMode::Negative => {
if self.sign {
inf
} else {
max
}
}
}
}
pub(crate) fn check_bounds(&self) {
let bounds = self.get_exp_bounds();
debug_assert!(self.exp >= bounds.0);
debug_assert!(self.exp <= bounds.1);
let max_mantissa = BigInt::one_hot(self.sem.get_precision());
debug_assert!(self.mantissa.lt(&max_mantissa));
}
pub(crate) fn shift_significand_left(&mut self, amt: u64) {
self.exp -= amt as i64;
self.mantissa.shift_left(amt as usize);
}
pub(crate) fn shift_significand_right(&mut self, amt: u64) -> LossFraction {
self.exp += amt as i64;
let res = shift_right_with_loss(&self.mantissa, amt as usize);
self.mantissa = res.0;
res.1
}
pub(crate) fn need_round_away_from_zero(
&self,
rm: RoundingMode,
loss: LossFraction,
) -> bool {
debug_assert!(self.is_normal() || self.is_zero());
match rm {
RoundingMode::Positive => !self.sign,
RoundingMode::Negative => self.sign,
RoundingMode::Zero => false,
RoundingMode::None => false,
RoundingMode::NearestTiesToAway => loss.is_gte_half(),
RoundingMode::NearestTiesToEven => {
if loss.is_mt_half() {
return true;
}
loss.is_exactly_half() && self.mantissa.is_odd()
}
}
}
pub(crate) fn same_absolute_value(&self, other: &Self) -> bool {
if self.category != other.category {
return false;
}
match self.category {
Category::Infinity => true,
Category::NaN => true,
Category::Zero => true,
Category::Normal => {
self.exp == other.exp && self.mantissa == other.mantissa
}
}
}
pub(crate) fn normalize(&mut self, rm: RoundingMode, loss: LossFraction) {
if !self.is_normal() {
return;
}
let mut loss = loss;
let bounds = self.get_exp_bounds();
let nmsb = self.mantissa.msb_index() as i64;
if nmsb > 0 {
let mut exp_change = nmsb - self.sem.get_precision() as i64;
if self.exp + exp_change > bounds.1 {
self.overflow(rm);
self.check_bounds();
return;
}
if self.exp + exp_change < bounds.0 {
exp_change = bounds.0 - self.exp;
}
if exp_change < 0 {
debug_assert!(loss.is_exactly_zero(), "losing information");
self.shift_significand_left(-exp_change as u64);
return;
}
if exp_change > 0 {
let loss2 = self.shift_significand_right(exp_change as u64);
loss = combine_loss_fraction(loss2, loss);
}
}
if loss.is_exactly_zero() {
if self.mantissa.is_zero() {
*self = Self::zero(self.sem, self.sign);
return;
}
return;
}
if self.need_round_away_from_zero(rm, loss) {
if self.mantissa.is_zero() {
self.exp = bounds.0
}
let one = BigInt::one();
self.mantissa = self.mantissa.clone() + one;
let mut m = self.mantissa.clone();
m.shift_right(self.sem.get_precision());
if !m.is_zero() {
if self.exp < bounds.1 {
self.shift_significand_right(1);
} else {
*self = Self::inf(self.sem, self.sign);
return;
}
}
}
if self.mantissa.is_zero() {
*self = Self::zero(self.sem, self.sign);
}
} }
impl PartialEq for Float {
fn eq(&self, other: &Self) -> bool {
let bitwise = self.sign == other.sign
&& self.exp == other.exp
&& self.mantissa == other.mantissa
&& self.category == other.category;
match self.category {
Category::Infinity | Category::Normal => bitwise,
Category::Zero => other.is_zero(),
Category::NaN => false,
}
}
}
impl PartialOrd for Float {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
debug_assert_eq!(self.get_semantics(), other.get_semantics());
let bool_to_ord = |ord: bool| -> Option<Ordering> {
if ord {
Some(Ordering::Less)
} else {
Some(Ordering::Greater)
}
};
match (self.category, other.category) {
(Category::NaN, _) | (_, Category::NaN) => None,
(Category::Zero, Category::Zero) => Some(Ordering::Equal),
(Category::Infinity, Category::Infinity) => {
if self.sign == other.sign {
Some(Ordering::Equal)
} else {
bool_to_ord(self.sign)
}
}
(Category::Infinity, Category::Normal)
| (Category::Infinity, Category::Zero)
| (Category::Normal, Category::Zero) => bool_to_ord(self.sign),
(Category::Normal, Category::Infinity)
| (Category::Zero, Category::Infinity)
| (Category::Zero, Category::Normal) => bool_to_ord(!other.sign),
(Category::Normal, Category::Normal) => {
if self.sign != other.sign {
bool_to_ord(self.sign)
} else if self.exp < other.exp {
bool_to_ord(!self.sign)
} else if self.exp > other.exp {
bool_to_ord(self.sign)
} else {
match self.mantissa.cmp(&other.mantissa) {
Ordering::Less => bool_to_ord(!self.sign),
Ordering::Equal => Some(Ordering::Equal),
Ordering::Greater => bool_to_ord(self.sign),
}
}
}
}
}
}
#[cfg(feature = "std")]
#[test]
fn test_comparisons() {
use super::utils;
for first in utils::get_special_test_values() {
for second in utils::get_special_test_values() {
let is_less = first < second;
let is_eq = first == second;
let is_gt = first > second;
let first = Float::from_f64(first);
let second = Float::from_f64(second);
assert_eq!(is_less, first < second, "<");
assert_eq!(is_eq, first == second, "==");
assert_eq!(is_gt, first > second, ">");
}
}
}
#[test]
fn test_one_imm() {
let sem = Semantics::new(10, 12, nte);
let x = Float::one(sem, false);
assert_eq!(x.as_f64(), 1.0);
}
#[test]
pub fn test_bigint_ctor() {
let bi = BigInt::from_u64(65519);
assert_eq!(Float::from_bigint(FP16, bi).cast(FP32).to_i64(), 65504);
assert_eq!(Float::from_f64(65519.).cast(FP16).to_i64(), 65504);
let sem = Semantics::new(40, 10, nte);
let bi = BigInt::from_u64(1 << 14);
let num = Float::from_bigint(sem, bi);
assert_eq!(num.to_i64(), 1 << 14);
}
#[test]
pub fn test_semantics_size() {
assert_eq!(FP16.log_precision(), 4);
assert_eq!(FP32.log_precision(), 5);
assert_eq!(FP64.log_precision(), 6);
assert_eq!(FP128.log_precision(), 7);
}