#![cfg(all(test, feature = "metal", target_os = "macos"))]
use scirs2_core::gpu::{
backends::{MetalBufferOptions, MetalContext, MetalStorageMode},
GpuBackend, GpuContext, GpuError,
};
use std::sync::Arc;
#[test]
#[allow(dead_code)]
fn test_metal_device_detection() {
use scirs2_core::gpu::backends::detect_gpu_backends;
let detection_result = detect_gpu_backends();
let metal_devices: Vec<_> = detection_result
.devices
.iter()
.filter(|d| d.backend == GpuBackend::Metal)
.collect();
assert!(
!metal_devices.is_empty(),
"No Metal devices detected on macOS"
);
for device in metal_devices {
println!("Metal Device: {}", device.device_name);
if let Some(memory) = device.memory_bytes {
println!(" Memory: {} GB", memory / (1024 * 1024 * 1024));
}
if let Some(capability) = &device.compute_capability {
println!(" Capability: {}", capability);
}
assert!(device.supports_tensors);
}
}
#[test]
#[ignore]
#[allow(dead_code)]
fn test_metal_context_creation() {
let result = GpuContext::new(GpuBackend::Metal);
match result {
Ok(context) => {
assert_eq!(context.backend(), GpuBackend::Metal);
assert_eq!(context.backend_name(), "Metal");
assert!(context.get_available_memory().is_some());
assert!(context.get_total_memory().is_some());
}
Err(e) => {
eprintln!("Metal context creation failed (expected in CI): {}", e);
}
}
}
#[test]
#[ignore]
#[allow(dead_code)]
fn test_metal_buffer_creation() {
let context = match GpuContext::new(GpuBackend::Metal) {
Ok(c) => c,
Err(_) => return, };
let buffer = context.create_buffer::<f32>(1024);
assert_eq!(buffer.len(), 1024);
assert!(!buffer.is_empty());
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let buffer = context.create_buffer_from_slice(&data);
assert_eq!(buffer.len(), 4);
let mut result = vec![0.0f32; 4];
buffer
.copy_to_host(&mut result)
.expect("Test: operation failed");
assert_eq!(result, data);
}
#[test]
#[ignore]
#[allow(dead_code)]
fn test_metal_buffer_options() {
use metal::MTLCPUCacheMode;
use metal::MTLHazardTrackingMode;
let context = match MetalContext::new() {
Ok(c) => c,
Err(_) => return, };
let options = MetalBufferOptions {
storage_mode: MetalStorageMode::Shared,
cache_mode: MTLCPUCacheMode::DefaultCache,
hazard_tracking_mode: MTLHazardTrackingMode::Default,
};
let buffer = context.create_buffer_with_options(1024, options);
assert!(Arc::strong_count(&buffer) == 1);
let private_options = MetalBufferOptions {
storage_mode: MetalStorageMode::Private,
cache_mode: MTLCPUCacheMode::DefaultCache,
hazard_tracking_mode: MTLHazardTrackingMode::Untracked,
};
let private_buffer = context.create_buffer_with_options(2048, private_options);
assert!(Arc::strong_count(&private_buffer) == 1);
}
#[test]
#[ignore]
#[allow(dead_code)]
fn test_metal_kernel_compilation() {
let context = match GpuContext::new(GpuBackend::Metal) {
Ok(c) => c,
Err(_) => return, };
let result = context.get_kernel("axpy");
assert!(result.is_ok(), "Failed to get AXPY kernel");
let complex_result = context.get_kernel("complex_multiply");
assert!(
complex_result.is_ok(),
"Failed to get complex multiply kernel"
);
}
#[test]
#[ignore]
#[allow(dead_code)]
fn test_metal_kernel_execution() {
let context = match GpuContext::new(GpuBackend::Metal) {
Ok(c) => c,
Err(_) => return, };
let x = vec![1.0f32, 2.0, 3.0, 4.0];
let y = vec![5.0f32, 6.0, 7.0, 8.0];
let alpha = 2.0f32;
let x_buffer = context.create_buffer_from_slice(&x);
let mut y_buffer = context.create_buffer_from_slice(&y);
let kernel = match context.get_kernel("axpy") {
Ok(k) => k,
Err(_) => return, };
kernel.set_buffer("x", &x_buffer);
kernel.set_buffer("y", &y_buffer);
kernel.set_f32("alpha", alpha);
kernel.set_i32("n", x.len() as i32);
kernel.dispatch([1, 1, 1]);
let mut result = vec![0.0f32; 4];
y_buffer
.copy_to_host(&mut result)
.expect("Test: operation failed");
let expected = [7.0f32, 10.0, 13.0, 16.0];
for (i, (r, e)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(r - e).abs() < 1e-6,
"Result mismatch at index {}: {} vs {}",
i,
r,
e
);
}
}
#[test]
#[ignore]
#[allow(dead_code)]
fn test_metal_complex_operations() {
let context = match GpuContext::new(GpuBackend::Metal) {
Ok(c) => c,
Err(_) => return, };
let a = vec![
1.0f32, 0.0, 2.0, 1.0, 3.0, -1.0, 0.0, 2.0, ];
let b = vec![
2.0f32, 0.0, 1.0, -1.0, 0.0, 1.0, 3.0, 1.0, ];
let a_buffer = context.create_buffer_from_slice(&a);
let b_buffer = context.create_buffer_from_slice(&b);
let result_buffer = context.create_buffer::<f32>(8);
let kernel = match context.get_kernel("complex_multiply") {
Ok(k) => k,
Err(_) => return, };
kernel.set_buffer("a", &a_buffer);
kernel.set_buffer("b", &b_buffer);
kernel.set_buffer("result", &result_buffer);
kernel.set_u32("n", 4);
kernel.dispatch([1, 1, 1]);
let mut result = vec![0.0f32; 8];
result_buffer
.copy_to_host(&mut result)
.expect("Test: operation failed");
let expected = [
2.0f32, 0.0, 3.0, -1.0, 1.0, 3.0, -2.0, 6.0, ];
for (i, (r, e)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(r - e).abs() < 1e-6,
"Complex result mismatch at index {}: {} vs {}",
i,
r,
e
);
}
}
#[test]
#[allow(dead_code)]
#[allow(unexpected_cfgs)]
fn test_metal_performance_shaders() {
#[cfg(feature = "metal-performance-shaders")]
{
let context = match MetalContext::new() {
Ok(c) => c,
Err(_) => return, };
if let Some(mps_ops) = context.mps_operations() {
println!("Metal Performance Shaders available");
} else {
println!("Metal Performance Shaders not available");
}
}
}
#[test]
#[ignore]
#[allow(dead_code)]
fn test_metal_unified_memory() {
let context = match MetalContext::new() {
Ok(c) => c,
Err(_) => return, };
println!("Device: {}", context.device_name());
println!("Unified Memory: {}", context.has_unified_memory());
if context.device_name().contains("Apple") {
assert!(
context.has_unified_memory(),
"Apple Silicon should have unified memory"
);
}
}
#[test]
#[ignore]
#[allow(dead_code)]
fn test_metalerror_handling() {
let context = match GpuContext::new(GpuBackend::Metal) {
Ok(c) => c,
Err(_) => return, };
let result = context.get_kernel("nonexistent_kernel");
assert!(matches!(result, Err(GpuError::KernelNotFound(_))));
let huge_size = usize::MAX / 2;
let buffer = context.create_buffer::<f32>(huge_size);
assert!(!buffer.is_empty());
}
#[test]
#[ignore]
#[should_panic(expected = "Data size exceeds buffer size")]
#[allow(dead_code)]
fn test_metal_buffer_overflow() {
let context = match GpuContext::new(GpuBackend::Metal) {
Ok(c) => c,
Err(_) => panic!("Data size exceeds buffer size"), };
let buffer = context.create_buffer::<f32>(4);
let data = vec![1.0f32; 8]; buffer
.copy_from_host(&data)
.expect("Test: operation failed"); }
mod benchmarks {
use super::*;
use std::time::Instant;
#[test]
#[ignore] fn bench_metal_buffer_transfer() {
let context = match GpuContext::new(GpuBackend::Metal) {
Ok(c) => c,
Err(_) => return,
};
let sizes = vec![1024, 1024 * 1024, 16 * 1024 * 1024];
for size in sizes {
let data = vec![1.0f32; size];
let start = Instant::now();
let buffer = context.create_buffer_from_slice(&data);
let h2d_time = start.elapsed();
let mut result = vec![0.0f32; size];
let start = Instant::now();
buffer
.copy_to_host(&mut result)
.expect("Test: operation failed");
let d2h_time = start.elapsed();
let size_mb = (size * 4) as f64 / (1024.0 * 1024.0);
println!("Buffer size: {:.2} MB", size_mb);
println!(
" H2D: {:?} ({:.2} GB/s)",
h2d_time,
size_mb / 1024.0 / h2d_time.as_secs_f64()
);
println!(
" D2H: {:?} ({:.2} GB/s)",
d2h_time,
size_mb / 1024.0 / d2h_time.as_secs_f64()
);
}
}
}