#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use crate::ops::BinaryOp;
const F32_LANES: usize = 4;
const F64_LANES: usize = 2;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn scalar_f32(op: BinaryOp, a: *const f32, scalar: f32, out: *mut f32, len: usize) {
let chunks = len / F32_LANES;
let remainder = len % F32_LANES;
let vs = vdupq_n_f32(scalar);
match op {
BinaryOp::Add => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let result = vaddq_f32(va, vs);
vst1q_f32(out.add(offset), result);
}
}
BinaryOp::Sub => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let result = vsubq_f32(va, vs);
vst1q_f32(out.add(offset), result);
}
}
BinaryOp::Mul => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let result = vmulq_f32(va, vs);
vst1q_f32(out.add(offset), result);
}
}
BinaryOp::Div => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let result = vdivq_f32(va, vs);
vst1q_f32(out.add(offset), result);
}
}
BinaryOp::Max => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let result = vmaxq_f32(va, vs);
vst1q_f32(out.add(offset), result);
}
}
BinaryOp::Min => {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let result = vminq_f32(va, vs);
vst1q_f32(out.add(offset), result);
}
}
BinaryOp::Pow | BinaryOp::Atan2 => {
super::super::scalar_scalar_f32(op, a, scalar, out, len);
return;
}
}
if remainder > 0 {
let offset = chunks * F32_LANES;
super::super::scalar_scalar_f32(op, a.add(offset), scalar, out.add(offset), remainder);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn scalar_f64(op: BinaryOp, a: *const f64, scalar: f64, out: *mut f64, len: usize) {
let chunks = len / F64_LANES;
let remainder = len % F64_LANES;
let vs = vdupq_n_f64(scalar);
match op {
BinaryOp::Add => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let result = vaddq_f64(va, vs);
vst1q_f64(out.add(offset), result);
}
}
BinaryOp::Sub => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let result = vsubq_f64(va, vs);
vst1q_f64(out.add(offset), result);
}
}
BinaryOp::Mul => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let result = vmulq_f64(va, vs);
vst1q_f64(out.add(offset), result);
}
}
BinaryOp::Div => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let result = vdivq_f64(va, vs);
vst1q_f64(out.add(offset), result);
}
}
BinaryOp::Max => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let result = vmaxq_f64(va, vs);
vst1q_f64(out.add(offset), result);
}
}
BinaryOp::Min => {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let result = vminq_f64(va, vs);
vst1q_f64(out.add(offset), result);
}
}
BinaryOp::Pow | BinaryOp::Atan2 => {
super::super::scalar_scalar_f64(op, a, scalar, out, len);
return;
}
}
if remainder > 0 {
let offset = chunks * F64_LANES;
super::super::scalar_scalar_f64(op, a.add(offset), scalar, out.add(offset), remainder);
}
}
pub unsafe fn rsub_scalar_f32(a: *const f32, scalar: f32, out: *mut f32, len: usize) {
let chunks = len / F32_LANES;
let remainder = len % F32_LANES;
let vs = vdupq_n_f32(scalar);
for i in 0..chunks {
let offset = i * F32_LANES;
let va = vld1q_f32(a.add(offset));
let vr = vsubq_f32(vs, va);
vst1q_f32(out.add(offset), vr);
}
for i in 0..remainder {
let offset = chunks * F32_LANES + i;
*out.add(offset) = scalar - *a.add(offset);
}
}
pub unsafe fn rsub_scalar_f64(a: *const f64, scalar: f64, out: *mut f64, len: usize) {
let chunks = len / F64_LANES;
let remainder = len % F64_LANES;
let vs = vdupq_n_f64(scalar);
for i in 0..chunks {
let offset = i * F64_LANES;
let va = vld1q_f64(a.add(offset));
let vr = vsubq_f64(vs, va);
vst1q_f64(out.add(offset), vr);
}
for i in 0..remainder {
let offset = chunks * F64_LANES + i;
*out.add(offset) = scalar - *a.add(offset);
}
}