#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn merge_histogram_grads_avx2(self_grads: &mut [f32; 256], other_grads: &[f32; 256]) {
for i in (0..256).step_by(8) {
let self_vec = _mm256_loadu_ps(self_grads.as_ptr().add(i));
let other_vec = _mm256_loadu_ps(other_grads.as_ptr().add(i));
let sum = _mm256_add_ps(self_vec, other_vec);
_mm256_storeu_ps(self_grads.as_mut_ptr().add(i), sum);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn merge_histogram_hess_avx2(self_hess: &mut [f32; 256], other_hess: &[f32; 256]) {
for i in (0..256).step_by(8) {
let self_vec = _mm256_loadu_ps(self_hess.as_ptr().add(i));
let other_vec = _mm256_loadu_ps(other_hess.as_ptr().add(i));
let sum = _mm256_add_ps(self_vec, other_vec);
_mm256_storeu_ps(self_hess.as_mut_ptr().add(i), sum);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn merge_histogram_counts_avx2(self_counts: &mut [u32; 256], other_counts: &[u32; 256]) {
for i in (0..256).step_by(8) {
let self_vec = _mm256_loadu_si256(self_counts.as_ptr().add(i) as *const __m256i);
let other_vec = _mm256_loadu_si256(other_counts.as_ptr().add(i) as *const __m256i);
let sum = _mm256_add_epi32(self_vec, other_vec);
_mm256_storeu_si256(self_counts.as_mut_ptr().add(i) as *mut __m256i, sum);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn subtract_histogram_grads_avx2(self_grads: &mut [f32; 256], other_grads: &[f32; 256]) {
for i in (0..256).step_by(8) {
let self_vec = _mm256_loadu_ps(self_grads.as_ptr().add(i));
let other_vec = _mm256_loadu_ps(other_grads.as_ptr().add(i));
let diff = _mm256_sub_ps(self_vec, other_vec);
_mm256_storeu_ps(self_grads.as_mut_ptr().add(i), diff);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn subtract_histogram_hess_avx2(self_hess: &mut [f32; 256], other_hess: &[f32; 256]) {
for i in (0..256).step_by(8) {
let self_vec = _mm256_loadu_ps(self_hess.as_ptr().add(i));
let other_vec = _mm256_loadu_ps(other_hess.as_ptr().add(i));
let diff = _mm256_sub_ps(self_vec, other_vec);
_mm256_storeu_ps(self_hess.as_mut_ptr().add(i), diff);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn subtract_histogram_counts_avx2(
self_counts: &mut [u32; 256],
other_counts: &[u32; 256],
) {
for i in (0..256).step_by(8) {
let self_vec = _mm256_loadu_si256(self_counts.as_ptr().add(i) as *const __m256i);
let other_vec = _mm256_loadu_si256(other_counts.as_ptr().add(i) as *const __m256i);
let diff = _mm256_sub_epi32(self_vec, other_vec);
_mm256_storeu_si256(self_counts.as_mut_ptr().add(i) as *mut __m256i, diff);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge_grads_avx2() {
if !std::arch::is_x86_feature_detected!("avx2") {
println!("AVX2 not available, skipping test");
return;
}
let mut self_grads = [0.0f32; 256];
let mut other_grads = [0.0f32; 256];
for i in 0..256 {
self_grads[i] = i as f32;
other_grads[i] = (256 - i) as f32;
}
unsafe {
merge_histogram_grads_avx2(&mut self_grads, &other_grads);
}
for (i, &val) in self_grads.iter().enumerate() {
assert!(
(val - 256.0).abs() < 1e-5,
"Mismatch at index {}: expected 256, got {}",
i,
val
);
}
}
#[test]
fn test_merge_counts_avx2() {
if !std::arch::is_x86_feature_detected!("avx2") {
println!("AVX2 not available, skipping test");
return;
}
let mut self_counts = [0u32; 256];
let mut other_counts = [0u32; 256];
for i in 0..256 {
self_counts[i] = i as u32;
other_counts[i] = 1;
}
unsafe {
merge_histogram_counts_avx2(&mut self_counts, &other_counts);
}
for i in 0..256 {
assert_eq!(self_counts[i], i as u32 + 1, "Mismatch at index {}", i);
}
}
#[test]
fn test_subtract_grads_avx2() {
if !std::arch::is_x86_feature_detected!("avx2") {
println!("AVX2 not available, skipping test");
return;
}
let mut self_grads = [100.0f32; 256];
let mut other_grads = [0.0f32; 256];
for i in 0..256 {
other_grads[i] = i as f32;
}
unsafe {
subtract_histogram_grads_avx2(&mut self_grads, &other_grads);
}
for i in 0..256 {
let expected = 100.0 - i as f32;
assert!(
(self_grads[i] - expected).abs() < 1e-5,
"Mismatch at index {}: expected {}, got {}",
i,
expected,
self_grads[i]
);
}
}
}