use crate::Bitmask;
use crate::enums::operators::ArithmeticOperator;
use num_traits::{Float, PrimInt, ToPrimitive, WrappingAdd, WrappingMul, WrappingSub};
#[inline(always)]
pub fn int_dense_body_std<T: PrimInt + ToPrimitive + WrappingAdd + WrappingSub + WrappingMul>(
op: ArithmeticOperator,
lhs: &[T],
rhs: &[T],
out: &mut [T],
) {
let n = lhs.len();
for i in 0..n {
out[i] = match op {
ArithmeticOperator::Add => lhs[i].wrapping_add(&rhs[i]),
ArithmeticOperator::Subtract => lhs[i].wrapping_sub(&rhs[i]),
ArithmeticOperator::Multiply => lhs[i].wrapping_mul(&rhs[i]),
ArithmeticOperator::Divide => {
if rhs[i] == T::zero() {
panic!("Division by zero")
} else {
lhs[i] / rhs[i]
}
}
ArithmeticOperator::Remainder => {
if rhs[i] == T::zero() {
panic!("Remainder by zero")
} else {
lhs[i] % rhs[i]
}
}
ArithmeticOperator::Power => lhs[i].pow(rhs[i].to_u32().unwrap_or(0)),
ArithmeticOperator::FloorDiv => {
if rhs[i] == T::zero() {
panic!("Floor division by zero")
} else {
let d = lhs[i] / rhs[i];
let r = lhs[i] % rhs[i];
if r != T::zero() && (lhs[i] ^ rhs[i]) < T::zero() { d - T::one() } else { d }
}
}
};
}
}
#[inline(always)]
pub fn int_masked_body_std<T: PrimInt + ToPrimitive + WrappingAdd + WrappingSub + WrappingMul>(
op: ArithmeticOperator,
lhs: &[T],
rhs: &[T],
mask: &Bitmask,
out: &mut [T],
out_mask: &mut Bitmask,
) {
let n = lhs.len();
for i in 0..n {
let valid = unsafe { mask.get_unchecked(i) };
if valid {
let (result, final_valid) = match op {
ArithmeticOperator::Add => (lhs[i].wrapping_add(&rhs[i]), true),
ArithmeticOperator::Subtract => (lhs[i].wrapping_sub(&rhs[i]), true),
ArithmeticOperator::Multiply => (lhs[i].wrapping_mul(&rhs[i]), true),
ArithmeticOperator::Divide => {
if rhs[i] == T::zero() {
(T::zero(), false) } else {
(lhs[i] / rhs[i], true)
}
}
ArithmeticOperator::Remainder => {
if rhs[i] == T::zero() {
(T::zero(), false) } else {
(lhs[i] % rhs[i], true)
}
}
ArithmeticOperator::Power => (lhs[i].pow(rhs[i].to_u32().unwrap_or(0)), true),
ArithmeticOperator::FloorDiv => {
if rhs[i] == T::zero() {
(T::zero(), false)
} else {
let d = lhs[i] / rhs[i];
let r = lhs[i] % rhs[i];
if r != T::zero() && (lhs[i] ^ rhs[i]) < T::zero() { (d - T::one(), true) } else { (d, true) }
}
}
};
out[i] = result;
unsafe {
out_mask.set_unchecked(i, final_valid);
}
} else {
out[i] = T::zero();
unsafe {
out_mask.set_unchecked(i, false);
}
}
}
}
#[inline(always)]
pub fn float_dense_body_std<T: Float>(op: ArithmeticOperator, lhs: &[T], rhs: &[T], out: &mut [T]) {
let n = lhs.len();
for i in 0..n {
out[i] = match op {
ArithmeticOperator::Add => lhs[i] + rhs[i],
ArithmeticOperator::Subtract => lhs[i] - rhs[i],
ArithmeticOperator::Multiply => lhs[i] * rhs[i],
ArithmeticOperator::Divide => lhs[i] / rhs[i],
ArithmeticOperator::Remainder => lhs[i] % rhs[i],
ArithmeticOperator::Power => (rhs[i] * lhs[i].ln()).exp(),
ArithmeticOperator::FloorDiv => (lhs[i] / rhs[i]).floor(),
};
}
}
#[inline(always)]
pub fn float_masked_body_std<T: Float>(
op: ArithmeticOperator,
lhs: &[T],
rhs: &[T],
mask: &Bitmask,
out: &mut [T],
out_mask: &mut Bitmask,
) {
let n = lhs.len();
for i in 0..n {
let valid = unsafe { mask.get_unchecked(i) };
if valid {
out[i] = match op {
ArithmeticOperator::Add => lhs[i] + rhs[i],
ArithmeticOperator::Subtract => lhs[i] - rhs[i],
ArithmeticOperator::Multiply => lhs[i] * rhs[i],
ArithmeticOperator::Divide => lhs[i] / rhs[i],
ArithmeticOperator::Remainder => lhs[i] % rhs[i],
ArithmeticOperator::Power => (rhs[i] * lhs[i].ln()).exp(),
ArithmeticOperator::FloorDiv => (lhs[i] / rhs[i]).floor(),
};
unsafe {
out_mask.set_unchecked(i, true);
}
} else {
out[i] = T::zero();
unsafe {
out_mask.set_unchecked(i, false);
}
}
}
}
#[inline(always)]
pub fn fma_masked_body_std<T: Float>(
lhs: &[T],
rhs: &[T],
acc: &[T],
mask: &Bitmask,
out: &mut [T],
out_mask: &mut Bitmask,
) {
let n = lhs.len();
for i in 0..n {
let valid = unsafe { mask.get_unchecked(i) };
if valid {
out[i] = lhs[i].mul_add(rhs[i], acc[i]);
unsafe {
out_mask.set_unchecked(i, true);
}
} else {
out[i] = T::zero();
unsafe {
out_mask.set_unchecked(i, false);
}
}
}
}
#[inline(always)]
pub fn fma_dense_body_std<T: Float>(lhs: &[T], rhs: &[T], acc: &[T], out: &mut [T]) {
let n = lhs.len();
for i in 0..n {
out[i] = lhs[i].mul_add(rhs[i], acc[i]);
}
}