#[cfg(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
all(target_arch = "aarch64", target_feature = "neon")
))]
use num_traits::MulAdd;
#[cfg(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
all(target_arch = "aarch64", target_feature = "neon")
))]
#[inline(always)]
pub fn mlaf<T: MulAdd<T, Output = T>>(acc: T, a: T, b: T) -> T {
MulAdd::mul_add(a, b, acc)
}
#[cfg(not(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
all(target_arch = "aarch64", target_feature = "neon")
)))]
#[inline(always)]
pub fn mlaf<T: core::ops::Add<Output = T> + core::ops::Mul<Output = T>>(acc: T, a: T, b: T) -> T {
acc + a * b
}
#[cfg(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
all(target_arch = "aarch64", target_feature = "neon")
))]
#[inline(always)]
pub fn neg_mlaf<T: MulAdd<T, Output = T> + core::ops::Neg<Output = T>>(acc: T, a: T, b: T) -> T {
mlaf(acc, a, -b)
}
#[cfg(not(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
all(target_arch = "aarch64", target_feature = "neon")
)))]
#[inline(always)]
pub fn neg_mlaf<
T: core::ops::Add<Output = T> + core::ops::Mul<Output = T> + core::ops::Neg<Output = T>,
>(
acc: T,
a: T,
b: T,
) -> T {
acc + a * (-b)
}
#[inline(always)]
pub fn fmla<T>(a: T, b: T, c: T) -> T
where
T: MlaCompatible,
{
T::mla(c, a, b)
}
pub trait MlaCompatible: Sized {
fn mla(acc: Self, a: Self, b: Self) -> Self;
}
impl MlaCompatible for f32 {
#[inline(always)]
fn mla(acc: Self, a: Self, b: Self) -> Self {
mlaf(acc, a, b)
}
}
impl MlaCompatible for f64 {
#[inline(always)]
fn mla(acc: Self, a: Self, b: Self) -> Self {
mlaf(acc, a, b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mlaf_f32() {
let result = mlaf(1.0f32, 2.0f32, 3.0f32);
assert_eq!(result, 7.0f32);
}
#[test]
fn test_mlaf_f64() {
let result = mlaf(1.0f64, 2.0f64, 3.0f64);
assert_eq!(result, 7.0f64);
}
#[test]
fn test_neg_mlaf_f32() {
let result = neg_mlaf(10.0f32, 2.0f32, 3.0f32);
assert_eq!(result, 4.0f32);
}
#[test]
fn test_fmla_f32() {
let result = fmla(2.0f32, 3.0f32, 1.0f32);
assert_eq!(result, 7.0f32);
}
}