#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "x86_64")]
mod avx512;
#[cfg(target_arch = "aarch64")]
mod aarch64;
use super::{SimdLevel, detect_simd};
const SIMD_THRESHOLD: usize = 16;
pub unsafe fn cumsum_strided_f32(
a: *const f32,
out: *mut f32,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
let level = detect_simd();
if inner_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
cumsum_strided_scalar_f32(a, out, scan_size, outer_size, inner_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::cumsum_strided_f32(a, out, scan_size, outer_size, inner_size),
SimdLevel::Avx2Fma => avx2::cumsum_strided_f32(a, out, scan_size, outer_size, inner_size),
_ => cumsum_strided_scalar_f32(a, out, scan_size, outer_size, inner_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::cumsum_strided_f32(a, out, scan_size, outer_size, inner_size)
}
_ => cumsum_strided_scalar_f32(a, out, scan_size, outer_size, inner_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
cumsum_strided_scalar_f32(a, out, scan_size, outer_size, inner_size);
}
pub unsafe fn cumsum_strided_f64(
a: *const f64,
out: *mut f64,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
let level = detect_simd();
if inner_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
cumsum_strided_scalar_f64(a, out, scan_size, outer_size, inner_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::cumsum_strided_f64(a, out, scan_size, outer_size, inner_size),
SimdLevel::Avx2Fma => avx2::cumsum_strided_f64(a, out, scan_size, outer_size, inner_size),
_ => cumsum_strided_scalar_f64(a, out, scan_size, outer_size, inner_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::cumsum_strided_f64(a, out, scan_size, outer_size, inner_size)
}
_ => cumsum_strided_scalar_f64(a, out, scan_size, outer_size, inner_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
cumsum_strided_scalar_f64(a, out, scan_size, outer_size, inner_size);
}
pub unsafe fn cumprod_strided_f32(
a: *const f32,
out: *mut f32,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
let level = detect_simd();
if inner_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
cumprod_strided_scalar_f32(a, out, scan_size, outer_size, inner_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::cumprod_strided_f32(a, out, scan_size, outer_size, inner_size),
SimdLevel::Avx2Fma => avx2::cumprod_strided_f32(a, out, scan_size, outer_size, inner_size),
_ => cumprod_strided_scalar_f32(a, out, scan_size, outer_size, inner_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::cumprod_strided_f32(a, out, scan_size, outer_size, inner_size)
}
_ => cumprod_strided_scalar_f32(a, out, scan_size, outer_size, inner_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
cumprod_strided_scalar_f32(a, out, scan_size, outer_size, inner_size);
}
pub unsafe fn cumprod_strided_f64(
a: *const f64,
out: *mut f64,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
let level = detect_simd();
if inner_size < SIMD_THRESHOLD || level == SimdLevel::Scalar {
cumprod_strided_scalar_f64(a, out, scan_size, outer_size, inner_size);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::cumprod_strided_f64(a, out, scan_size, outer_size, inner_size),
SimdLevel::Avx2Fma => avx2::cumprod_strided_f64(a, out, scan_size, outer_size, inner_size),
_ => cumprod_strided_scalar_f64(a, out, scan_size, outer_size, inner_size),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::cumprod_strided_f64(a, out, scan_size, outer_size, inner_size)
}
_ => cumprod_strided_scalar_f64(a, out, scan_size, outer_size, inner_size),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
cumprod_strided_scalar_f64(a, out, scan_size, outer_size, inner_size);
}
#[inline]
unsafe fn cumsum_strided_scalar_f32(
a: *const f32,
out: *mut f32,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
for o in 0..outer_size {
for i in 0..inner_size {
let mut acc = 0.0f32;
for s in 0..scan_size {
let idx = o * scan_size * inner_size + s * inner_size + i;
acc += *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[inline]
unsafe fn cumsum_strided_scalar_f64(
a: *const f64,
out: *mut f64,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
for o in 0..outer_size {
for i in 0..inner_size {
let mut acc = 0.0f64;
for s in 0..scan_size {
let idx = o * scan_size * inner_size + s * inner_size + i;
acc += *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[inline]
unsafe fn cumprod_strided_scalar_f32(
a: *const f32,
out: *mut f32,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
for o in 0..outer_size {
for i in 0..inner_size {
let mut acc = 1.0f32;
for s in 0..scan_size {
let idx = o * scan_size * inner_size + s * inner_size + i;
acc *= *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[inline]
unsafe fn cumprod_strided_scalar_f64(
a: *const f64,
out: *mut f64,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
for o in 0..outer_size {
for i in 0..inner_size {
let mut acc = 1.0f64;
for s in 0..scan_size {
let idx = o * scan_size * inner_size + s * inner_size + i;
acc *= *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[cfg(feature = "f16")]
pub unsafe fn cumsum_strided_f16(
a: *const half::f16,
out: *mut half::f16,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
use super::half_convert_utils::*;
let total = outer_size * scan_size * inner_size;
let mut a_f32 = vec![0.0f32; total];
let mut out_f32 = vec![0.0f32; total];
convert_f16_to_f32(a as *const u16, a_f32.as_mut_ptr(), total);
cumsum_strided_f32(
a_f32.as_ptr(),
out_f32.as_mut_ptr(),
scan_size,
outer_size,
inner_size,
);
convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, total);
}
#[cfg(feature = "f16")]
pub unsafe fn cumsum_strided_bf16(
a: *const half::bf16,
out: *mut half::bf16,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
use super::half_convert_utils::*;
let total = outer_size * scan_size * inner_size;
let mut a_f32 = vec![0.0f32; total];
let mut out_f32 = vec![0.0f32; total];
convert_bf16_to_f32(a as *const u16, a_f32.as_mut_ptr(), total);
cumsum_strided_f32(
a_f32.as_ptr(),
out_f32.as_mut_ptr(),
scan_size,
outer_size,
inner_size,
);
convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, total);
}
#[cfg(feature = "f16")]
pub unsafe fn cumprod_strided_f16(
a: *const half::f16,
out: *mut half::f16,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
use super::half_convert_utils::*;
let total = outer_size * scan_size * inner_size;
let mut a_f32 = vec![0.0f32; total];
let mut out_f32 = vec![0.0f32; total];
convert_f16_to_f32(a as *const u16, a_f32.as_mut_ptr(), total);
cumprod_strided_f32(
a_f32.as_ptr(),
out_f32.as_mut_ptr(),
scan_size,
outer_size,
inner_size,
);
convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, total);
}
#[cfg(feature = "f16")]
pub unsafe fn cumprod_strided_bf16(
a: *const half::bf16,
out: *mut half::bf16,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
use super::half_convert_utils::*;
let total = outer_size * scan_size * inner_size;
let mut a_f32 = vec![0.0f32; total];
let mut out_f32 = vec![0.0f32; total];
convert_bf16_to_f32(a as *const u16, a_f32.as_mut_ptr(), total);
cumprod_strided_f32(
a_f32.as_ptr(),
out_f32.as_mut_ptr(),
scan_size,
outer_size,
inner_size,
);
convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, total);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cumsum_strided_f32() {
let input: Vec<f32> = (0..24).map(|x| x as f32).collect();
let mut output = vec![0.0f32; 24];
unsafe {
cumsum_strided_f32(input.as_ptr(), output.as_mut_ptr(), 3, 2, 4);
}
assert_eq!(output[0], 0.0);
assert_eq!(output[4], 4.0);
assert_eq!(output[8], 12.0);
assert_eq!(output[1], 1.0);
assert_eq!(output[5], 6.0);
assert_eq!(output[9], 15.0);
}
#[test]
fn test_cumprod_strided_f32() {
let input = vec![1.0f32, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0];
let mut output = vec![0.0f32; 8];
unsafe {
cumprod_strided_f32(input.as_ptr(), output.as_mut_ptr(), 4, 1, 2);
}
assert_eq!(output[0], 1.0);
assert_eq!(output[2], 2.0);
assert_eq!(output[4], 6.0);
assert_eq!(output[6], 24.0);
assert_eq!(output[1], 2.0);
assert_eq!(output[3], 6.0);
assert_eq!(output[5], 24.0);
assert_eq!(output[7], 120.0);
}
}