use num_traits::MulAdd;
use std::ops::{Add, Mul};
pub trait ConstSqrt2 {
fn const_sqrt2() -> Self;
}
impl ConstSqrt2 for f32 {
#[inline]
fn const_sqrt2() -> Self {
std::f32::consts::SQRT_2
}
}
impl ConstSqrt2 for f64 {
#[inline]
fn const_sqrt2() -> Self {
std::f64::consts::SQRT_2
}
}
pub trait ConstPI {
fn const_pi() -> Self;
}
impl ConstPI for f32 {
#[inline]
fn const_pi() -> Self {
std::f32::consts::PI
}
}
impl ConstPI for f64 {
#[inline]
fn const_pi() -> Self {
std::f64::consts::PI
}
}
#[cfg(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
all(target_arch = "aarch64", feature = "neon")
))]
#[inline(always)]
pub(crate) fn mla<T: Copy + Mul<T, Output = T> + Add<T, Output = T> + MulAdd<T, Output = T>>(
a: T,
b: T,
acc: T,
) -> T {
MulAdd::mul_add(a, b, acc)
}
#[inline(always)]
#[cfg(not(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
all(target_arch = "aarch64", feature = "neon")
)))]
pub(crate) fn mla<T: Copy + Mul<T, Output = T> + Add<T, Output = T> + MulAdd<T, Output = T>>(
a: T,
b: T,
acc: T,
) -> T {
a * b + acc
}