use super::fixedpoint::{Q0_31, Q1_30, Q2_29, Q5_26};
use num_traits::PrimInt;
pub fn convert_scale_to_mult_shift(scale: f32) -> Option<(i32, isize)> {
if scale <= 0.0 {
return None;
}
let scale_bits = scale.to_bits();
let current_exponent = scale_bits >> 23;
let fractional_part = scale_bits & 0x007fffff;
if fractional_part == 0 {
let shift = 127 - current_exponent as isize;
Some((0, shift))
} else {
let bumped_multi = f32::from_bits(fractional_part | 0x3f000000);
let int_multi = (bumped_multi * (1u32 << 31) as f32).round() as i32;
let shift = 127 - current_exponent as isize - 1;
Some((int_multi, shift))
}
}
pub(crate) fn get_reciprocal(x: i32, fixed_point: usize) -> (i32, usize) {
assert!(fixed_point > 0);
let headroom_plus_one = (x as u32).leading_zeros() as usize;
let num_bits_over_unit = fixed_point - headroom_plus_one;
let shifted_sum_minus_one = ((x as u32) << headroom_plus_one) - (1_u32 << 31);
let shifted_scale =
Q0_31::from_raw(shifted_sum_minus_one as i32).one_over_one_plus_x_for_x_in_0_1();
(shifted_scale.as_raw(), num_bits_over_unit)
}
pub fn one_over_one_plus_x_for_x_in_0_1(a: i32) -> i32 {
let a_in_q0_31 = Q0_31::from_raw(a);
let constant_48_over_17 = Q2_29::from_raw(1515870810);
let constant_neg_32_over_17 = Q2_29::from_raw(-1010580540);
let one_in_q0_31 = Q0_31::one();
let one_in_q2_29 = Q2_29::one();
let half_denominator = a_in_q0_31.rounding_half_sum(one_in_q0_31);
let x_0: Q2_29 = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
let mut x_n = x_0;
for _ in 0..3 {
let half_denominator_times_x_n = half_denominator * x_n;
let one_minus_half_denominator_times_x_n = one_in_q2_29 - half_denominator_times_x_n;
let x_times_one_minus_half_denominator_times_x_n =
x_n * one_minus_half_denominator_times_x_n;
let rescaled_x_n = x_times_one_minus_half_denominator_times_x_n.rescale::<2>();
x_n = x_n + rescaled_x_n;
}
let half_x_n = Q1_30::from_raw(x_n.as_raw());
let res_in_q0_31 = half_x_n.rescale::<0>();
res_in_q0_31.as_raw()
}
pub(crate) fn rescale(x: i32, src_integer_bits: usize, dst_integer_bits: usize) -> i32 {
let exponent = src_integer_bits as i32 - dst_integer_bits as i32;
saturating_rounding_multiply_by_pot(x, exponent)
}
pub fn exp_on_negative_values(a: i32) -> i32 {
let a_q5_36 = Q5_26::from_raw(a);
let k_one_quarter = Q5_26::constant_pot(-2); let mask = k_one_quarter - Q5_26::from_raw(1); let a_mod_quarter_minus_one_quarter = (a_q5_36 & mask) - k_one_quarter;
let rescaled_a_mod_quarter_minus_one_quarter = a_mod_quarter_minus_one_quarter.rescale::<0>();
let mut result = rescaled_a_mod_quarter_minus_one_quarter
.exp_on_interval_between_negative_one_quarter_and_0_excl();
let remainder = (a_mod_quarter_minus_one_quarter - a_q5_36).as_raw();
macro_rules! exp_barrel_shifter {
($exponent: expr, $quantized_value: expr) => {
if 5 > $exponent {
let k_shift_amount = 26 + $exponent;
let mask = mask_if_non_zero(remainder & (1 << k_shift_amount));
result = Q0_31::select_using_mask(
mask,
result * Q0_31::from_raw($quantized_value),
result,
);
}
};
}
exp_barrel_shifter!(-2, 1672461947); exp_barrel_shifter!(-1, 1302514674); exp_barrel_shifter!(0, 790015084); exp_barrel_shifter!(1, 290630308); exp_barrel_shifter!(2, 39332535); exp_barrel_shifter!(3, 720401); exp_barrel_shifter!(4, 242);
let mask = a_q5_36.mask_if_zero();
let res_in_q0_31 = Q0_31::select_using_mask(mask.as_raw(), Q0_31::one(), result);
res_in_q0_31.as_raw()
}
pub(crate) fn exp_on_interval_between_negative_one_quarter_and_0_excl(a: i32) -> i32 {
let a_in_q0_31 = Q0_31::from_raw(a);
let exp_minus_one_over_eight = Q0_31::from_raw(1895147668);
let constant_1_over_3 = Q0_31::from_raw(715827883);
let x = a_in_q0_31 + Q0_31::constant_pot(-3);
let x2 = x * x;
let x3 = x2 * x;
let x4 = x2 * x2;
let x4_over_4 = x4.saturating_rounding_multiply_by_pot(-2);
let x4_over_24_plus_x3_over_6_plus_x2_over_2 =
Q0_31::from_raw(saturating_rounding_multiply_by_pot(
(((x4_over_4 + x3) * constant_1_over_3) + x2).as_raw(),
-1,
));
let res_in_q0_31 = exp_minus_one_over_eight
+ exp_minus_one_over_eight * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2);
res_in_q0_31.as_raw()
}
pub fn rounding_divide_by_pot(x: i32, exponent: i32) -> i32 {
assert!(exponent >= 0);
assert!(exponent <= 31);
let mask = ((1_i64 << exponent) - 1) as i32;
let remainder = x & mask;
let mut result = x >> exponent as usize;
let mut threshold = mask >> 1;
if result < 0 {
threshold += 1;
}
if remainder > threshold {
result += 1;
}
result
}
#[allow(clippy::comparison_chain)] pub fn saturating_rounding_multiply_by_pot(x: i32, exponent: i32) -> i32 {
if exponent == 0 {
x
} else if exponent < 0 {
rounding_divide_by_pot(x, -exponent)
} else {
let min = i32::MIN;
let max = i32::MAX;
let threshold = (1 << (32 - 1 - exponent)) - 1;
let positive_mask = mask_if_non_zero((x > threshold) as i32);
let negative_mask = mask_if_non_zero((x < -threshold) as i32);
let mut result = x << exponent as usize;
result = select_using_mask(positive_mask, max, result);
result = select_using_mask(negative_mask, min, result);
result
}
}
pub fn rounding_half_sum(a: i32, b: i32) -> i32 {
let sum = (a as i64) + (b as i64);
let sign: i64 = if sum >= 0 { 1 } else { -1 };
((sum + sign) / 2) as i32
}
pub fn mask_if_non_zero(x: i32) -> i32 {
if x != 0 {
!0
} else {
0
}
}
pub fn mask_if_zero(x: i32) -> i32 {
if x == 0 {
!0
} else {
0
}
}
pub fn select_using_mask(mask: i32, a: i32, b: i32) -> i32 {
(mask & a) ^ (!mask & b)
}
pub fn saturating_rounding_doubling_high_mul(a: i32, b: i32) -> i32 {
let overflow = a == b && a == i32::MIN;
let product = (a as i64) * (b as i64);
let nudge = if product >= 0 { 1 << 30 } else { 1 - (1 << 30) };
let product_x2_high32 = ((product + nudge) / (1_i64 << 31)) as i32;
if overflow {
i32::MAX
} else {
product_x2_high32
}
}
pub fn is_signed<T: PrimInt>() -> bool {
let mv = T::min_value();
let z = T::zero();
mv < z
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_rounding_divide_by_pot_1() {
let x = 128;
let res = rounding_divide_by_pot(x, 2);
assert_eq!(res, 32);
}
#[test]
fn test_rounding_divide_by_pot_2() {
let x = 129;
let res = rounding_divide_by_pot(x, 2);
assert_eq!(res, 32);
}
#[test]
fn test_rounding_half_sum_1() {
let a = 22;
let b = 22;
let res = rounding_half_sum(a, b);
assert_eq!(res, 22)
}
#[test]
fn test_rounding_half_sum_2() {
let a = 6; let b = 1_i32 << 3; let expected_res = 7; let res = rounding_half_sum(a, b);
assert_eq!(res, expected_res)
}
#[test]
fn test_rounding_half_sum_3() {
let a = 1610612736; let b = i32::MAX; let expected_res = 1879048192; let res = rounding_half_sum(a, b);
assert_eq!(res, expected_res)
}
#[test]
fn test_saturating_rounding_doubling_high_mul() {
let a: i32 = 1879048192; let b: i32 = -631612838; let expected_res = -552661233;
let res = saturating_rounding_doubling_high_mul(a, b);
assert_eq!(res, expected_res);
}
}