use crate::blis::*;
#[test]
fn test_backend_selection_small_problem_chooses_cpu() {
let cost = BackendCostModel::default();
let backend = cost.select_backend(64, 64, 64);
assert!(
matches!(backend, ComputeBackend::Cpu | ComputeBackend::Scalar),
"Small problem should use CPU, got {:?}",
backend
);
}
#[test]
fn test_backend_cost_model_time_estimate() {
let cost = BackendCostModel::default();
let m = 1024;
let n = 1024;
let k = 1024;
let cpu_time = cost.estimate_time_us(m, n, k, ComputeBackend::Cpu);
let scalar_time = cost.estimate_time_us(m, n, k, ComputeBackend::Scalar);
assert!(
cpu_time < scalar_time,
"CPU ({:.2}us) should be faster than scalar ({:.2}us)",
cpu_time,
scalar_time
);
}
#[test]
fn test_roofline_analysis_compute_bound() {
let profiler = UnifiedBrickProfiler::new();
let result = profiler.roofline_analysis(1024, 1024, 1024);
assert!(
result.is_compute_bound(),
"1024x1024x1024 should be compute-bound, AI={:.1}",
result.arithmetic_intensity()
);
}
#[test]
fn test_unified_profiler_records_selection() {
let mut profiler = UnifiedBrickProfiler::new();
profiler.record_selection(256, 256, 256, ComputeBackend::Cpu);
assert_eq!(profiler.selection_history.len(), 1);
assert_eq!(profiler.backend, Some(ComputeBackend::Cpu));
assert_eq!(profiler.total_elements, 256 * 256);
}
#[test]
fn test_wgsl_spec_generation() {
let spec = WgslMicrokernelSpec::default();
let wgsl = spec.generate_wgsl();
assert!(wgsl.contains("@compute"));
assert!(wgsl.contains("@workgroup_size"));
assert!(wgsl.contains("tile_a"));
assert!(wgsl.contains("tile_b"));
assert!(wgsl.contains("workgroupBarrier"));
assert!(wgsl.contains("@group(0) @binding(0)"));
assert!(wgsl.contains("@group(0) @binding(1)"));
assert!(wgsl.contains("@group(0) @binding(2)"));
assert!(wgsl.contains("@group(0) @binding(3)"));
assert!(wgsl.contains("struct GemmParams"));
assert!(wgsl.contains("m: u32"));
assert!(wgsl.contains("n: u32"));
assert!(wgsl.contains("k: u32"));
assert!(wgsl.contains("alpha: f32"));
assert!(wgsl.contains("beta: f32"));
assert!(wgsl.contains("@workgroup_size(8, 8, 1)"));
assert!(wgsl.contains("num_tiles"));
assert!(wgsl.contains("var sum: f32 = 0.0"));
assert!(wgsl.contains("params.alpha * sum + params.beta"));
assert!(wgsl.contains("Tile: 8x8"));
assert!(wgsl.contains("Workgroup: 8x8x1"));
}
#[test]
fn test_wgsl_spec_custom_dimensions() {
let spec = WgslMicrokernelSpec {
workgroup_size: (16, 16, 1),
tile_dim: (16, 16),
use_shared_memory: true,
};
let wgsl = spec.generate_wgsl();
assert!(wgsl.contains("@workgroup_size(16, 16, 1)"));
assert!(wgsl.contains("Tile: 16x16"));
assert!(wgsl.contains("Workgroup: 16x16x1"));
assert!(wgsl.contains("array<f32, 256>")); }
#[test]
fn test_wgsl_spec_small_workgroup() {
let spec = WgslMicrokernelSpec {
workgroup_size: (4, 4, 1),
tile_dim: (4, 4),
use_shared_memory: true,
};
let wgsl = spec.generate_wgsl();
assert!(wgsl.contains("@workgroup_size(4, 4, 1)"));
assert!(wgsl.contains("Tile: 4x4"));
assert!(wgsl.contains("array<f32, 16>"));
}
#[test]
fn test_ptx_spec_default() {
let spec = PtxMicrokernelSpec::default();
assert_eq!(spec.sm_target, "sm_80");
assert_eq!(spec.registers_per_thread, 64);
assert_eq!(spec.tile_dim, (16, 16));
}
#[test]
fn test_gemm_auto_produces_correct_result() {
let m = 128;
let n = 128;
let k = 128;
let a: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32) * 0.1).collect();
let mut c_ref = vec![0.0; m * n];
let mut c_auto = vec![0.0; m * n];
gemm_reference(m, n, k, &a, &b, &mut c_ref).unwrap();
gemm_auto(m, n, k, &a, &b, &mut c_auto, None).unwrap();
let max_diff: f32 =
c_ref.iter().zip(c_auto.iter()).map(|(r, a)| (r - a).abs()).fold(0.0, f32::max);
assert!(max_diff < 1e-3, "gemm_auto should match reference, max_diff={}", max_diff);
}
#[test]
fn test_gemm_auto_with_profiler() {
let m = 64;
let n = 64;
let k = 64;
let a: Vec<f32> = vec![1.0; m * k];
let b: Vec<f32> = vec![1.0; k * n];
let mut c = vec![0.0; m * n];
let mut profiler = UnifiedBrickProfiler::new();
gemm_auto(m, n, k, &a, &b, &mut c, Some(&mut profiler)).unwrap();
assert!(profiler.backend.is_some());
assert_eq!(profiler.total_elements, (m * n) as u64);
}
#[test]
fn test_f323_backend_selection_respects_pcie_rule() {
let cost = BackendCostModel::default();
let small = cost.select_backend(32, 32, 32);
assert!(
matches!(small, ComputeBackend::Cpu | ComputeBackend::Scalar),
"F323: Small matrix should use CPU"
);
let m: usize = 1024;
let n: usize = 1024;
let k: usize = 1024;
let flops = 2_u64 * m as u64 * n as u64 * k as u64;
let bytes = 4_u64 * (m * k + k * n + m * n) as u64;
let ai = flops as f64 / bytes as f64;
assert!(ai > 100.0, "F323: AI should be high for large K, got {}", ai);
}
#[test]
fn test_f324_cross_backend_equivalence() {
let m = 64;
let n = 64;
let k = 64;
let a: Vec<f32> = (0..m * k).map(|i| ((i % 13) as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * n).map(|i| ((i % 17) as f32) * 0.1).collect();
let mut c_ref = vec![0.0; m * n];
gemm_reference(m, n, k, &a, &b, &mut c_ref).unwrap();
let mut c_blis = vec![0.0; m * n];
gemm_blis(m, n, k, &a, &b, &mut c_blis, None).unwrap();
let mut c_auto = vec![0.0; m * n];
gemm_auto(m, n, k, &a, &b, &mut c_auto, None).unwrap();
let max_diff_blis: f32 =
c_ref.iter().zip(c_blis.iter()).map(|(r, b)| (r - b).abs()).fold(0.0, f32::max);
let max_diff_auto: f32 =
c_ref.iter().zip(c_auto.iter()).map(|(r, a)| (r - a).abs()).fold(0.0, f32::max);
assert!(max_diff_blis < 1e-3, "F324: BLIS should match reference");
assert!(max_diff_auto < 1e-3, "F324: Auto should match reference");
}
#[test]
fn test_f325_profiler_reports_consistent_metrics() {
let profiler = UnifiedBrickProfiler::new();
let m = 128;
let n = 128;
let k = 128;
let roofline = profiler.roofline_analysis(m, n, k);
let ai = roofline.arithmetic_intensity();
let flops = 2.0 * m as f64 * n as f64 * k as f64;
let bytes = 4.0 * (m * k + k * n + m * n) as f64;
let expected_ai = flops / bytes;
assert!(
(ai - expected_ai).abs() < 0.01,
"F325: Profiler AI ({}) should match manual calculation ({})",
ai,
expected_ai
);
}
#[test]
fn test_f329_brick_hierarchy_profiled() {
let mut profiler = BlisProfiler::enabled();
let n = 128;
let a: Vec<f32> = vec![1.0; n * n];
let b: Vec<f32> = vec![1.0; n * n];
let mut c = vec![0.0; n * n];
gemm_blis(n, n, n, &a, &b, &mut c, Some(&mut profiler)).unwrap();
assert!(profiler.macro_stats.count > 0, "F329: Macro level should be profiled");
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_microkernel_pipelined_matches_reference() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 64;
let a: Vec<f32> = (0..MR * k).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * NR).map(|i| (i as f32) * 0.01).collect();
let mut c_scalar = vec![0.0; MR * NR];
let mut c_pipelined = vec![0.0; MR * NR];
microkernel_scalar(k, &a, &b, &mut c_scalar, MR);
unsafe {
microkernel_8x6_avx2_asm(k, a.as_ptr(), b.as_ptr(), c_pipelined.as_mut_ptr(), MR);
}
for i in 0..MR * NR {
let diff = (c_scalar[i] - c_pipelined[i]).abs();
let rel_diff = diff / c_scalar[i].abs().max(1e-10);
assert!(
rel_diff < 1e-5,
"Pipelined microkernel mismatch at {}: scalar={}, pipelined={}, rel_diff={}",
i,
c_scalar[i],
c_pipelined[i],
rel_diff
);
}
}