use core::{marker::PhantomData, slice};
use burn_backend::Element;
use macerator::{Mask, Scalar, Simd, VEq, VOrd, Vector, vload_unaligned};
use ndarray::ArrayD;
use seq_macro::seq;
use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};
use super::should_use_simd;
pub trait SimdCmpOp<T: Scalar> {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T>;
fn apply(lhs: T, rhs: T) -> bool;
fn is_accelerated<S: Simd>() -> bool;
}
pub struct VecEquals;
impl<T: VEq> SimdCmpOp<T> for VecEquals {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.eq(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs == rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VEq>::is_accelerated::<S>()
}
}
pub struct VecGreater;
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecGreater {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.gt(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs > rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_cmp_accelerated::<S>()
}
}
pub struct VecGreaterEq;
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecGreaterEq {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.ge(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs >= rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_cmp_accelerated::<S>()
}
}
pub struct VecLowerEq;
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecLowerEq {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.le(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs <= rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_cmp_accelerated::<S>()
}
}
pub struct VecLower;
impl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecLower {
fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {
lhs.lt(rhs)
}
fn apply(lhs: T, rhs: T) -> bool {
lhs < rhs
}
fn is_accelerated<S: Simd>() -> bool {
<T as VOrd>::is_cmp_accelerated::<S>()
}
}
#[macerator::with_simd]
fn is_accelerated<S: Simd, T: Scalar, Op: SimdCmpOp<T>>(_x: PhantomData<(T, Op)>) -> bool {
Op::is_accelerated::<S>()
}
#[allow(clippy::result_large_err)]
pub fn try_cmp_simd<E: Element, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
lhs: SharedArray<E>,
rhs: SharedArray<E>,
) -> Result<SharedArray<bool>, (SharedArray<E>, SharedArray<E>)> {
let lhs_len = lhs.len();
let rhs_len = rhs.len();
if !should_use_simd(lhs_len.max(rhs_len))
|| !lhs.is_standard_layout()
|| !rhs.is_standard_layout()
|| lhs.shape() != rhs.shape()
|| !is_accelerated::<T, Op>(PhantomData)
{
return Err((lhs, rhs));
}
let lhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(lhs) };
let rhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(rhs) };
let out = cmp_simd_same::<T, Op>(lhs, rhs);
Ok(out)
}
fn cmp_simd_same<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
lhs: SharedArray<T>,
rhs: SharedArray<T>,
) -> SharedArray<bool> {
let out = if lhs.is_unique() && size_of::<T>() == size_of::<bool>() {
let mut buf = lhs.into_owned();
let lhs = buf.as_slice_mut().unwrap();
let rhs = rhs.as_slice().unwrap();
let out =
unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(lhs)) };
cmp(lhs, rhs, out, PhantomData::<Op>);
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buf) }
} else if rhs.is_unique() && size_of::<T>() == size_of::<bool>() {
let mut buf = rhs.into_owned();
let lhs = lhs.as_slice().unwrap();
let rhs = buf.as_slice_mut().unwrap();
let out =
unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(rhs)) };
cmp(lhs, rhs, out, PhantomData::<Op>);
unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buf) }
} else {
let mut out = uninit_array_like(&lhs);
let lhs = lhs.as_slice().unwrap();
let rhs = rhs.as_slice().unwrap();
let out_slice = out.as_slice_mut().unwrap();
cmp(lhs, rhs, out_slice, PhantomData::<Op>);
out
};
out.into_shared()
}
#[allow(clippy::erasing_op, clippy::identity_op)]
#[macerator::with_simd]
fn cmp<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
lhs: &'a [T],
rhs: &'a [T],
out: &'a mut [bool],
_op: PhantomData<Op>,
) where
'a: 'a,
{
let lanes = T::lanes::<S>();
let mut chunks_lhs = lhs.chunks_exact(8 * lanes);
let mut chunks_rhs = rhs.chunks_exact(8 * lanes);
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
while let Some(((lhs, rhs), out)) = chunks_lhs
.next()
.zip(chunks_rhs.next())
.zip(chunks_out.next())
{
seq!(N in 0..8 {
let lhs~N = unsafe { vload_unaligned::<S, _>(&lhs[N * lanes]) };
let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) };
let s~N = Op::apply_vec(lhs~N, rhs~N);
unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) };
});
}
let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes);
let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes);
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
while let Some(((lhs, rhs), out)) = chunks_lhs
.next()
.zip(chunks_rhs.next())
.zip(chunks_out.next())
{
let lhs0 = unsafe { vload_unaligned::<S, _>(lhs.as_ptr()) };
let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) };
let s0 = Op::apply_vec(lhs0, rhs0);
unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) };
}
for ((lhs, rhs), out) in chunks_lhs
.remainder()
.iter()
.zip(chunks_rhs.remainder())
.zip(chunks_out.into_remainder())
{
*out = Op::apply(*lhs, *rhs)
}
}
fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] {
let ptr = slice.as_mut_ptr();
let len = slice.len();
unsafe { slice::from_raw_parts_mut(ptr, len) }
}
pub use elemwise::try_cmp_scalar_simd;
mod elemwise {
use bytemuck::cast;
use macerator::vload;
use super::*;
pub fn try_cmp_scalar_simd<E: Element, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
input: SharedArray<E>,
elem: T,
) -> Result<SharedArray<bool>, SharedArray<E>> {
if !should_use_simd(input.len())
|| input.as_slice_memory_order().is_none()
|| !is_accelerated::<T, Op>(PhantomData)
{
return Err(input);
}
let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };
let out = if size_of::<T>() == size_of::<bool>()
&& align_of::<T>() >= align_of::<bool>()
&& input.is_unique()
{
unsafe { cmp_scalar_simd_inplace::<T, Op>(input, elem) }
} else {
cmp_scalar_simd_owned::<T, Op>(input, elem)
};
Ok(out)
}
unsafe fn cmp_scalar_simd_inplace<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
input: SharedArray<T>,
elem: T,
) -> SharedArray<bool> {
let mut buffer = input.into_owned();
let slice = buffer.as_slice_memory_order_mut().unwrap();
unsafe { cmp_scalar_slice_inplace::<T, Op>(slice, elem, PhantomData) };
let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buffer) };
out.into_shared()
}
fn cmp_scalar_simd_owned<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
input: SharedArray<T>,
elem: T,
) -> SharedArray<bool> {
let mut out = uninit_array_like(&input);
let input = input.as_slice_memory_order().unwrap();
let out_slice = out.as_slice_memory_order_mut().unwrap();
cmp_scalar_slice::<T, Op>(input, out_slice, elem, PhantomData);
out.into_shared()
}
#[inline(always)]
#[allow(clippy::erasing_op, clippy::identity_op)]
#[macerator::with_simd]
fn cmp_scalar_slice<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
input: &'a [T],
out: &'a mut [bool],
rhs: T,
_op: PhantomData<Op>,
) where
'a: 'a,
{
let lanes = T::lanes::<S>();
let mut chunks_input = input.chunks_exact(8 * lanes);
let mut chunks_out = out.chunks_exact_mut(8 * lanes);
let rhs_vec = rhs.splat::<S>();
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
seq!(N in 0..8 {
let s~N = unsafe { vload_unaligned(&input[N * lanes]) };
let s~N = Op::apply_vec(s~N, rhs_vec);
unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) };
});
}
let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);
let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);
while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {
let s0 = unsafe { vload_unaligned(input.as_ptr()) };
let s0 = Op::apply_vec(s0, rhs_vec);
unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) };
}
for (input, out) in chunks_input
.remainder()
.iter()
.zip(chunks_out.into_remainder())
{
*out = Op::apply(*input, rhs)
}
}
#[inline(always)]
#[macerator::with_simd]
unsafe fn cmp_scalar_slice_inplace<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(
buf: &'a mut [T],
rhs: T,
_op: PhantomData<Op>,
) where
'a: 'a,
{
let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };
for elem in head.iter_mut().chain(tail) {
*elem = cast(Op::apply(*elem, rhs));
}
let mut chunks = main.chunks_exact_mut(8);
let rhs = rhs.splat::<S>();
for elem in chunks.by_ref() {
seq!(N in 0..8 {
let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };
let s~N = Op::apply_vec(s~N, rhs);
unsafe { T::mask_store_as_bool(&mut elem[N] as *mut _ as *mut bool, s~N) };
});
}
for elem in chunks.into_remainder() {
let s0 = unsafe { vload(elem as *const _ as *const T) };
let s0 = Op::apply_vec(s0, rhs);
unsafe { T::mask_store_as_bool(elem as *mut _ as *mut bool, s0) };
}
}
}