#[cfg(all(target_os = "macos", feature = "metal-native"))]
mod metal_detection_tests {
use hive_gpu::error::HiveGpuError;
use hive_gpu::metal::MetalNativeContext;
use hive_gpu::traits::GpuContext;
#[test]
fn test_metal_device_availability() {
let result = MetalNativeContext::new();
match result {
Ok(context) => {
println!("✅ Metal device available");
println!(" Context created successfully");
let info = context.device_info().expect("Failed to get device info");
assert!(!info.name.is_empty(), "Device name should not be empty");
assert_eq!(info.backend, "Metal", "Backend should be Metal");
}
Err(HiveGpuError::NoDeviceAvailable) => {
println!("⚠️ Metal device not available on this system");
}
Err(e) => {
panic!("Unexpected error during Metal detection: {}", e);
}
}
}
#[test]
fn test_metal_device_name_retrieval() {
let context = match MetalNativeContext::new() {
Ok(ctx) => ctx,
Err(HiveGpuError::NoDeviceAvailable) => {
println!("⚠️ Metal not available, skipping test");
return;
}
Err(e) => panic!("Failed to create Metal context: {}", e),
};
let info = context.device_info().expect("Failed to get device info");
assert!(!info.name.is_empty(), "Device name should not be empty");
println!("✅ Metal device name: {}", info.name);
let valid_prefixes = ["Apple", "AMD", "Intel"];
let has_valid_prefix = valid_prefixes
.iter()
.any(|prefix| info.name.starts_with(prefix));
assert!(
has_valid_prefix,
"Device name should start with known vendor: {}",
info.name
);
}
#[test]
fn test_metal_device_capabilities() {
let context = match MetalNativeContext::new() {
Ok(ctx) => ctx,
Err(HiveGpuError::NoDeviceAvailable) => {
println!("⚠️ Metal not available, skipping test");
return;
}
Err(e) => panic!("Failed to create Metal context: {}", e),
};
let info = context.device_info().expect("Failed to get device info");
assert!(
info.max_threads_per_block > 0,
"Max threads should be positive"
);
assert!(
info.max_shared_memory_per_block > 0,
"Max shared memory should be positive"
);
assert!(info.total_vram_bytes > 0, "Total VRAM should be positive");
println!("✅ Metal capabilities:");
println!(" Max threads/block: {}", info.max_threads_per_block);
println!(
" Max shared memory: {} KB",
info.max_shared_memory_per_block / 1024
);
println!(" Total VRAM: {} MB", info.total_vram_mb());
if info.name.contains("Apple") {
assert!(
info.max_threads_per_block >= 512,
"Apple Silicon should support at least 512 threads per block"
);
assert!(
info.max_shared_memory_per_block >= 16 * 1024,
"Apple Silicon should have at least 16KB shared memory"
);
}
}
#[test]
fn test_metal_multiple_contexts() {
let context1 = match MetalNativeContext::new() {
Ok(ctx) => ctx,
Err(HiveGpuError::NoDeviceAvailable) => {
println!("⚠️ Metal not available, skipping test");
return;
}
Err(e) => panic!("Failed to create first Metal context: {}", e),
};
let context2 = match MetalNativeContext::new() {
Ok(ctx) => ctx,
Err(e) => panic!("Failed to create second Metal context: {}", e),
};
let info1 = context1
.device_info()
.expect("Failed to get info from context 1");
let info2 = context2
.device_info()
.expect("Failed to get info from context 2");
assert_eq!(
info1.name, info2.name,
"Both contexts should use same device"
);
assert_eq!(
info1.backend, info2.backend,
"Both should use Metal backend"
);
println!("✅ Multiple Metal contexts created successfully");
println!(" Context 1: {}", info1.name);
println!(" Context 2: {}", info2.name);
}
#[test]
fn test_metal_vram_query() {
let context = match MetalNativeContext::new() {
Ok(ctx) => ctx,
Err(HiveGpuError::NoDeviceAvailable) => {
println!("⚠️ Metal not available, skipping test");
return;
}
Err(e) => panic!("Failed to create Metal context: {}", e),
};
let info = context.device_info().expect("Failed to get device info");
assert!(info.total_vram_bytes > 0, "Total VRAM should be positive");
assert!(
info.available_vram_bytes <= info.total_vram_bytes,
"Available VRAM should not exceed total"
);
assert!(
info.used_vram_bytes == info.total_vram_bytes - info.available_vram_bytes,
"Used VRAM should equal total - available"
);
println!("✅ VRAM information:");
println!(" Total: {} MB", info.total_vram_mb());
println!(" Available: {} MB", info.available_vram_mb());
println!(" Used: {} MB", info.used_vram_bytes / (1024 * 1024));
println!(" Usage: {:.1}%", info.vram_usage_percent());
}
}
mod fallback_tests {
use hive_gpu::backends::detector::{
GpuBackendType, detect_available_backends, select_best_backend,
};
#[test]
fn test_backend_detection() {
let backends = detect_available_backends();
println!("✅ Detected backends: {:?}", backends);
assert!(
!backends.is_empty(),
"Should detect at least one backend (CPU)"
);
assert!(
backends.contains(&GpuBackendType::Cpu),
"CPU should always be available as fallback"
);
for backend in &backends {
match backend {
GpuBackendType::Metal => {
println!(" Metal backend available");
#[cfg(not(target_os = "macos"))]
panic!("Metal should only be detected on macOS");
}
GpuBackendType::Cuda => {
println!(" CUDA backend available");
}
GpuBackendType::Rocm => {
println!(" ROCm backend available");
}
GpuBackendType::Intel => {
println!(" Intel backend available");
}
GpuBackendType::Cpu => {
println!(" CPU backend available");
}
}
}
}
#[test]
fn test_best_backend_selection() {
let best = select_best_backend().expect("Should always find a backend");
println!("✅ Best backend selected: {:?}", best);
let available = detect_available_backends();
assert!(
available.contains(&best),
"Best backend should be in available backends"
);
#[cfg(all(target_os = "macos", feature = "metal-native"))]
{
if available.contains(&GpuBackendType::Metal) {
assert_eq!(
best,
GpuBackendType::Metal,
"Metal should be preferred on macOS"
);
}
}
}
#[test]
fn test_graceful_fallback_no_gpu() {
println!("✅ Testing graceful fallback behavior");
let backend = select_best_backend().expect("Should fallback to CPU if no GPU");
println!(" Fallback backend: {:?}", backend);
let info = hive_gpu::backends::detector::get_backend_info(backend);
if let Ok(info_str) = info {
println!(" Info: {}", info_str);
}
}
#[test]
fn test_backend_performance_info() {
let backends = detect_available_backends();
println!("✅ Backend performance characteristics:");
for backend in backends {
let perf = hive_gpu::backends::detector::get_backend_performance_info(backend);
println!(" {} ({}):", perf.name, backend);
println!(
" Memory bandwidth: {:.1} GB/s",
perf.memory_bandwidth_gbps
);
println!(" Compute units: {}", perf.compute_units);
println!(" Memory size: {} GB", perf.memory_size_gb);
println!(" HNSW support: {}", perf.supports_hnsw);
println!(" Batch support: {}", perf.supports_batch);
}
}
}