include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
#[cfg(feature = "datetime")]
use crate::DatetimeAVT;
#[cfg(feature = "datetime")]
use crate::DatetimeArray;
use crate::enums::error::KernelError;
use crate::enums::operators::ArithmeticOperator::{self};
#[cfg(feature = "simd")]
use crate::kernels::arithmetic::simd::{
float_dense_body_f32_simd, float_dense_body_f64_simd, float_masked_body_f32_simd,
float_masked_body_f64_simd, fma_dense_body_f32_simd, fma_dense_body_f64_simd,
fma_masked_body_f32_simd, fma_masked_body_f64_simd, int_dense_body_simd, int_masked_body_simd,
};
use crate::kernels::arithmetic::std::{
float_dense_body_std, float_masked_body_std, int_dense_body_std, int_masked_body_std,
};
#[cfg(feature = "datetime")]
use crate::kernels::bitmask::merge_bitmasks_to_new;
use crate::structs::variants::float::FloatArray;
use crate::structs::variants::integer::IntegerArray;
use crate::utils::confirm_equal_len;
#[cfg(feature = "simd")]
use crate::utils::is_simd_aligned;
use crate::{Bitmask, Vec64};
macro_rules! impl_apply_int {
($fn_name:ident, $ty:ty, $lanes:expr) => {
#[doc = concat!(
"Performs element-wise integer `ArithmeticOperator` over two `&[", stringify!($ty),
"]`, SIMD-accelerated using ", stringify!($lanes), " lanes if available, \
otherwise falls back to scalar. \
Returns `IntegerArray<", stringify!($ty), ">` with appropriate null-mask handling."
)]
#[inline(always)]
pub fn $fn_name(
lhs: &[$ty],
rhs: &[$ty],
op: ArithmeticOperator,
mask: Option<&Bitmask>
) -> Result<IntegerArray<$ty>, KernelError> {
let len = lhs.len();
confirm_equal_len("apply numeric: length mismatch", len, rhs.len())?;
#[cfg(feature = "simd")]
{
if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
let mut out = Vec64::with_capacity(len);
unsafe { out.set_len(len) };
match mask {
Some(mask) => {
let mut out_mask = crate::Bitmask::new_set_all(len, true);
int_masked_body_simd::<$ty, $lanes>(op, lhs, rhs, mask, &mut out, &mut out_mask);
return Ok(IntegerArray {
data: out.into(),
null_mask: Some(out_mask),
});
}
None => {
int_dense_body_simd::<$ty, $lanes>(op, lhs, rhs, &mut out);
return Ok(IntegerArray {
data: out.into(),
null_mask: None,
});
}
}
}
}
let mut out = Vec64::with_capacity(len);
unsafe { out.set_len(len) };
match mask {
Some(mask) => {
let mut out_mask = crate::Bitmask::new_set_all(len, true);
int_masked_body_std::<$ty>(op, lhs, rhs, mask, &mut out, &mut out_mask);
Ok(IntegerArray {
data: out.into(),
null_mask: Some(out_mask),
})
}
None => {
int_dense_body_std::<$ty>(op, lhs, rhs, &mut out);
Ok(IntegerArray {
data: out.into(),
null_mask: None,
})
}
}
}
};
}
macro_rules! impl_apply_float {
($fn_name:ident, $ty:ty, $lanes:expr, $dense_body_simd:ident, $masked_body_simd:ident) => {
#[doc = concat!(
"Performs element-wise float `ArithmeticOperator` on `&[", stringify!($ty),
"]` using SIMD (", stringify!($lanes), " lanes) for dense/masked cases, \
Falls back to standard scalar ops when the `simd` feature is not enabled. \
Returns `FloatArray<", stringify!($ty), ">` and handles optional null-mask."
)]
#[inline(always)]
pub fn $fn_name(
lhs: &[$ty],
rhs: &[$ty],
op: ArithmeticOperator,
mask: Option<&Bitmask>
) -> Result<FloatArray<$ty>, KernelError> {
let len = lhs.len();
confirm_equal_len("apply numeric: length mismatch", len, rhs.len())?;
#[cfg(feature = "simd")]
{
if is_simd_aligned(lhs) && is_simd_aligned(rhs) {
let mut out = Vec64::with_capacity(len);
unsafe { out.set_len(len) };
match mask {
Some(mask) => {
let mut out_mask = crate::Bitmask::new_set_all(len, true);
$masked_body_simd::<$lanes>(op, lhs, rhs, mask, &mut out, &mut out_mask);
return Ok(FloatArray {
data: out.into(),
null_mask: Some(out_mask),
});
}
None => {
$dense_body_simd::<$lanes>(op, lhs, rhs, &mut out);
return Ok(FloatArray {
data: out.into(),
null_mask: None,
});
}
}
}
}
let mut out = Vec64::with_capacity(len);
unsafe { out.set_len(len) };
match mask {
Some(mask) => {
let mut out_mask = crate::Bitmask::new_set_all(len, true);
float_masked_body_std::<$ty>(op, lhs, rhs, mask, &mut out, &mut out_mask);
Ok(FloatArray {
data: out.into(),
null_mask: Some(out_mask),
})
}
None => {
float_dense_body_std::<$ty>(op, lhs, rhs, &mut out);
Ok(FloatArray {
data: out.into(),
null_mask: None,
})
}
}
}
};
}
macro_rules! impl_apply_fma_float {
($fn_name:ident, $ty:ty, $lanes:expr, $dense_simd:ident, $masked_simd:ident) => {
#[doc = concat!(
"Performs element-wise fused multiply-add (`a * b + acc`) on `&[", stringify!($ty),
"]` using SIMD (", stringify!($lanes), " lanes; dense or masked, via `",
stringify!($dense), "`/`", stringify!($masked), "` as needed. \
Falls back to standard scalar ops when the `simd` feature is not enabled. \
Results in a `FloatArray<", stringify!($ty), ">`."
)]
#[inline(always)]
pub fn $fn_name(
lhs: &[$ty],
rhs: &[$ty],
acc: &[$ty],
mask: Option<&Bitmask>
) -> Result<FloatArray<$ty>, KernelError> {
let len = lhs.len();
confirm_equal_len("apply numeric: length mismatch", len, rhs.len())?;
confirm_equal_len("acc length mismatch", len, acc.len())?;
let mut out = Vec64::with_capacity(len);
unsafe { out.set_len(len) };
let mut out_mask = crate::Bitmask::new_set_all(len, true);
#[cfg(feature = "simd")]
{
if is_simd_aligned(lhs) && is_simd_aligned(rhs) && is_simd_aligned(acc) {
match mask {
Some(mask) => {
$masked_simd::<$lanes>(lhs, rhs, acc, mask, &mut out, &mut out_mask);
return Ok(FloatArray {
data: out.into(),
null_mask: Some(out_mask),
});
}
None => {
$dense_simd::<$lanes>(lhs, rhs, acc, &mut out);
return Ok(FloatArray {
data: out.into(),
null_mask: None,
});
}
}
}
}
match mask {
Some(mask) => {
for i in 0..len {
if unsafe { mask.get_unchecked(i) } {
out[i] = lhs[i] * rhs[i] + acc[i];
} else {
out[i] = 0 as $ty; out_mask.set(i, false);
}
}
Ok(FloatArray {
data: out.into(),
null_mask: Some(out_mask),
})
}
None => {
for i in 0..len {
out[i] = lhs[i] * rhs[i] + acc[i];
}
Ok(FloatArray {
data: out.into(),
null_mask: None,
})
}
}
}
};
}
#[cfg(feature = "datetime")]
macro_rules! impl_apply_datetime {
($fn_name:ident, $ty:ty, $lanes:expr) => {
#[inline(always)]
pub fn $fn_name(
lhs: DatetimeAVT<$ty>,
rhs: DatetimeAVT<$ty>,
op: ArithmeticOperator,
) -> Result<DatetimeArray<$ty>, KernelError> {
let (larr, loff, llen) = lhs;
let (rarr, roff, rlen) = rhs;
confirm_equal_len("apply_datetime: length mismatch", llen, rlen)?;
let out_mask =
merge_bitmasks_to_new(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), llen);
let ldata = &larr.data[loff..loff + llen];
let rdata = &rarr.data[roff..roff + rlen];
let mut out = Vec64::<$ty>::with_capacity(llen);
unsafe {
out.set_len(llen);
}
match out_mask.as_ref() {
Some(mask) => {
let mut result_mask = crate::Bitmask::new_set_all(llen, true);
#[cfg(feature = "simd")]
{
int_masked_body_simd::<$ty, $lanes>(
op,
ldata,
rdata,
mask,
&mut out,
&mut result_mask,
);
}
#[cfg(not(feature = "simd"))]
{
int_masked_body_std::<$ty>(
op,
ldata,
rdata,
mask,
&mut out,
&mut result_mask,
);
}
Ok(DatetimeArray::from_vec64(out, Some(result_mask), None))
}
None => {
#[cfg(feature = "simd")]
{
int_dense_body_simd::<$ty, $lanes>(op, ldata, rdata, &mut out);
}
#[cfg(not(feature = "simd"))]
{
int_dense_body_std::<$ty>(op, ldata, rdata, &mut out);
}
Ok(DatetimeArray::from_vec64(out, None, None))
}
}
}
};
}
impl_apply_int!(apply_int_i32, i32, W32);
impl_apply_int!(apply_int_u32, u32, W32);
impl_apply_int!(apply_int_i64, i64, W64);
impl_apply_int!(apply_int_u64, u64, W64);
#[cfg(feature = "extended_numeric_types")]
impl_apply_int!(apply_int_i16, i16, W16);
#[cfg(feature = "extended_numeric_types")]
impl_apply_int!(apply_int_u16, u16, W16);
#[cfg(feature = "extended_numeric_types")]
impl_apply_int!(apply_int_i8, i8, W8);
#[cfg(feature = "extended_numeric_types")]
impl_apply_int!(apply_int_u8, u8, W8);
impl_apply_float!(
apply_float_f32,
f32,
W32,
float_dense_body_f32_simd,
float_masked_body_f32_simd
);
impl_apply_float!(
apply_float_f64,
f64,
W64,
float_dense_body_f64_simd,
float_masked_body_f64_simd
);
impl_apply_fma_float!(
apply_fma_f32,
f32,
W32,
fma_dense_body_f32_simd,
fma_masked_body_f32_simd
);
impl_apply_fma_float!(
apply_fma_f64,
f64,
W64,
fma_dense_body_f64_simd,
fma_masked_body_f64_simd
);
#[cfg(feature = "datetime")]
impl_apply_datetime!(apply_datetime_i32, i32, W32);
#[cfg(feature = "datetime")]
impl_apply_datetime!(apply_datetime_u32, u32, W32);
#[cfg(feature = "datetime")]
impl_apply_datetime!(apply_datetime_i64, i64, W64);
#[cfg(feature = "datetime")]
impl_apply_datetime!(apply_datetime_u64, u64, W64);