#![cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[target_feature(enable = "avx2,fma")]
fn outer(data: &[f32; 8]) -> f32 {
let a = inner_add(data);
let b = inner_mul(data);
a + b
}
#[target_feature(enable = "avx2")]
#[inline]
fn inner_add(data: &[f32; 8]) -> f32 {
unsafe {
let v = _mm256_loadu_ps(data.as_ptr());
let sum = _mm256_hadd_ps(v, v);
let sum = _mm256_hadd_ps(sum, sum);
_mm_cvtss_f32(_mm256_castps256_ps128(sum)) + _mm_cvtss_f32(_mm256_extractf128_ps::<1>(sum))
}
}
#[target_feature(enable = "avx2")]
#[inline]
fn inner_mul(data: &[f32; 8]) -> f32 {
unsafe {
let v = _mm256_loadu_ps(data.as_ptr());
let prod = _mm256_mul_ps(v, v);
let sum = _mm256_hadd_ps(prod, prod);
let sum = _mm256_hadd_ps(sum, sum);
_mm_cvtss_f32(_mm256_castps256_ps128(sum)) + _mm_cvtss_f32(_mm256_extractf128_ps::<1>(sum))
}
}
#[target_feature(enable = "avx2,fma,bmi1,bmi2")]
fn v3_outer(data: &[f32; 8]) -> f32 {
inner_add(data)
}
#[target_feature(enable = "avx2,fma")]
fn level1(data: &[f32; 8]) -> f32 {
level2(data) * 2.0
}
#[target_feature(enable = "avx2,fma")]
#[inline]
fn level2(data: &[f32; 8]) -> f32 {
level3(data) + 1.0
}
#[target_feature(enable = "avx2")]
#[inline]
fn level3(data: &[f32; 8]) -> f32 {
inner_add(data)
}
#[test]
fn test_nested_target_feature_no_unsafe() {
if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = unsafe { outer(&data) };
println!("outer result: {}", result);
let result = unsafe { v3_outer(&data) };
println!("v3_outer result: {}", result);
let result = unsafe { level1(&data) };
println!("level1 result: {}", result);
}
}