#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod metal_shaders;
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
fn is_available(&self) -> bool;
fn device_count(&self) -> usize;
}
#[derive(Debug, Default)]
pub struct CudaBackend;
impl Backend for CudaBackend {
fn name(&self) -> &str {
"CUDA"
}
fn is_available(&self) -> bool {
crate::driver::cuda_available()
}
#[cfg(feature = "cuda")]
fn device_count(&self) -> usize {
if self.is_available() {
crate::driver::device_count().unwrap_or(0)
} else {
0
}
}
#[cfg(not(feature = "cuda"))]
fn device_count(&self) -> usize {
if self.is_available() {
crate::driver::device_count()
} else {
0
}
}
}
#[derive(Debug, Default)]
pub struct MetalBackend;
impl Backend for MetalBackend {
fn name(&self) -> &str {
"Metal"
}
#[cfg(all(target_os = "macos", feature = "metal"))]
fn is_available(&self) -> bool {
manzana::metal::is_available()
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
fn is_available(&self) -> bool {
false
}
#[cfg(all(target_os = "macos", feature = "metal"))]
fn device_count(&self) -> usize {
manzana::metal::MetalCompute::devices().len()
}
#[cfg(not(all(target_os = "macos", feature = "metal")))]
fn device_count(&self) -> usize {
0
}
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub use manzana::metal::{CompiledShader as MetalShader, MetalBuffer, MetalCompute, MetalDevice};
#[derive(Debug, Default)]
pub struct VulkanBackend;
impl Backend for VulkanBackend {
fn name(&self) -> &str {
"Vulkan"
}
fn is_available(&self) -> bool {
false }
fn device_count(&self) -> usize {
0
}
}
#[derive(Debug, Default)]
pub struct WgpuBackend;
impl Backend for WgpuBackend {
fn name(&self) -> &str {
"WGPU"
}
fn is_available(&self) -> bool {
cfg!(feature = "wgpu")
}
fn device_count(&self) -> usize {
usize::from(self.is_available())
}
}
#[must_use]
pub fn detect_backend() -> Box<dyn Backend> {
let cuda = CudaBackend;
if cuda.is_available() {
return Box::new(cuda);
}
let wgpu = WgpuBackend;
if wgpu.is_available() {
return Box::new(wgpu);
}
let metal = MetalBackend;
if metal.is_available() {
return Box::new(metal);
}
let vulkan = VulkanBackend;
if vulkan.is_available() {
return Box::new(vulkan);
}
Box::new(CudaBackend)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cuda_backend_name() {
let backend = CudaBackend;
assert_eq!(backend.name(), "CUDA");
}
#[test]
#[cfg(not(all(target_os = "macos", feature = "metal")))]
fn test_metal_backend_unavailable() {
let backend = MetalBackend;
assert!(!backend.is_available());
}
#[test]
#[cfg(all(target_os = "macos", feature = "metal"))]
fn test_metal_backend_available() {
let backend = MetalBackend;
assert!(backend.is_available(), "Metal should be available on macOS");
assert!(
backend.device_count() > 0,
"Should have at least one Metal device"
);
}
#[test]
fn test_detect_backend() {
let backend = detect_backend();
assert!(!backend.name().is_empty());
}
#[test]
fn test_metal_backend_name() {
let backend = MetalBackend;
assert_eq!(backend.name(), "Metal");
}
#[test]
fn test_vulkan_backend_name() {
let backend = VulkanBackend;
assert_eq!(backend.name(), "Vulkan");
}
#[test]
fn test_vulkan_backend_unavailable() {
let backend = VulkanBackend;
assert!(!backend.is_available());
}
#[test]
fn test_cuda_backend_device_count() {
let backend = CudaBackend;
let count = backend.device_count();
assert!(backend.is_available() || count == 0);
}
#[test]
#[cfg(not(all(target_os = "macos", feature = "metal")))]
fn test_metal_backend_device_count() {
let backend = MetalBackend;
assert_eq!(backend.device_count(), 0);
}
#[test]
#[cfg(all(target_os = "macos", feature = "metal"))]
fn test_metal_backend_device_count_macos() {
let backend = MetalBackend;
assert!(
backend.device_count() >= 1,
"Should have at least one Metal device"
);
}
#[test]
fn test_vulkan_backend_device_count() {
let backend = VulkanBackend;
assert_eq!(backend.device_count(), 0);
}
#[test]
fn test_cuda_backend_default() {
let backend = CudaBackend::default();
assert_eq!(backend.name(), "CUDA");
}
#[test]
fn test_metal_backend_default() {
let backend = MetalBackend::default();
assert_eq!(backend.name(), "Metal");
}
#[test]
fn test_vulkan_backend_default() {
let backend = VulkanBackend::default();
assert_eq!(backend.name(), "Vulkan");
}
#[test]
fn test_wgpu_backend_name() {
let backend = WgpuBackend;
assert_eq!(backend.name(), "WGPU");
}
#[test]
fn test_wgpu_backend_default() {
let backend = WgpuBackend::default();
assert_eq!(backend.name(), "WGPU");
}
#[test]
fn test_wgpu_backend_device_count() {
let backend = WgpuBackend;
#[cfg(not(feature = "wgpu"))]
assert_eq!(backend.device_count(), 0);
}
#[test]
fn test_wgpu_backend_is_available() {
let backend = WgpuBackend;
#[cfg(not(feature = "wgpu"))]
assert!(!backend.is_available());
#[cfg(feature = "wgpu")]
{
let _ = backend.is_available(); }
}
#[test]
fn test_cuda_backend_is_available() {
let backend = CudaBackend;
let available = backend.is_available();
let _ = available;
}
#[test]
fn test_cuda_backend_debug() {
let backend = CudaBackend;
let debug_str = format!("{:?}", backend);
assert!(debug_str.contains("CudaBackend"));
}
#[test]
fn test_metal_backend_debug() {
let backend = MetalBackend;
let debug_str = format!("{:?}", backend);
assert!(debug_str.contains("MetalBackend"));
}
#[test]
fn test_vulkan_backend_debug() {
let backend = VulkanBackend;
let debug_str = format!("{:?}", backend);
assert!(debug_str.contains("VulkanBackend"));
}
#[test]
fn test_wgpu_backend_debug() {
let backend = WgpuBackend;
let debug_str = format!("{:?}", backend);
assert!(debug_str.contains("WgpuBackend"));
}
#[test]
fn test_detect_backend_returns_valid_name() {
let backend = detect_backend();
let name = backend.name();
assert!(!name.is_empty());
let valid_names = ["CUDA", "Metal", "Vulkan", "WGPU"];
assert!(valid_names.contains(&name), "Unknown backend name");
}
#[test]
fn test_detect_backend_fallback_is_cuda() {
let backend = detect_backend();
let any_available = CudaBackend.is_available()
|| WgpuBackend.is_available()
|| MetalBackend.is_available()
|| VulkanBackend.is_available();
if !any_available {
assert_eq!(backend.name(), "CUDA");
}
}
#[test]
fn test_backend_trait_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CudaBackend>();
assert_send_sync::<MetalBackend>();
assert_send_sync::<VulkanBackend>();
assert_send_sync::<WgpuBackend>();
}
#[test]
fn test_all_backends_device_count_consistent() {
let cuda = CudaBackend;
let cuda_count = cuda.device_count();
assert!(cuda.is_available() || cuda_count == 0);
let metal = MetalBackend;
let metal_count = metal.device_count();
assert!(metal.is_available() || metal_count == 0);
let vulkan = VulkanBackend;
let vulkan_count = vulkan.device_count();
assert!(vulkan.is_available() || vulkan_count == 0);
let wgpu = WgpuBackend;
let wgpu_count = wgpu.device_count();
assert!(wgpu.is_available() || wgpu_count == 0);
}
#[test]
fn test_detect_backend_is_deterministic() {
let backend1 = detect_backend();
let backend2 = detect_backend();
assert_eq!(backend1.name(), backend2.name());
}
#[test]
fn test_boxed_backend_trait_object() {
let backends: Vec<Box<dyn Backend>> = vec![
Box::new(CudaBackend),
Box::new(MetalBackend),
Box::new(VulkanBackend),
Box::new(WgpuBackend),
];
for backend in &backends {
assert!(!backend.name().is_empty());
let _ = backend.is_available();
let _ = backend.device_count();
}
}
}