use super::FixedPoint;
use crate::{fixed_point::debug_print, mask, to_fixed, to_float, UInt};
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct Fx {
pub val: UInt,
pub m: i32,
pub b: i32,
pub is_exact: bool,
}
impl Fx {
pub fn new(val: UInt, m: i32, b: i32, is_exact: bool) -> Self {
if b < 1 + m {
panic!("Total num of bits must be larger than num of integer bits + sign.")
}
Self {
val,
m,
b,
is_exact,
}
}
pub fn get_frac_bits(&self) -> i32 {
self.b - self.m
}
}
impl FixedPoint for Fx {
fn eval(&self) -> f64 {
to_float(self.val, self.b, self.m, self.b - self.m - 1).unwrap()
}
}
impl std::ops::Add for Fx {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
if self.m != rhs.m || self.b != rhs.b {
panic!("`m` and `n` field of each fx obj has to match.")
}
let (m, b) = (self.m, self.b);
let sum_eval = self.eval() + rhs.eval();
if sum_eval.abs().log2() >= m as f64 {
panic!("{} can't fit into {} integer bits", sum_eval, m);
}
let (fixed1_val, fixed2_val) = (self.val, rhs.val);
let fixed_sum = (fixed1_val + fixed2_val) & mask(b as u32);
if (fixed_sum) & mask(b as u32 - 1) == 0 {
return Self {
val: 0,
m,
b,
is_exact: true,
};
}
let fixed_sum_sign = (fixed_sum >> (b - 1)) as i32;
let fixed_sum_abs = match fixed_sum_sign == 0 {
true => fixed_sum,
false => (!fixed_sum + 1) & mask(b as u32),
};
let new_val = if (fixed_sum_abs as f32).log2() < (self.b - 1) as f32 {
fixed_sum
} else {
panic!(
"Can't fit {} into `m` = {} integer bits",
fixed_sum_abs, self.m
)
};
Self {
val: new_val,
m: self.m,
b: self.b,
is_exact: self.is_exact && rhs.is_exact,
}
}
}
impl std::ops::Sub for Fx {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
self + (-rhs)
}
}
impl std::ops::Neg for Fx {
type Output = Self;
fn neg(self) -> Self::Output {
if self.val == 0 {
return self;
}
Self {
val: (!self.val + 1) & mask(self.b as u32),
m: self.m,
b: self.b,
is_exact: self.is_exact,
}
}
}
impl std::ops::Mul for Fx {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
if self.m != rhs.m || self.b != rhs.b {
panic!("`m` and `n` field of each fx obj has to match.")
}
let (m, b) = (self.m, self.b);
let n = b - m - 1;
let mul_eval = self.eval() * rhs.eval();
if mul_eval.abs().log2() >= m as f64 {
panic!("{} can't fit into {} integer bits", mul_eval, m);
}
let sign_val1 = self.val >> (b - 1);
let sign_val2 = rhs.val >> (b - 1);
let abs_val1 = match sign_val1 == 0 {
true => self.val,
false => (!self.val + 1) & mask(b as u32),
};
let abs_val2 = match sign_val2 == 0 {
true => rhs.val,
false => (!rhs.val + 1) & mask(b as u32),
};
let abs_val_mul = abs_val1 * abs_val2;
let fixed_mul_sign = sign_val1 ^ sign_val2;
let is_exact = (self.is_exact && rhs.is_exact) && (abs_val_mul & mask(n as u32) == 0);
let abs_val_mul_adjusted = abs_val_mul >> n;
let val_mul_adjusted = match fixed_mul_sign == 0 {
true => abs_val_mul_adjusted,
false => (!abs_val_mul_adjusted + 1) & mask(b as u32),
};
Self {
val: val_mul_adjusted,
m: self.m,
b: self.b,
is_exact,
}
}
}
impl std::fmt::Debug for Fx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ans = debug_print(self.val, self.m, self.b, self.is_exact);
write!(f, "{}", ans)
}
}
impl std::fmt::Display for Fx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ans = format!("Fx<{},{}>({})", self.m, self.b, self.val);
write!(f, "{}", ans)
}
}
#[allow(non_snake_case)]
pub fn to_Fx(x: f64, m: i32, b: i32, round: bool) -> Result<Fx, String> {
let fx_q = crate::to_fixed(x, m, b - m - 1, round);
match fx_q {
Ok(fx) => Ok(Fx::new(fx.val, fx.m, fx.m + fx.n + 1, fx.is_exact)),
Err(e) => Err(e),
}
}