use anyhow::Result;
fn main() -> Result<()> {
println!("================================================================================");
println!("๐งช MPS Matmul Correctness Test");
println!("================================================================================\n");
println!("Test Case: 2x3 @ 3x2 = 2x2\n");
let a_data = [
1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, ];
println!("Matrix A (2x3):");
println!(" [{}, {}, {}]", a_data[0], a_data[1], a_data[2]);
println!(" [{}, {}, {}]\n", a_data[3], a_data[4], a_data[5]);
let b_data = [
7.0_f32, 8.0, 9.0, 10.0, 11.0, 12.0, ];
println!("Matrix B (3x2):");
println!(" [{}, {}]", b_data[0], b_data[1]);
println!(" [{}, {}]", b_data[2], b_data[3]);
println!(" [{}, {}]\n", b_data[4], b_data[5]);
println!("Expected Result (CPU):");
println!(" [58.0, 64.0]");
println!(" [139.0, 154.0]\n");
#[cfg(all(target_os = "macos", feature = "metal"))]
{
use trustformers_core::gpu_ops::metal::get_metal_backend;
println!("Testing MPS matmul...");
let backend = get_metal_backend()?;
let a_buffer_id = backend.create_persistent_buffer(&a_data)?;
let b_buffer_id = backend.create_persistent_buffer(&b_data)?;
let m = 2; let k = 3; let n = 2;
let c_buffer_id = backend.matmul_gpu_to_gpu_mps(&a_buffer_id, &b_buffer_id, m, k, n)?;
let c_data = backend.download_buffer_to_vec(&c_buffer_id)?;
println!("MPS Result:");
println!(" [{:.1}, {:.1}]", c_data[0], c_data[1]);
println!(" [{:.1}, {:.1}]\n", c_data[2], c_data[3]);
let expected = [58.0, 64.0, 139.0, 154.0];
let mut all_correct = true;
for i in 0..4 {
let diff = (c_data[i] - expected[i]).abs();
if diff > 0.001 {
println!(
"โ Mismatch at position {}: expected {}, got {} (diff: {})",
i, expected[i], c_data[i], diff
);
all_correct = false;
}
}
if all_correct {
println!("โ
MPS matmul is CORRECT!");
} else {
println!("๐ด MPS matmul is INCORRECT!");
}
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
{
println!("โ ๏ธ Metal not available, skipping GPU test");
}
Ok(())
}