use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use std::ops::Add;
pub fn test_metal_gpu_basic() -> RusTorchResult<()> {
#[cfg(all(target_os = "macos", feature = "metal"))]
{
use metal::{Device, MTLResourceOptions};
if let Some(device) = Device::system_default() {
println!("✅ Metal device found: {}", device.name());
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let buffer = device.new_buffer_with_data(
data.as_ptr() as *const _,
(data.len() * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
println!(
"✅ Metal buffer created successfully: {} bytes",
buffer.length()
);
let contents = buffer.contents();
let read_data =
unsafe { std::slice::from_raw_parts(contents as *const f32, data.len()) };
if read_data == data.as_slice() {
println!("✅ Metal buffer read-back successful");
Ok(())
} else {
Err(RusTorchError::gpu("Metal buffer data mismatch"))
}
} else {
Err(RusTorchError::gpu("No Metal device available"))
}
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
{
println!("ℹ️ Metal testing skipped (requires macOS and metal feature)");
Ok(())
}
}
pub fn test_metal_tensor_operations() -> RusTorchResult<()> {
let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let b = Tensor::<f32>::from_vec(vec![2.0, 3.0, 4.0, 5.0], vec![2, 2]);
let cpu_result = &a + &b;
println!(
"✅ CPU tensor addition successful: {:?}",
cpu_result.data.as_slice()
);
println!("✅ Metal GPU operations currently using CPU fallback (safe mode)");
Ok(())
}
pub fn benchmark_metal_performance() -> RusTorchResult<()> {
use std::time::Instant;
println!("🚀 Performance benchmark starting...");
let sizes = vec![64, 128, 256, 512];
for size in sizes {
println!("\n📊 Testing matrix size: {}x{}", size, size);
let data_a: Vec<f32> = (0..(size * size)).map(|i| i as f32 * 0.1).collect();
let data_b: Vec<f32> = (0..(size * size)).map(|i| (i as f32 + 1.0) * 0.1).collect();
let matrix_a = Tensor::<f32>::from_vec(data_a, vec![size, size]);
let matrix_b = Tensor::<f32>::from_vec(data_b, vec![size, size]);
let start = Instant::now();
let cpu_result = matrix_a
.matmul(&matrix_b)
.map_err(|e| RusTorchError::gpu(e.to_string()))?;
let cpu_time = start.elapsed();
println!(" CPU time: {:.2}ms", cpu_time.as_secs_f64() * 1000.0);
let flops = 2.0 * (size as f64).powi(3); let cpu_gflops = flops / (cpu_time.as_secs_f64() * 1e9);
println!(" CPU performance: {:.2} GFLOPS", cpu_gflops);
println!(" Result shape: {:?}", cpu_result.shape());
println!(
" First few elements: {:?}",
&cpu_result.data.as_slice().unwrap()[..4.min(cpu_result.data.len())]
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(all(target_os = "macos", feature = "metal"))]
fn test_metal_availability() {
let result = test_metal_gpu_basic();
match result {
Ok(()) => println!("Metal GPU test passed"),
Err(e) => println!("Metal GPU test failed: {}", e),
}
}
#[test]
fn test_tensor_operations() {
let result = test_metal_tensor_operations();
assert!(result.is_ok());
}
#[test]
#[cfg(all(target_os = "macos", feature = "metal"))]
fn test_performance_benchmark() {
let result = benchmark_metal_performance();
assert!(result.is_ok());
}
}