use super::super::*;
#[test]
#[cfg(target_arch = "x86_64")]
fn test_horizontal_sum_avx2() {
if !is_x86_feature_detected!("avx2") {
println!("Skipping AVX2 horizontal sum test (CPU doesn't support AVX2)");
return;
}
use std::arch::x86_64::*;
unsafe {
let v = _mm256_set1_ps(1.0);
let sum = Matrix::<f32>::horizontal_sum_avx2(v);
assert!((sum - 8.0).abs() < 1e-6, "Expected 8.0, got {}", sum);
let v = _mm256_setr_ps(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
let sum = Matrix::<f32>::horizontal_sum_avx2(v);
assert!((sum - 36.0).abs() < 1e-6, "Expected 36.0, got {}", sum);
let v = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);
let sum = Matrix::<f32>::horizontal_sum_avx2(v);
assert!(sum.abs() < 1e-6, "Expected ~0.0, got {}", sum);
let v = _mm256_setr_ps(100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700.0, 800.0);
let sum = Matrix::<f32>::horizontal_sum_avx2(v);
assert!((sum - 3600.0).abs() < 1e-3, "Expected 3600.0, got {}", sum);
let v = _mm256_setr_ps(10.5, -5.25, 3.75, -8.0, 12.0, -6.5, 4.25, -2.75);
let expected = 10.5 - 5.25 + 3.75 - 8.0 + 12.0 - 6.5 + 4.25 - 2.75;
let sum = Matrix::<f32>::horizontal_sum_avx2(v);
assert!((sum - expected).abs() < 1e-5, "Expected {}, got {}", expected, sum);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_matmul_microkernel_4x1_avx2() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
println!("Skipping AVX2 micro-kernel test (CPU doesn't support AVX2/FMA)");
return;
}
{
let row0: Vec<f32> = (1..=16).map(|x| x as f32).collect();
let row1: Vec<f32> = (17..=32).map(|x| x as f32).collect();
let row2: Vec<f32> = (33..=48).map(|x| x as f32).collect();
let row3: Vec<f32> = (49..=64).map(|x| x as f32).collect();
let b_col = vec![1.0f32; 16];
let a_rows = [row0.as_slice(), row1.as_slice(), row2.as_slice(), row3.as_slice()];
let mut results = [0.0f32; 4];
unsafe {
Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
}
let expected = [
(1..=16).sum::<i32>() as f32,
(17..=32).sum::<i32>() as f32,
(33..=48).sum::<i32>() as f32,
(49..=64).sum::<i32>() as f32,
];
for i in 0..4 {
assert!(
(results[i] - expected[i]).abs() < 1e-3,
"Row {}: expected {}, got {}",
i,
expected[i],
results[i]
);
}
}
{
let row0 =
vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let row1 =
vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let row2 =
vec![0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let row3 =
vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let b_col: Vec<f32> = (1..=16).map(|x| x as f32).collect();
let a_rows = [row0.as_slice(), row1.as_slice(), row2.as_slice(), row3.as_slice()];
let mut results = [0.0f32; 4];
unsafe {
Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
}
let expected = [1.0, 2.0, 3.0, 4.0];
for i in 0..4 {
assert!(
(results[i] - expected[i]).abs() < 1e-6,
"Row {}: expected {}, got {}",
i,
expected[i],
results[i]
);
}
}
{
let row0: Vec<f32> = (1..=10).map(|x| x as f32).collect();
let row1: Vec<f32> = (11..=20).map(|x| x as f32).collect();
let row2: Vec<f32> = (21..=30).map(|x| x as f32).collect();
let row3: Vec<f32> = (31..=40).map(|x| x as f32).collect();
let b_col = vec![2.0f32; 10];
let a_rows = [row0.as_slice(), row1.as_slice(), row2.as_slice(), row3.as_slice()];
let mut results = [0.0f32; 4];
unsafe {
Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
}
let expected = [
2.0 * (1..=10).sum::<i32>() as f32,
2.0 * (11..=20).sum::<i32>() as f32,
2.0 * (21..=30).sum::<i32>() as f32,
2.0 * (31..=40).sum::<i32>() as f32,
];
for i in 0..4 {
assert!(
(results[i] - expected[i]).abs() < 1e-3,
"Row {}: expected {}, got {}",
i,
expected[i],
results[i]
);
}
}
{
let row0 = vec![
1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0,
-16.0,
];
let row1 = vec![
2.0, -4.0, 6.0, -8.0, 10.0, -12.0, 14.0, -16.0, 18.0, -20.0, 22.0, -24.0, 26.0, -28.0,
30.0, -32.0,
];
let row2 = vec![
0.5, -1.0, 1.5, -2.0, 2.5, -3.0, 3.5, -4.0, 4.5, -5.0, 5.5, -6.0, 6.5, -7.0, 7.5, -8.0,
];
let row3 = vec![
10.0, -10.0, 10.0, -10.0, 10.0, -10.0, 10.0, -10.0, 10.0, -10.0, 10.0, -10.0, 10.0,
-10.0, 10.0, -10.0,
];
let b_col =
vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let a_rows = [row0.as_slice(), row1.as_slice(), row2.as_slice(), row3.as_slice()];
let mut results = [0.0f32; 4];
unsafe {
Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
}
let expected = [
row0.iter().sum::<f32>(),
row1.iter().sum::<f32>(),
row2.iter().sum::<f32>(),
row3.iter().sum::<f32>(),
];
for i in 0..4 {
assert!(
(results[i] - expected[i]).abs() < 1e-4,
"Row {}: expected {}, got {}",
i,
expected[i],
results[i]
);
}
}
{
let row0 = vec![0.0f32; 16];
let row1 = vec![0.0f32; 16];
let row2 = vec![0.0f32; 16];
let row3 = vec![0.0f32; 16];
let b_col: Vec<f32> = (1..=16).map(|x| x as f32).collect();
let a_rows = [row0.as_slice(), row1.as_slice(), row2.as_slice(), row3.as_slice()];
let mut results = [0.0f32; 4];
unsafe {
Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
}
for (i, &result) in results.iter().enumerate() {
assert!(result.abs() < 1e-6, "Row {}: expected 0.0, got {}", i, result);
}
}
{
let row0 = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
];
let row1 = vec![
2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0,
32.0,
];
let row2 =
vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0];
let row3 = vec![
3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0, 39.0, 42.0, 45.0,
48.0,
];
let b_col =
vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
let a_rows = [row0.as_slice(), row1.as_slice(), row2.as_slice(), row3.as_slice()];
let mut results = [0.0f32; 4];
unsafe {
Matrix::<f32>::matmul_microkernel_4x1_avx2(a_rows, &b_col, &mut results);
}
let expected = [
0.5 * row0.iter().sum::<f32>(),
0.5 * row1.iter().sum::<f32>(),
0.5 * row2.iter().sum::<f32>(),
0.5 * row3.iter().sum::<f32>(),
];
for i in 0..4 {
assert!(
(results[i] - expected[i]).abs() < 1e-3,
"Row {}: expected {}, got {}",
i,
expected[i],
results[i]
);
}
}
}