use super::fixedpoint::{Q0_31, Q1_30, Q2_29, Q5_26};
use num_traits::PrimInt;
pub fn convert_scale_to_mult_shift(scale: f32) -> Option<(i32, isize)> {
    if scale <= 0.0 {
        return None;
    }
    let scale_bits = scale.to_bits();
    let current_exponent = scale_bits >> 23;
    let fractional_part = scale_bits & 0x007fffff;
    if fractional_part == 0 {
        let shift = 127 - current_exponent as isize;
        Some((0, shift))
    } else {
        let bumped_multi = f32::from_bits(fractional_part | 0x3f000000);
        let int_multi = (bumped_multi * (1u32 << 31) as f32).round() as i32;
        let shift = 127 - current_exponent as isize - 1;
        Some((int_multi, shift))
    }
}
pub(crate) fn get_reciprocal(x: i32, fixed_point: usize) -> (i32, usize) {
    assert!(fixed_point > 0);
    let headroom_plus_one = (x as u32).leading_zeros() as usize;
    let num_bits_over_unit = fixed_point - headroom_plus_one;
    let shifted_sum_minus_one = ((x as u32) << headroom_plus_one) - (1_u32 << 31);
    let shifted_scale =
        Q0_31::from_raw(shifted_sum_minus_one as i32).one_over_one_plus_x_for_x_in_0_1();
    (shifted_scale.as_raw(), num_bits_over_unit)
}
pub fn one_over_one_plus_x_for_x_in_0_1(a: i32) -> i32 {
    let a_in_q0_31 = Q0_31::from_raw(a);
    let constant_48_over_17 = Q2_29::from_raw(1515870810);
    let constant_neg_32_over_17 = Q2_29::from_raw(-1010580540);
    let one_in_q0_31 = Q0_31::one();
    let one_in_q2_29 = Q2_29::one();
    let half_denominator = a_in_q0_31.rounding_half_sum(one_in_q0_31);
    let x_0: Q2_29 = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
    let mut x_n = x_0;
    for _ in 0..3 {
        let half_denominator_times_x_n = half_denominator * x_n;
        let one_minus_half_denominator_times_x_n = one_in_q2_29 - half_denominator_times_x_n;
        let x_times_one_minus_half_denominator_times_x_n =
            x_n * one_minus_half_denominator_times_x_n;
        let rescaled_x_n = x_times_one_minus_half_denominator_times_x_n.rescale::<2>();
        x_n = x_n + rescaled_x_n;
    }
    let half_x_n = Q1_30::from_raw(x_n.as_raw());
    let res_in_q0_31 = half_x_n.rescale::<0>();
    res_in_q0_31.as_raw()
}
pub(crate) fn rescale(x: i32, src_integer_bits: usize, dst_integer_bits: usize) -> i32 {
    let exponent = src_integer_bits as i32 - dst_integer_bits as i32;
    saturating_rounding_multiply_by_pot(x, exponent)
}
pub fn exp_on_negative_values(a: i32) -> i32 {
    let a_q5_36 = Q5_26::from_raw(a);
    let k_one_quarter = Q5_26::constant_pot(-2); let mask = k_one_quarter - Q5_26::from_raw(1); let a_mod_quarter_minus_one_quarter = (a_q5_36 & mask) - k_one_quarter;
    let rescaled_a_mod_quarter_minus_one_quarter = a_mod_quarter_minus_one_quarter.rescale::<0>();
    let mut result = rescaled_a_mod_quarter_minus_one_quarter
        .exp_on_interval_between_negative_one_quarter_and_0_excl();
    let remainder = (a_mod_quarter_minus_one_quarter - a_q5_36).as_raw();
    macro_rules! exp_barrel_shifter {
        ($exponent: expr, $quantized_value: expr) => {
            if 5 > $exponent {
                let k_shift_amount = 26 + $exponent;
                let mask = mask_if_non_zero(remainder & (1 << k_shift_amount));
                result = Q0_31::select_using_mask(
                    mask,
                    result * Q0_31::from_raw($quantized_value),
                    result,
                );
            }
        };
    }
    exp_barrel_shifter!(-2, 1672461947); exp_barrel_shifter!(-1, 1302514674); exp_barrel_shifter!(0, 790015084); exp_barrel_shifter!(1, 290630308); exp_barrel_shifter!(2, 39332535); exp_barrel_shifter!(3, 720401); exp_barrel_shifter!(4, 242); let mask = a_q5_36.mask_if_zero();
    let res_in_q0_31 = Q0_31::select_using_mask(mask.as_raw(), Q0_31::one(), result);
    res_in_q0_31.as_raw()
}
pub(crate) fn exp_on_interval_between_negative_one_quarter_and_0_excl(a: i32) -> i32 {
    let a_in_q0_31 = Q0_31::from_raw(a);
    let exp_minus_one_over_eight = Q0_31::from_raw(1895147668);
    let constant_1_over_3 = Q0_31::from_raw(715827883);
    let x = a_in_q0_31 + Q0_31::constant_pot(-3);
    let x2 = x * x;
    let x3 = x2 * x;
    let x4 = x2 * x2;
    let x4_over_4 = x4.saturating_rounding_multiply_by_pot(-2);
    let x4_over_24_plus_x3_over_6_plus_x2_over_2 =
        Q0_31::from_raw(saturating_rounding_multiply_by_pot(
            (((x4_over_4 + x3) * constant_1_over_3) + x2).as_raw(),
            -1,
        ));
    let res_in_q0_31 = exp_minus_one_over_eight
        + exp_minus_one_over_eight * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2);
    res_in_q0_31.as_raw()
}
pub fn rounding_divide_by_pot(x: i32, exponent: i32) -> i32 {
    assert!(exponent >= 0);
    assert!(exponent <= 31);
    let mask = ((1_i64 << exponent) - 1) as i32;
    let remainder = x & mask;
    let mut result = x >> exponent as usize;
    let mut threshold = mask >> 1;
    if result < 0 {
        threshold += 1;
    }
    if remainder > threshold {
        result += 1;
    }
    result
}
#[allow(clippy::comparison_chain)] pub fn saturating_rounding_multiply_by_pot(x: i32, exponent: i32) -> i32 {
    if exponent == 0 {
        x
    } else if exponent < 0 {
        rounding_divide_by_pot(x, -exponent)
    } else {
        let min = i32::MIN;
        let max = i32::MAX;
        let threshold = (1 << (32 - 1 - exponent)) - 1;
        let positive_mask = mask_if_non_zero((x > threshold) as i32);
        let negative_mask = mask_if_non_zero((x < -threshold) as i32);
        let mut result = x << exponent as usize;
        result = select_using_mask(positive_mask, max, result);
        result = select_using_mask(negative_mask, min, result);
        result
    }
}
pub fn rounding_half_sum(a: i32, b: i32) -> i32 {
    let sum = (a as i64) + (b as i64);
    let sign: i64 = if sum >= 0 { 1 } else { -1 };
    ((sum + sign) / 2) as i32
}
pub fn mask_if_non_zero(x: i32) -> i32 {
    if x != 0 {
        !0
    } else {
        0
    }
}
pub fn mask_if_zero(x: i32) -> i32 {
    if x == 0 {
        !0
    } else {
        0
    }
}
pub fn select_using_mask(mask: i32, a: i32, b: i32) -> i32 {
    (mask & a) ^ (!mask & b)
}
pub fn saturating_rounding_doubling_high_mul(a: i32, b: i32) -> i32 {
    let overflow = a == b && a == i32::MIN;
    let product = (a as i64) * (b as i64);
    let nudge = if product >= 0 { 1 << 30 } else { 1 - (1 << 30) };
    let product_x2_high32 = ((product + nudge) / (1_i64 << 31)) as i32;
    if overflow {
        i32::MAX
    } else {
        product_x2_high32
    }
}
pub fn is_signed<T: PrimInt>() -> bool {
    let mv = T::min_value();
    let z = T::zero();
    mv < z
}
#[cfg(test)]
mod test {
    use super::*;
    #[test]
    fn test_rounding_divide_by_pot_1() {
        let x = 128;
        let res = rounding_divide_by_pot(x, 2);
        assert_eq!(res, 32);
    }
    #[test]
    fn test_rounding_divide_by_pot_2() {
        let x = 129;
        let res = rounding_divide_by_pot(x, 2);
        assert_eq!(res, 32);
    }
    #[test]
    fn test_rounding_half_sum_1() {
        let a = 22;
        let b = 22;
        let res = rounding_half_sum(a, b);
        assert_eq!(res, 22)
    }
    #[test]
    fn test_rounding_half_sum_2() {
        let a = 6; let b = 1_i32 << 3; let expected_res = 7; let res = rounding_half_sum(a, b);
        assert_eq!(res, expected_res)
    }
    #[test]
    fn test_rounding_half_sum_3() {
        let a = 1610612736; let b = i32::MAX; let expected_res = 1879048192; let res = rounding_half_sum(a, b);
        assert_eq!(res, expected_res)
    }
    #[test]
    fn test_saturating_rounding_doubling_high_mul() {
        let a: i32 = 1879048192; let b: i32 = -631612838; let expected_res = -552661233;
        let res = saturating_rounding_doubling_high_mul(a, b);
        assert_eq!(res, expected_res);
    }
}