optical_embeddings/
backend.rs1use burn::prelude::*;
4
5#[cfg(feature = "cuda")]
7pub type AutoBackend = burn_cuda::Cuda<f32>;
8
9#[cfg(all(feature = "wgpu", not(feature = "cuda")))]
10pub type AutoBackend = burn_wgpu::Wgpu<f32, i32>;
11
12#[cfg(all(feature = "cpu", not(any(feature = "cuda", feature = "wgpu"))))]
13pub type AutoBackend = burn_ndarray::NdArray<f32>;
14
15pub fn get_device() -> <AutoBackend as Backend>::Device {
17 #[cfg(feature = "cuda")]
18 {
19 println!("🚀 Using CUDA backend (NVIDIA GPU)");
20 burn_cuda::CudaDevice::default()
21 }
22
23 #[cfg(all(feature = "wgpu", not(feature = "cuda")))]
24 {
25 println!("🎮 Using WGPU backend (GPU via Vulkan/Metal/DX12)");
26 burn_wgpu::WgpuDevice::default()
27 }
28
29 #[cfg(all(feature = "cpu", not(any(feature = "cuda", feature = "wgpu"))))]
30 {
31 println!("💻 Using NdArray backend (CPU)");
32 Default::default()
33 }
34}
35
36pub fn print_backend_info() {
38 println!("╔════════════════════════════════════════╗");
39 println!("║ Backend Configuration ║");
40 println!("╚════════════════════════════════════════╝");
41
42 #[cfg(feature = "cuda")]
43 {
44 println!(" Backend: CUDA (NVIDIA GPU)");
45 println!(" Features: Fast matrix ops, tensor cores");
46 let _device = burn_cuda::CudaDevice::new(0);
47 println!(" Status: ✓ GPU 0 available");
48 }
49
50 #[cfg(all(feature = "wgpu", not(feature = "cuda")))]
51 {
52 println!(" Backend: WGPU (Cross-platform GPU)");
53 println!(" Features: Vulkan/Metal/DX12 support");
54 println!(" Status: ✓ GPU acceleration enabled");
55 }
56
57 #[cfg(all(feature = "cpu", not(any(feature = "cuda", feature = "wgpu"))))]
58 {
59 println!(" Backend: NdArray (CPU)");
60 println!(" Features: Portable, no GPU required");
61 println!(" Note: For GPU acceleration, rebuild with:");
62 println!(" cargo run --release --features wgpu");
63 println!(" cargo run --release --features cuda (NVIDIA)");
64 }
65
66 println!();
67}
68
69pub fn is_gpu_available() -> bool {
71 #[cfg(any(feature = "cuda", feature = "wgpu"))]
72 {
73 true
74 }
75
76 #[cfg(not(any(feature = "cuda", feature = "wgpu")))]
77 {
78 false
79 }
80}