use crate::kernels::scientific::erf::erf as erf_fn;
use crate::utils::bitmask_to_simd_mask;
use minarrow::utils::is_simd_aligned;
use minarrow::{Bitmask, FloatArray, Vec64};
#[macro_export]
macro_rules! impl_vecmap {
($name:ident, $name_to:ident, $name_elem:ident, $expr:expr) => {
#[inline(always)]
pub fn $name_elem(x: f64) -> f64 {
$expr(x)
}
#[inline(always)]
pub fn $name_to<const LANES: usize>(
input: &[f64],
output: &mut [f64],
null_mask: Option<&Bitmask>,
null_count: Option<usize>,
) -> Result<(), &'static str>
where {
let len = input.len();
assert_eq!(
len,
output.len(),
concat!(stringify!($name_to), ": input/output length mismatch")
);
if len == 0 {
return Ok(());
}
let has_nulls = match null_count {
Some(n) => n > 0,
None => null_mask.is_some(),
};
if !has_nulls {
#[cfg(feature = "simd")]
{
if is_simd_aligned(input) {
use core::simd::Simd;
let mut i = 0;
while i + LANES <= len {
let v = Simd::<f64, LANES>::from_slice(&input[i..i + LANES]);
let mut r = Simd::<f64, LANES>::splat(0.0);
for lane in 0..LANES {
r[lane] = $expr(v[lane]);
}
output[i..i + LANES].copy_from_slice(r.as_array());
i += LANES;
}
for j in i..len {
output[j] = $expr(input[j]);
}
return Ok(());
}
}
for j in 0..len {
output[j] = $expr(input[j]);
}
return Ok(());
}
let mb = null_mask.ok_or(concat!(
stringify!($name_to),
": input mask required when nulls present"
))?;
#[cfg(feature = "simd")]
{
if is_simd_aligned(input) {
use core::simd::{Mask, Simd};
let mask_bytes = mb.as_bytes();
let mut i = 0;
while i + LANES <= len {
let lane_valid: Mask<i8, LANES> =
bitmask_to_simd_mask::<LANES, i8>(mask_bytes, i, len);
let mut arr = [0.0f64; LANES];
for j in 0..LANES {
let idx = i + j;
arr[j] = if unsafe { lane_valid.test_unchecked(j) } {
input[idx]
} else {
f64::NAN
};
}
let v = Simd::<f64, LANES>::from_array(arr);
let mut r = Simd::<f64, LANES>::splat(0.0);
for lane in 0..LANES {
r[lane] = $expr(v[lane]);
}
let r_arr = r.as_array();
output[i..i + LANES].copy_from_slice(r_arr);
i += LANES;
}
for idx in i..len {
if !unsafe { mb.get_unchecked(idx) } {
output[idx] = f64::NAN;
} else {
output[idx] = $expr(input[idx]);
}
}
return Ok(());
}
}
#[cfg(not(feature = "simd"))]
{
for idx in 0..len {
if !unsafe { mb.get_unchecked(idx) } {
output[idx] = f64::NAN;
} else {
output[idx] = $expr(input[idx]);
}
}
}
#[cfg(feature = "simd")]
{
for idx in 0..len {
if !unsafe { mb.get_unchecked(idx) } {
output[idx] = f64::NAN;
} else {
output[idx] = $expr(input[idx]);
}
}
}
Ok(())
}
#[inline(always)]
pub fn $name<const LANES: usize>(
input: &[f64],
null_mask: Option<&Bitmask>,
null_count: Option<usize>,
) -> Result<FloatArray<f64>, &'static str>
where {
let len = input.len();
if len == 0 {
return Ok(FloatArray::from_slice(&[]));
}
let mut out = Vec64::with_capacity(len);
unsafe {
out.set_len(len);
}
$name_to::<LANES>(input, out.as_mut_slice(), null_mask, null_count)?;
Ok(FloatArray::from_vec64(out, null_mask.cloned()))
}
};
}
impl_vecmap!(abs, abs_to, abs_elem, |x: f64| x.abs());
impl_vecmap!(neg, neg_to, neg_elem, |x: f64| -x);
impl_vecmap!(recip, recip_to, recip_elem, |x: f64| 1.0 / x);
impl_vecmap!(sqrt, sqrt_to, sqrt_elem, |x: f64| x.sqrt());
impl_vecmap!(cbrt, cbrt_to, cbrt_elem, |x: f64| x.cbrt());
impl_vecmap!(exp, exp_to, exp_elem, |x: f64| x.exp());
impl_vecmap!(exp2, exp2_to, exp2_elem, |x: f64| x.exp2());
impl_vecmap!(ln, ln_to, ln_elem, |x: f64| x.ln());
impl_vecmap!(log2, log2_to, log2_elem, |x: f64| x.log2());
impl_vecmap!(log10, log10_to, log10_elem, |x: f64| x.log10());
impl_vecmap!(sin, sin_to, sin_elem, |x: f64| x.sin());
impl_vecmap!(cos, cos_to, cos_elem, |x: f64| x.cos());
impl_vecmap!(tan, tan_to, tan_elem, |x: f64| x.tan());
impl_vecmap!(asin, asin_to, asin_elem, |x: f64| x.asin());
impl_vecmap!(acos, acos_to, acos_elem, |x: f64| x.acos());
impl_vecmap!(atan, atan_to, atan_elem, |x: f64| x.atan());
impl_vecmap!(sinh, sinh_to, sinh_elem, |x: f64| x.sinh());
impl_vecmap!(cosh, cosh_to, cosh_elem, |x: f64| x.cosh());
impl_vecmap!(tanh, tanh_to, tanh_elem, |x: f64| x.tanh());
impl_vecmap!(asinh, asinh_to, asinh_elem, |x: f64| x.asinh());
impl_vecmap!(acosh, acosh_to, acosh_elem, |x: f64| x.acosh());
impl_vecmap!(atanh, atanh_to, atanh_elem, |x: f64| x.atanh());
impl_vecmap!(erf, erf_to, erf_elem, |x: f64| erf_fn(x));
impl_vecmap!(erfc, erfc_to, erfc_elem, |x: f64| 1.0 - erf_fn(x));
impl_vecmap!(ceil, ceil_to, ceil_elem, |x: f64| x.ceil());
impl_vecmap!(floor, floor_to, floor_elem, |x: f64| x.floor());
impl_vecmap!(trunc, trunc_to, trunc_elem, |x: f64| x.trunc());
impl_vecmap!(round, round_to, round_elem, |x: f64| x.round());
impl_vecmap!(sign, sign_to, sign_elem, |x: f64| x.signum());
impl_vecmap!(sigmoid, sigmoid_to, sigmoid_elem, |x: f64| 1.0
/ (1.0 + (-x).exp()));
impl_vecmap!(softplus, softplus_to, softplus_elem, |x: f64| if x > 20.0 {
x
} else {
(1.0 + x.exp()).ln()
});
impl_vecmap!(relu, relu_to, relu_elem, |x: f64| if x > 0.0 {
x
} else {
0.0
});
impl_vecmap!(gelu, gelu_to, gelu_elem, |x: f64| {
0.5 * x * (1.0 + erf_fn(x / std::f64::consts::SQRT_2))
});