use std::ptr;
pub unsafe fn copy_strided_to_contiguous_f32(
src: *const f32,
dst: &mut [f32],
n_elements: usize,
stride: usize,
) {
debug_assert!(
dst.len() >= n_elements,
"dst must have at least n_elements slots"
);
if stride == 1 {
unsafe {
ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), n_elements);
}
return;
}
#[cfg(target_arch = "x86_64")]
{
const AVX2_LANES: usize = 8;
if is_x86_feature_detected!("avx2") && stride <= (i32::MAX as usize) / (AVX2_LANES - 1) {
unsafe {
gather_f32_avx2(src, dst, n_elements, stride);
}
return;
}
}
unsafe {
scalar_gather_f32(src, dst, n_elements, stride);
}
}
pub unsafe fn copy_strided_to_contiguous_f64(
src: *const f64,
dst: &mut [f64],
n_elements: usize,
stride: usize,
) {
debug_assert!(
dst.len() >= n_elements,
"dst must have at least n_elements slots"
);
if stride == 1 {
unsafe {
ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), n_elements);
}
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe {
gather_f64_avx2(src, dst, n_elements, stride);
}
return;
}
}
unsafe {
scalar_gather_f64(src, dst, n_elements, stride);
}
}
#[inline]
unsafe fn scalar_gather_f32(src: *const f32, dst: &mut [f32], n_elements: usize, stride: usize) {
for i in 0..n_elements {
*dst.get_unchecked_mut(i) = *src.add(i * stride);
}
}
#[inline]
unsafe fn scalar_gather_f64(src: *const f64, dst: &mut [f64], n_elements: usize, stride: usize) {
for i in 0..n_elements {
*dst.get_unchecked_mut(i) = *src.add(i * stride);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn gather_f32_avx2(src: *const f32, dst: &mut [f32], n_elements: usize, stride: usize) {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let stride_i32 = stride as i32;
let vindex = _mm256_set_epi32(
7 * stride_i32,
6 * stride_i32,
5 * stride_i32,
4 * stride_i32,
3 * stride_i32,
2 * stride_i32,
stride_i32,
0,
);
let chunks = n_elements / 8;
let remainder = n_elements % 8;
let mut dst_ptr = dst.as_mut_ptr();
for chunk in 0..chunks {
let chunk_src = src.add(chunk * 8 * stride);
let gathered = _mm256_i32gather_ps(chunk_src, vindex, 4);
_mm256_storeu_ps(dst_ptr, gathered);
dst_ptr = dst_ptr.add(8);
}
let tail_src_base = src.add(chunks * 8 * stride);
for i in 0..remainder {
*dst_ptr.add(i) = *tail_src_base.add(i * stride);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn gather_f64_avx2(src: *const f64, dst: &mut [f64], n_elements: usize, stride: usize) {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let stride_i64 = stride as i64;
let vindex = _mm256_set_epi64x(3 * stride_i64, 2 * stride_i64, stride_i64, 0);
let chunks = n_elements / 4;
let remainder = n_elements % 4;
let mut dst_ptr = dst.as_mut_ptr();
for chunk in 0..chunks {
let chunk_src = src.add(chunk * 4 * stride);
let gathered = _mm256_i64gather_pd(chunk_src, vindex, 8);
_mm256_storeu_pd(dst_ptr, gathered);
dst_ptr = dst_ptr.add(4);
}
let tail_src_base = src.add(chunks * 4 * stride);
for i in 0..remainder {
*dst_ptr.add(i) = *tail_src_base.add(i * stride);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_f32_stride1_is_memcpy() {
let src: Vec<f32> = (0..16).map(|x| x as f32).collect();
let mut dst = vec![0.0_f32; 16];
unsafe {
copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 16, 1);
}
assert_eq!(dst, src);
}
#[test]
fn test_f32_stride2() {
let src: Vec<f32> = (0..18).map(|x| x as f32).collect();
let expected: Vec<f32> = (0..9).map(|x| (x * 2) as f32).collect();
let mut dst = vec![0.0_f32; 9];
unsafe {
copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 9, 2);
}
assert_eq!(dst, expected);
}
#[test]
fn test_f32_stride3() {
let src: Vec<f32> = (0..21).map(|x| x as f32).collect();
let expected: Vec<f32> = (0..7).map(|x| (x * 3) as f32).collect();
let mut dst = vec![0.0_f32; 7];
unsafe {
copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 7, 3);
}
assert_eq!(dst, expected);
}
#[test]
fn test_f64_stride1_is_memcpy() {
let src: Vec<f64> = (0..16).map(|x| x as f64).collect();
let mut dst = vec![0.0_f64; 16];
unsafe {
copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, 16, 1);
}
assert_eq!(dst, src);
}
#[test]
fn test_f64_stride2() {
let src: Vec<f64> = (0..18).map(|x| x as f64).collect();
let expected: Vec<f64> = (0..9).map(|x| (x * 2) as f64).collect();
let mut dst = vec![0.0_f64; 9];
unsafe {
copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, 9, 2);
}
assert_eq!(dst, expected);
}
#[test]
fn test_f64_stride4() {
let n = 10_000_usize;
let stride = 4_usize;
let src: Vec<f64> = (0..(n * stride)).map(|x| x as f64).collect();
let expected: Vec<f64> = (0..n).map(|x| (x * stride) as f64).collect();
let mut dst = vec![0.0_f64; n];
unsafe {
copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, n, stride);
}
assert_eq!(dst, expected);
}
#[test]
fn benchmark_copy_overhead_documentation() {
let n = 1_000_000_usize;
let stride = 3_usize;
let src: Vec<f64> = (0..(n * stride)).map(|x| x as f64).collect();
let mut dst = vec![0.0_f64; n];
let start = std::time::Instant::now();
unsafe {
copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, n, stride);
}
let elapsed = start.elapsed();
assert_eq!(dst[0], 0.0);
assert_eq!(dst[n - 1], ((n - 1) * stride) as f64);
eprintln!(
"copy_strided_to_contiguous_f64: {} elements, stride={}, elapsed={:.2?}",
n, stride, elapsed
);
}
}