#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use crate::ops::CompareOp;
const F32_LANES: usize = 4;
const F64_LANES: usize = 2;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn compare_f32(op: CompareOp, a: *const f32, b: *const f32, out: *mut f32, len: usize) {
let chunks = len / F32_LANES;
let remainder = len % F32_LANES;
let one = vdupq_n_f32(1.0);
let zero = vdupq_n_f32(0.0);
match op {
CompareOp::Eq => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let vb = vld1q_f32(b.add(offset));
let mask = vceqq_f32(va, vb);
let result = vbslq_f32(mask, one, zero);
vst1q_f32(out.add(offset), result);
}
}
CompareOp::Ne => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let vb = vld1q_f32(b.add(offset));
let mask = vceqq_f32(va, vb);
let result = vbslq_f32(mask, zero, one);
vst1q_f32(out.add(offset), result);
}
}
CompareOp::Lt => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let vb = vld1q_f32(b.add(offset));
let mask = vcltq_f32(va, vb);
let result = vbslq_f32(mask, one, zero);
vst1q_f32(out.add(offset), result);
}
}
CompareOp::Le => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let vb = vld1q_f32(b.add(offset));
let mask = vcleq_f32(va, vb);
let result = vbslq_f32(mask, one, zero);
vst1q_f32(out.add(offset), result);
}
}
CompareOp::Gt => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let vb = vld1q_f32(b.add(offset));
let mask = vcgtq_f32(va, vb);
let result = vbslq_f32(mask, one, zero);
vst1q_f32(out.add(offset), result);
}
}
CompareOp::Ge => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let vb = vld1q_f32(b.add(offset));
let mask = vcgeq_f32(va, vb);
let result = vbslq_f32(mask, one, zero);
vst1q_f32(out.add(offset), result);
}
}
}
if remainder > 0 {
let offset = chunks * F32_LANES;
super::super::compare_scalar_f32(
op,
a.add(offset),
b.add(offset),
out.add(offset),
remainder,
);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn compare_f64(op: CompareOp, a: *const f64, b: *const f64, out: *mut f64, len: usize) {
let chunks = len / F64_LANES;
let remainder = len % F64_LANES;
let one = vdupq_n_f64(1.0);
let zero = vdupq_n_f64(0.0);
match op {
CompareOp::Eq => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let vb = vld1q_f64(b.add(offset));
let mask = vceqq_f64(va, vb);
let result = vbslq_f64(mask, one, zero);
vst1q_f64(out.add(offset), result);
}
}
CompareOp::Ne => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let vb = vld1q_f64(b.add(offset));
let mask = vceqq_f64(va, vb);
let result = vbslq_f64(mask, zero, one);
vst1q_f64(out.add(offset), result);
}
}
CompareOp::Lt => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let vb = vld1q_f64(b.add(offset));
let mask = vcltq_f64(va, vb);
let result = vbslq_f64(mask, one, zero);
vst1q_f64(out.add(offset), result);
}
}
CompareOp::Le => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let vb = vld1q_f64(b.add(offset));
let mask = vcleq_f64(va, vb);
let result = vbslq_f64(mask, one, zero);
vst1q_f64(out.add(offset), result);
}
}
CompareOp::Gt => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let vb = vld1q_f64(b.add(offset));
let mask = vcgtq_f64(va, vb);
let result = vbslq_f64(mask, one, zero);
vst1q_f64(out.add(offset), result);
}
}
CompareOp::Ge => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let vb = vld1q_f64(b.add(offset));
let mask = vcgeq_f64(va, vb);
let result = vbslq_f64(mask, one, zero);
vst1q_f64(out.add(offset), result);
}
}
}
if remainder > 0 {
let offset = chunks * F64_LANES;
super::super::compare_scalar_f64(
op,
a.add(offset),
b.add(offset),
out.add(offset),
remainder,
);
}
}