use crate::dtype::Element;
#[inline]
pub unsafe fn where_kernel<T: Element>(
cond: *const u8,
x: *const T,
y: *const T,
out: *mut T,
len: usize,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::where_select;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
where_select::where_f32(
cond,
x as *const f32,
y as *const f32,
out as *mut f32,
len,
);
return;
}
DType::F64 => {
where_select::where_f64(
cond,
x as *const f64,
y as *const f64,
out as *mut f64,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 => {
where_select::where_f16(
cond,
x as *const half::f16,
y as *const half::f16,
out as *mut half::f16,
len,
);
return;
}
#[cfg(feature = "f16")]
DType::BF16 => {
where_select::where_bf16(
cond,
x as *const half::bf16,
y as *const half::bf16,
out as *mut half::bf16,
len,
);
return;
}
_ => {} }
}
where_kernel_generic::<u8, T>(cond, x, y, out, len);
}
#[inline]
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn where_kernel_generic<C: Element, T: Element>(
cond: *const C,
x: *const T,
y: *const T,
out: *mut T,
len: usize,
) {
let cond_slice = std::slice::from_raw_parts(cond, len);
let x_slice = std::slice::from_raw_parts(x, len);
let y_slice = std::slice::from_raw_parts(y, len);
let out_slice = std::slice::from_raw_parts_mut(out, len);
let zero = C::zero();
for i in 0..len {
out_slice[i] = if cond_slice[i] != zero {
x_slice[i]
} else {
y_slice[i]
};
}
}
#[inline]
fn is_contiguous_layout(
ndim: usize,
out_shape: &[usize],
cond_strides: &[isize],
x_strides: &[isize],
y_strides: &[isize],
cond_offset: usize,
x_offset: usize,
y_offset: usize,
) -> bool {
if ndim == 0 {
return false;
}
let mut expected_stride = 1isize;
for i in (0..ndim).rev() {
if cond_strides[i] != expected_stride
|| x_strides[i] != expected_stride
|| y_strides[i] != expected_stride
{
return false;
}
expected_stride *= out_shape[i] as isize;
}
cond_offset == 0 && x_offset == 0 && y_offset == 0
}
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn where_strided_impl<C, T: Element, F>(
cond: *const C,
x: *const T,
y: *const T,
out: *mut T,
out_shape: &[usize],
cond_strides: &[isize],
x_strides: &[isize],
y_strides: &[isize],
cond_offset: usize,
x_offset: usize,
y_offset: usize,
is_true: F,
) where
F: Fn(*const C, isize) -> bool,
{
let ndim = out_shape.len();
let total = out_shape.iter().product::<usize>();
if total == 0 {
return;
}
let mut indices = vec![0usize; ndim];
let mut cond_idx = cond_offset as isize;
let mut x_idx = x_offset as isize;
let mut y_idx = y_offset as isize;
for out_idx in 0..total {
let result = if is_true(cond, cond_idx) {
*x.offset(x_idx)
} else {
*y.offset(y_idx)
};
*out.add(out_idx) = result;
for dim in (0..ndim).rev() {
indices[dim] += 1;
cond_idx += cond_strides[dim];
x_idx += x_strides[dim];
y_idx += y_strides[dim];
if indices[dim] < out_shape[dim] {
break;
}
indices[dim] = 0;
cond_idx -= (out_shape[dim] as isize) * cond_strides[dim];
x_idx -= (out_shape[dim] as isize) * x_strides[dim];
y_idx -= (out_shape[dim] as isize) * y_strides[dim];
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn where_strided_kernel<T: Element>(
cond: *const u8,
x: *const T,
y: *const T,
out: *mut T,
out_shape: &[usize],
cond_strides: &[isize],
x_strides: &[isize],
y_strides: &[isize],
cond_offset: usize,
x_offset: usize,
y_offset: usize,
) {
let ndim = out_shape.len();
let total = out_shape.iter().product::<usize>();
if total == 0 {
return;
}
if is_contiguous_layout(
ndim,
out_shape,
cond_strides,
x_strides,
y_strides,
cond_offset,
x_offset,
y_offset,
) {
where_kernel(cond, x, y, out, total);
return;
}
where_strided_impl(
cond,
x,
y,
out,
out_shape,
cond_strides,
x_strides,
y_strides,
cond_offset,
x_offset,
y_offset,
|cond_ptr, idx| *cond_ptr.offset(idx) != 0,
);
}
#[inline]
#[allow(clippy::too_many_arguments)]
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn where_strided_kernel_generic<C: Element, T: Element>(
cond: *const C,
x: *const T,
y: *const T,
out: *mut T,
out_shape: &[usize],
cond_strides: &[isize],
x_strides: &[isize],
y_strides: &[isize],
cond_offset: usize,
x_offset: usize,
y_offset: usize,
) {
let ndim = out_shape.len();
let total = out_shape.iter().product::<usize>();
if total == 0 {
return;
}
if is_contiguous_layout(
ndim,
out_shape,
cond_strides,
x_strides,
y_strides,
cond_offset,
x_offset,
y_offset,
) {
where_kernel_generic(cond, x, y, out, total);
return;
}
let zero = C::zero();
where_strided_impl(
cond,
x,
y,
out,
out_shape,
cond_strides,
x_strides,
y_strides,
cond_offset,
x_offset,
y_offset,
|cond_ptr, idx| *cond_ptr.offset(idx) != zero,
);
}