rlx_driver/device.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Device selection — which backend to use.
17
18/// Target device for graph execution.
19///
20/// Each variant maps to a backend crate gated by a Cargo feature.
21/// Use `Device::is_available()` to check if the feature is enabled.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub enum Device {
24 // ── CPU ─────────────────────────────────────────────────
25 /// CPU with SIMD (NEON/AVX) + optional BLAS.
26 Cpu,
27
28 // ── Apple ───────────────────────────────────────────────
29 /// GPU via Apple Metal (Metal Performance Shaders).
30 Metal,
31 /// Apple MLX framework (unified memory GPU).
32 Mlx,
33 /// Apple Neural Engine.
34 Ane,
35
36 // ── NVIDIA ──────────────────────────────────────────────
37 /// NVIDIA GPU via native CUDA (cuBLAS, cuDNN).
38 Cuda,
39
40 // ── AMD ─────────────────────────────────────────────────
41 /// AMD GPU via ROCm/HIP.
42 Rocm,
43
44 // ── Google ──────────────────────────────────────────────
45 /// Google TPU via libtpu's PJRT plugin (no Python).
46 Tpu,
47
48 // ── Cross-platform GPU ──────────────────────────────────
49 /// Portable GPU via wgpu (Metal/Vulkan/DX12/WebGPU).
50 Gpu,
51 /// Vulkan compute shaders.
52 Vulkan,
53 /// OpenGL compute shaders (legacy).
54 OpenGl,
55 /// DirectX 12 compute (Windows).
56 DirectX,
57 /// WebGPU (WASM target).
58 WebGpu,
59}
60
61impl Device {
62 /// Human-readable name (no engine-layer info).
63 /// `is_available` / `available` live in rlx-runtime since they
64 /// consult the engine's backend registry — keeping them out of
65 /// the driver layer preserves the one-way dep direction.
66 pub fn name(self) -> &'static str {
67 match self {
68 Device::Cpu => "CPU",
69 Device::Metal => "Metal",
70 Device::Mlx => "MLX",
71 Device::Ane => "ANE",
72 Device::Cuda => "CUDA",
73 Device::Rocm => "ROCm",
74 Device::Tpu => "TPU",
75 Device::Gpu => "GPU (wgpu)",
76 Device::Vulkan => "Vulkan",
77 Device::OpenGl => "OpenGL",
78 Device::DirectX => "DirectX 12",
79 Device::WebGpu => "WebGPU",
80 }
81 }
82
83 /// All variant labels — convenience for callers that want to
84 /// enumerate without listing every variant manually. Pair
85 /// with `rlx_runtime::available_devices()` to filter.
86 pub fn all() -> &'static [Device] {
87 &[
88 Device::Cpu,
89 Device::Metal,
90 Device::Mlx,
91 Device::Ane,
92 Device::Cuda,
93 Device::Rocm,
94 Device::Tpu,
95 Device::Gpu,
96 Device::Vulkan,
97 Device::OpenGl,
98 Device::DirectX,
99 Device::WebGpu,
100 ]
101 }
102}
103
104impl std::fmt::Display for Device {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 write!(f, "{}", self.name())
107 }
108}