#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
const F32_LANES: usize = 4;
const F64_LANES: usize = 2;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn where_f32(cond: *const u8, x: *const f32, y: *const f32, out: *mut f32, len: usize) {
let chunks = len / F32_LANES;
let remainder = len % F32_LANES;
for i in 0..chunks {
let offset = i * F32_LANES;
let c0 = *cond.add(offset);
let c1 = *cond.add(offset + 1);
let c2 = *cond.add(offset + 2);
let c3 = *cond.add(offset + 3);
let cond_mask_u32 = vld1q_u32(
[
if c0 != 0 { !0u32 } else { 0u32 },
if c1 != 0 { !0u32 } else { 0u32 },
if c2 != 0 { !0u32 } else { 0u32 },
if c3 != 0 { !0u32 } else { 0u32 },
]
.as_ptr(),
);
let vx = vld1q_f32(x.add(offset));
let vy = vld1q_f32(y.add(offset));
let result = vbslq_f32(cond_mask_u32, vx, vy);
vst1q_f32(out.add(offset), result);
}
if remainder > 0 {
let offset = chunks * F32_LANES;
super::super::where_scalar_f32(
cond.add(offset),
x.add(offset),
y.add(offset),
out.add(offset),
remainder,
);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn where_f64(cond: *const u8, x: *const f64, y: *const f64, out: *mut f64, len: usize) {
let chunks = len / F64_LANES;
let remainder = len % F64_LANES;
for i in 0..chunks {
let offset = i * F64_LANES;
let c0 = *cond.add(offset);
let c1 = *cond.add(offset + 1);
let m0: u64 = if c0 != 0 { !0u64 } else { 0u64 };
let m1: u64 = if c1 != 0 { !0u64 } else { 0u64 };
let mask = vld1q_u64([m0, m1].as_ptr());
let vx = vld1q_f64(x.add(offset));
let vy = vld1q_f64(y.add(offset));
let result = vbslq_f64(mask, vx, vy);
vst1q_f64(out.add(offset), result);
}
if remainder > 0 {
let offset = chunks * F64_LANES;
super::super::where_scalar_f64(
cond.add(offset),
x.add(offset),
y.add(offset),
out.add(offset),
remainder,
);
}
}