use num_traits::MulAdd;
use std::ops::{Add, Mul};
#[cfg(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
all(target_arch = "aarch64", target_feature = "neon")
))]
#[inline(always)]
pub(crate) fn mlaf<T: Copy + Mul<T, Output = T> + Add<T, Output = T> + MulAdd<T, Output = T>>(
acc: T,
a: T,
b: T,
) -> T {
MulAdd::mul_add(a, b, acc)
}
#[inline(always)]
#[cfg(not(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
all(target_arch = "aarch64", target_feature = "neon")
)))]
pub(crate) fn mlaf<T: Copy + Mul<T, Output = T> + Add<T, Output = T> + MulAdd<T, Output = T>>(
acc: T,
a: T,
b: T,
) -> T {
acc + a * b
}
#[inline]
pub(crate) const fn rintfk(x: f32) -> f32 {
(if x < 0. { x - 0.5 } else { x + 0.5 }) as i32 as f32
}
#[inline(always)]
pub(crate) const fn fmlaf(a: f32, b: f32, c: f32) -> f32 {
c + a * b
}
#[inline(always)]
pub(crate) fn f_fmlaf(a: f32, b: f32, c: f32) -> f32 {
mlaf(c, a, b)
}
#[inline(always)]
pub(crate) const fn fmla(a: f64, b: f64, c: f64) -> f64 {
c + a * b
}
#[inline(always)]
pub(crate) fn f_fmla(a: f64, b: f64, c: f64) -> f64 {
mlaf(c, a, b)
}
#[allow(dead_code)]
#[inline(always)]
pub(crate) fn c_mlaf<T: Copy + Mul<T, Output = T> + Add<T, Output = T> + MulAdd<T, Output = T>>(
a: T,
b: T,
c: T,
) -> T {
mlaf(c, a, b)
}
#[inline]
pub const fn copysignfk(x: f32, y: f32) -> f32 {
f32::from_bits((x.to_bits() & !(1 << 31)) ^ (y.to_bits() & (1 << 31)))
}
#[inline]
pub(crate) const fn pow2if(q: i32) -> f32 {
f32::from_bits((q.wrapping_add(0x7f) as u32) << 23)
}
#[inline]
pub(crate) const fn rintk(x: f64) -> f64 {
(if x < 0. { x - 0.5 } else { x + 0.5 }) as i64 as f64
}
#[inline(always)]
pub(crate) const fn pow2i(q: i32) -> f64 {
f64::from_bits((q.wrapping_add(0x3ff) as u64) << 52)
}
#[inline]
pub const fn copysignk(x: f64, y: f64) -> f64 {
f64::from_bits((x.to_bits() & !(1 << 63)) ^ (y.to_bits() & (1 << 63)))
}
#[inline]
pub(crate) const fn min_normal_f64() -> f64 {
let exponent_bits = 1u64 << 52;
let bits = exponent_bits;
f64::from_bits(bits)
}