optical_embeddings/
backend.rs

1//! Backend selection with automatic GPU detection for Burn 0.18
2
3use burn::prelude::*;
4
5// Define backend types based on enabled features
6#[cfg(feature = "cuda")]
7pub type AutoBackend = burn_cuda::Cuda<f32>;
8
9#[cfg(all(feature = "wgpu", not(feature = "cuda")))]
10pub type AutoBackend = burn_wgpu::Wgpu<f32, i32>;
11
12#[cfg(all(feature = "cpu", not(any(feature = "cuda", feature = "wgpu"))))]
13pub type AutoBackend = burn_ndarray::NdArray<f32>;
14
15/// Get the best available device
16pub fn get_device() -> <AutoBackend as Backend>::Device {
17    #[cfg(feature = "cuda")]
18    {
19        println!("🚀 Using CUDA backend (NVIDIA GPU)");
20        burn_cuda::CudaDevice::default()
21    }
22
23    #[cfg(all(feature = "wgpu", not(feature = "cuda")))]
24    {
25        println!("🎮 Using WGPU backend (GPU via Vulkan/Metal/DX12)");
26        burn_wgpu::WgpuDevice::default()
27    }
28
29    #[cfg(all(feature = "cpu", not(any(feature = "cuda", feature = "wgpu"))))]
30    {
31        println!("💻 Using NdArray backend (CPU)");
32        Default::default()
33    }
34}
35
36/// Print detailed backend information
37pub fn print_backend_info() {
38    println!("╔════════════════════════════════════════╗");
39    println!("║        Backend Configuration           ║");
40    println!("╚════════════════════════════════════════╝");
41
42    #[cfg(feature = "cuda")]
43    {
44        println!("  Backend: CUDA (NVIDIA GPU)");
45        println!("  Features: Fast matrix ops, tensor cores");
46        let _device = burn_cuda::CudaDevice::new(0);
47        println!("  Status: ✓ GPU 0 available");
48    }
49
50    #[cfg(all(feature = "wgpu", not(feature = "cuda")))]
51    {
52        println!("  Backend: WGPU (Cross-platform GPU)");
53        println!("  Features: Vulkan/Metal/DX12 support");
54        println!("  Status: ✓ GPU acceleration enabled");
55    }
56
57    #[cfg(all(feature = "cpu", not(any(feature = "cuda", feature = "wgpu"))))]
58    {
59        println!("  Backend: NdArray (CPU)");
60        println!("  Features: Portable, no GPU required");
61        println!("  Note: For GPU acceleration, rebuild with:");
62        println!("    cargo run --release --features wgpu");
63        println!("    cargo run --release --features cuda (NVIDIA)");
64    }
65
66    println!();
67}
68
69/// Check if GPU is available
70pub fn is_gpu_available() -> bool {
71    #[cfg(any(feature = "cuda", feature = "wgpu"))]
72    {
73        true
74    }
75
76    #[cfg(not(any(feature = "cuda", feature = "wgpu")))]
77    {
78        false
79    }
80}