#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn cumsum_strided_f32(
a: *const f32,
out: *mut f32,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
let lanes = 4;
let chunks = inner_size / lanes;
let _remainder = inner_size % lanes;
for o in 0..outer_size {
let outer_offset = o * scan_size * inner_size;
for chunk in 0..chunks {
let i_base = chunk * lanes;
let mut acc = vdupq_n_f32(0.0);
for s in 0..scan_size {
let idx = outer_offset + s * inner_size + i_base;
let v = vld1q_f32(a.add(idx));
acc = vaddq_f32(acc, v);
vst1q_f32(out.add(idx), acc);
}
}
for i in (chunks * lanes)..inner_size {
let mut acc = 0.0f32;
for s in 0..scan_size {
let idx = outer_offset + s * inner_size + i;
acc += *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn cumsum_strided_f64(
a: *const f64,
out: *mut f64,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
let lanes = 2;
let chunks = inner_size / lanes;
for o in 0..outer_size {
let outer_offset = o * scan_size * inner_size;
for chunk in 0..chunks {
let i_base = chunk * lanes;
let mut acc = vdupq_n_f64(0.0);
for s in 0..scan_size {
let idx = outer_offset + s * inner_size + i_base;
let v = vld1q_f64(a.add(idx));
acc = vaddq_f64(acc, v);
vst1q_f64(out.add(idx), acc);
}
}
for i in (chunks * lanes)..inner_size {
let mut acc = 0.0f64;
for s in 0..scan_size {
let idx = outer_offset + s * inner_size + i;
acc += *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn cumprod_strided_f32(
a: *const f32,
out: *mut f32,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
let lanes = 4;
let chunks = inner_size / lanes;
for o in 0..outer_size {
let outer_offset = o * scan_size * inner_size;
for chunk in 0..chunks {
let i_base = chunk * lanes;
let mut acc = vdupq_n_f32(1.0);
for s in 0..scan_size {
let idx = outer_offset + s * inner_size + i_base;
let v = vld1q_f32(a.add(idx));
acc = vmulq_f32(acc, v);
vst1q_f32(out.add(idx), acc);
}
}
for i in (chunks * lanes)..inner_size {
let mut acc = 1.0f32;
for s in 0..scan_size {
let idx = outer_offset + s * inner_size + i;
acc *= *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn cumprod_strided_f64(
a: *const f64,
out: *mut f64,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
let lanes = 2;
let chunks = inner_size / lanes;
for o in 0..outer_size {
let outer_offset = o * scan_size * inner_size;
for chunk in 0..chunks {
let i_base = chunk * lanes;
let mut acc = vdupq_n_f64(1.0);
for s in 0..scan_size {
let idx = outer_offset + s * inner_size + i_base;
let v = vld1q_f64(a.add(idx));
acc = vmulq_f64(acc, v);
vst1q_f64(out.add(idx), acc);
}
}
for i in (chunks * lanes)..inner_size {
let mut acc = 1.0f64;
for s in 0..scan_size {
let idx = outer_offset + s * inner_size + i;
acc *= *a.add(idx);
*out.add(idx) = acc;
}
}
}
}