#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[target_feature(enable = "avx2")]
pub unsafe fn cumsum_strided_f32(
a: *const f32,
out: *mut f32,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
const LANES: usize = 8;
let chunks = inner_size / LANES;
for o in 0..outer_size {
let outer_base = o * scan_size * inner_size;
for chunk in 0..chunks {
let i_base = chunk * LANES;
let mut acc = _mm256_setzero_ps();
for s in 0..scan_size {
let idx = outer_base + s * inner_size + i_base;
let val = _mm256_loadu_ps(a.add(idx));
acc = _mm256_add_ps(acc, val);
_mm256_storeu_ps(out.add(idx), acc);
}
}
let i_start = chunks * LANES;
for i in i_start..inner_size {
let mut acc = 0.0f32;
for s in 0..scan_size {
let idx = outer_base + s * inner_size + i;
acc += *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn cumsum_strided_f64(
a: *const f64,
out: *mut f64,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
const LANES: usize = 4;
let chunks = inner_size / LANES;
for o in 0..outer_size {
let outer_base = o * scan_size * inner_size;
for chunk in 0..chunks {
let i_base = chunk * LANES;
let mut acc = _mm256_setzero_pd();
for s in 0..scan_size {
let idx = outer_base + s * inner_size + i_base;
let val = _mm256_loadu_pd(a.add(idx));
acc = _mm256_add_pd(acc, val);
_mm256_storeu_pd(out.add(idx), acc);
}
}
let i_start = chunks * LANES;
for i in i_start..inner_size {
let mut acc = 0.0f64;
for s in 0..scan_size {
let idx = outer_base + s * inner_size + i;
acc += *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn cumprod_strided_f32(
a: *const f32,
out: *mut f32,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
const LANES: usize = 8;
let chunks = inner_size / LANES;
for o in 0..outer_size {
let outer_base = o * scan_size * inner_size;
for chunk in 0..chunks {
let i_base = chunk * LANES;
let mut acc = _mm256_set1_ps(1.0);
for s in 0..scan_size {
let idx = outer_base + s * inner_size + i_base;
let val = _mm256_loadu_ps(a.add(idx));
acc = _mm256_mul_ps(acc, val);
_mm256_storeu_ps(out.add(idx), acc);
}
}
let i_start = chunks * LANES;
for i in i_start..inner_size {
let mut acc = 1.0f32;
for s in 0..scan_size {
let idx = outer_base + s * inner_size + i;
acc *= *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn cumprod_strided_f64(
a: *const f64,
out: *mut f64,
scan_size: usize,
outer_size: usize,
inner_size: usize,
) {
const LANES: usize = 4;
let chunks = inner_size / LANES;
for o in 0..outer_size {
let outer_base = o * scan_size * inner_size;
for chunk in 0..chunks {
let i_base = chunk * LANES;
let mut acc = _mm256_set1_pd(1.0);
for s in 0..scan_size {
let idx = outer_base + s * inner_size + i_base;
let val = _mm256_loadu_pd(a.add(idx));
acc = _mm256_mul_pd(acc, val);
_mm256_storeu_pd(out.add(idx), acc);
}
}
let i_start = chunks * LANES;
for i in i_start..inner_size {
let mut acc = 1.0f64;
for s in 0..scan_size {
let idx = outer_base + s * inner_size + i;
acc *= *a.add(idx);
*out.add(idx) = acc;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn has_avx2() -> bool {
is_x86_feature_detected!("avx2")
}
#[test]
fn test_cumsum_strided_f32_avx2() {
if !has_avx2() {
return;
}
let input: Vec<f32> = (0..64).map(|x| x as f32).collect();
let mut output = vec![0.0f32; 64];
unsafe {
cumsum_strided_f32(input.as_ptr(), output.as_mut_ptr(), 4, 1, 16);
}
assert_eq!(output[0], 0.0);
assert_eq!(output[16], 16.0);
assert_eq!(output[32], 48.0);
assert_eq!(output[48], 96.0);
assert_eq!(output[1], 1.0);
assert_eq!(output[17], 18.0);
assert_eq!(output[33], 51.0);
assert_eq!(output[49], 100.0);
}
#[test]
fn test_cumprod_strided_f32_avx2() {
if !has_avx2() {
return;
}
let input = vec![
2.0f32, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, ];
let mut output = vec![0.0f32; 24];
unsafe {
cumprod_strided_f32(input.as_ptr(), output.as_mut_ptr(), 3, 1, 8);
}
for i in 0..8 {
assert_eq!(output[i], 2.0, "s=0, i={}", i);
assert_eq!(output[8 + i], 6.0, "s=1, i={}", i);
assert_eq!(output[16 + i], 24.0, "s=2, i={}", i);
}
}
#[test]
fn test_cumsum_strided_f64_avx2() {
if !has_avx2() {
return;
}
let input: Vec<f64> = (0..24).map(|x| x as f64).collect();
let mut output = vec![0.0f64; 24];
unsafe {
cumsum_strided_f64(input.as_ptr(), output.as_mut_ptr(), 3, 1, 8);
}
assert_eq!(output[0], 0.0);
assert_eq!(output[8], 8.0);
assert_eq!(output[16], 24.0);
}
}