use anyhow::Result;
#[cfg(all(target_os = "macos", feature = "metal"))]
use trustformers_core::device::Device;
#[cfg(all(target_os = "macos", feature = "metal"))]
use trustformers_core::gpu_ops::metal::get_metal_backend;
#[cfg(all(target_os = "macos", feature = "metal"))]
fn main() -> Result<()> {
println!("===============================================");
println!("๐งช Testing Softmax Causal GPU Kernel");
println!("===============================================\n");
let device = Device::metal_if_available(0);
println!("Device: {:?}\n", device);
let backend = get_metal_backend()?;
let seq_len = 5;
let mut input_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
input_data[i * seq_len + j] = (j + 1) as f32;
}
}
println!("๐ฅ Input scores ({} x {}):", seq_len, seq_len);
for i in 0..seq_len {
print!(" Row {}: [", i);
for j in 0..seq_len {
print!("{:5.2}", input_data[i * seq_len + j]);
if j < seq_len - 1 {
print!(", ");
}
}
println!("]");
}
println!();
let input_buffer_id = backend.create_persistent_buffer(&input_data)?;
println!("๐ Running softmax_causal GPU kernel...\n");
let output_buffer_id = backend.softmax_causal_gpu_to_gpu(&input_buffer_id, seq_len)?;
let output_data = backend.download_buffer_to_vec(&output_buffer_id)?;
println!("๐ค Softmax output ({} x {}):", seq_len, seq_len);
for i in 0..seq_len {
print!(" Row {}: [", i);
let mut row_sum = 0.0f32;
for j in 0..seq_len {
let val = output_data[i * seq_len + j];
print!("{:7.4}", val);
if j < seq_len - 1 {
print!(", ");
}
row_sum += val;
}
println!("] sum={:.6}", row_sum);
}
println!();
println!("๐ Verification:");
let mut all_correct = true;
for i in 0..seq_len {
for j in 0..seq_len {
let val = output_data[i * seq_len + j];
if j > i {
if val > 1e-6 {
println!(
" โ Position ({}, {}) should be 0.0 but is {:.6}",
i, j, val
);
all_correct = false;
}
} else {
if val < 1e-6 {
println!(
" โ ๏ธ Position ({}, {}) is {:.6} (seems too small)",
i, j, val
);
}
}
}
let row_sum: f32 = output_data[i * seq_len..(i + 1) * seq_len].iter().sum();
if (row_sum - 1.0).abs() > 0.01 {
println!(" โ Row {} sum is {:.6}, expected ~1.0", i, row_sum);
all_correct = false;
}
}
if all_correct {
println!(" โ
All checks passed!");
}
println!("\n===============================================");
Ok(())
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
fn main() -> Result<()> {
println!("This test requires macOS and the 'metal' feature to be enabled.");
Ok(())
}