ferrotorch_cubecl/runtime.rs
1//! Unified runtime selection for CubeCL backends.
2//!
3//! [`CubeDevice`] enumerates the three supported backends (CUDA, ROCm, WGPU),
4//! each parameterised by a device ordinal. [`CubeRuntime`] resolves that
5//! selection into a real CubeCL [`ComputeClient`] — one per backend — which
6//! owns the on-device memory and the compiled-kernel cache for that device.
7//!
8//! ## REQ status (per `.design/ferrotorch-cubecl/runtime.md`)
9//!
10//! Full evidence rows (impl + non-test production consumer + upstream
11//! cites) live in the design doc; this synopsis is a one-line summary per
12//! REQ.
13//!
14//! | REQ | Status | Evidence |
15//! |---|---|---|
16//! | REQ-1 (`CubeDevice` enum) | SHIPPED | `pub enum CubeDevice in runtime.rs` with `Cuda(usize)/Wgpu(usize)/Rocm(usize)` + derived `Debug, Clone, Copy, PartialEq, Eq, Hash`; consumer `ferrotorch-xpu/src/lib.rs` constructs `CubeDevice::Wgpu(ordinal)` inside `XpuDevice::new` |
17//! | REQ-2 (`ordinal/backend_name/Display`) | SHIPPED | `impl CubeDevice in runtime.rs` + `impl fmt::Display for CubeDevice in runtime.rs`; pinned by `cube_device_display` test |
18//! | REQ-3 (`CubeClient` enum) | SHIPPED | `pub enum CubeClient in runtime.rs` with cfg-gated `Cuda/Wgpu/Rocm(ComputeClient<R>)` + always-present `Stub` (#1083); consumer `ferrotorch-cubecl/src/ops.rs` dispatch macros match on `CubeClient` |
19//! | REQ-4 (`CubeRuntime::new`) | SHIPPED | `pub fn new in runtime.rs` + `fn make_client in runtime.rs` with cfg-arm fallback to `DeviceUnavailable`; consumer `ferrotorch-xpu/src/lib.rs` calls `CubeRuntime::new(CubeDevice::Wgpu(ordinal))?` |
20//! | REQ-5 (`CubeRuntime::auto`) | SHIPPED | `pub fn auto in runtime.rs` with cfg-gated CUDA > ROCm > WGPU priority chain; consumer `lib.rs` rustdoc demo + meta-crate `ferrotorch::cubecl` on-ramp |
21//! | REQ-6 (`is_available`) | SHIPPED | `pub fn is_available in runtime.rs` returning `cfg!(any(feature = "cuda", feature = "rocm", feature = "wgpu"))`; consumer `ferrotorch-xpu/src/lib.rs::XpuDevice::is_available` |
22//! | REQ-7 (`read_f32s`) | SHIPPED | `pub fn read_f32s in runtime.rs` (cfg `any(wgpu,cuda,rocm)`) dispatching `c.read_one(handle)`; consumer `ferrotorch-cubecl/src/storage.rs::CubeclStorageHandle::read_to_host` |
23//! | REQ-8 (`new_for_testing`) | SHIPPED | `#[doc(hidden)] pub fn new_for_testing in runtime.rs` returning `CubeClient::Stub` (#1083); test-infrastructure boundary API grandfathered per goal.md S5 |
24
25use std::fmt;
26
27#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
28use cubecl::Runtime;
29#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
30use cubecl::prelude::ComputeClient;
31use ferrotorch_core::FerrotorchResult;
32
33#[cfg(feature = "cuda")]
34use cubecl_cuda::{CudaDevice, CudaRuntime};
35#[cfg(feature = "rocm")]
36use cubecl_hip::{AmdDevice, HipRuntime};
37#[cfg(feature = "wgpu")]
38use cubecl_wgpu::{WgpuDevice, WgpuRuntime};
39
40// ---------------------------------------------------------------------------
41// CubeDevice
42// ---------------------------------------------------------------------------
43
44/// A device selector for CubeCL backends.
45///
46/// The `usize` field is the device ordinal (e.g., GPU index 0, 1, ...).
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48pub enum CubeDevice {
49 /// NVIDIA GPU via CUDA PTX codegen.
50 Cuda(usize),
51 /// Portable GPU via WGPU — AMD (Vulkan), Intel (Vulkan), Apple (Metal).
52 Wgpu(usize),
53 /// AMD GPU via native HIP/ROCm runtime.
54 Rocm(usize),
55}
56
57impl CubeDevice {
58 /// Device ordinal regardless of backend.
59 #[inline]
60 pub fn ordinal(&self) -> usize {
61 match self {
62 Self::Cuda(o) | Self::Wgpu(o) | Self::Rocm(o) => *o,
63 }
64 }
65
66 /// Human-readable backend name.
67 #[inline]
68 pub fn backend_name(&self) -> &'static str {
69 match self {
70 Self::Cuda(_) => "cuda",
71 Self::Wgpu(_) => "wgpu",
72 Self::Rocm(_) => "rocm",
73 }
74 }
75}
76
77impl fmt::Display for CubeDevice {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 write!(f, "{}:{}", self.backend_name(), self.ordinal())
80 }
81}
82
83// ---------------------------------------------------------------------------
84// CubeClient — per-backend real compute client
85// ---------------------------------------------------------------------------
86
87/// An initialised CubeCL compute client for one of the supported backends.
88///
89/// The variant is determined by which runtime feature was compiled in and
90/// what [`CubeDevice`] the runtime was built for. `ops.rs` matches on this
91/// enum to dispatch generic CubeCL kernels to the correct backend.
92///
93/// The [`CubeClient::Stub`] variant is always present (no cfg gate) and
94/// is reserved for tests — it holds no client state and every kernel
95/// dispatch macro has a `Stub => unreachable!()` arm. Tests that want to
96/// exercise pre-dispatch paths (shape checks, signature pins) on a
97/// machine without a real backend can build a [`CubeRuntime`] via
98/// [`CubeRuntime::new_for_testing`], whose `client` field is `Stub`.
99/// Production code paths never construct or observe `Stub` because
100/// [`CubeRuntime::new`] only ever yields a real backend client. (#1083)
101#[derive(Clone)]
102pub enum CubeClient {
103 /// A real Wgpu (Vulkan/Metal/DX12) compute client.
104 #[cfg(feature = "wgpu")]
105 Wgpu(ComputeClient<WgpuRuntime>),
106 /// A real CUDA compute client.
107 #[cfg(feature = "cuda")]
108 Cuda(ComputeClient<CudaRuntime>),
109 /// A real HIP/ROCm compute client.
110 #[cfg(feature = "rocm")]
111 Rocm(ComputeClient<HipRuntime>),
112 /// Test stub — every kernel dispatch panics; only pre-dispatch
113 /// paths (shape checks, signature pins) are reachable.
114 ///
115 /// This variant is always compiled in (no cfg gate) so tests in
116 /// any feature configuration can construct a runtime without a
117 /// real backend client. Reaching a kernel dispatch arm with
118 /// `Stub` is a test-discipline bug — the shape-mismatch test
119 /// must fire before dispatch. (#1083)
120 Stub,
121}
122
123impl fmt::Debug for CubeClient {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 match self {
126 #[cfg(feature = "wgpu")]
127 Self::Wgpu(_) => f.write_str("CubeClient::Wgpu(..)"),
128 #[cfg(feature = "cuda")]
129 Self::Cuda(_) => f.write_str("CubeClient::Cuda(..)"),
130 #[cfg(feature = "rocm")]
131 Self::Rocm(_) => f.write_str("CubeClient::Rocm(..)"),
132 Self::Stub => f.write_str("CubeClient::Stub"),
133 }
134 }
135}
136
137// ---------------------------------------------------------------------------
138// CubeRuntime
139// ---------------------------------------------------------------------------
140
141/// CubeCL runtime wrapper that holds a real compute client for one device.
142#[derive(Clone, Debug)]
143pub struct CubeRuntime {
144 device: CubeDevice,
145 client: CubeClient,
146}
147
148impl CubeRuntime {
149 /// Create a runtime targeting the given device.
150 ///
151 /// Returns `Err(FerrotorchError::DeviceUnavailable)` if the required
152 /// backend feature was not compiled in.
153 pub fn new(device: CubeDevice) -> FerrotorchResult<Self> {
154 let client = Self::make_client(device)?;
155 Ok(Self { device, client })
156 }
157
158 /// Construct a `CubeRuntime` whose [`CubeClient`] is the test-only
159 /// [`CubeClient::Stub`] variant.
160 ///
161 /// Reserved for conformance tests that want to exercise CPU-side
162 /// pre-dispatch paths — shape validation, dtype checks, signature
163 /// pins — when the test environment has no usable wgpu/CUDA/ROCm
164 /// adapter. Reaching kernel dispatch with this runtime is a test
165 /// bug: every dispatch macro arm panics on `Stub`.
166 ///
167 /// Production code paths must use [`Self::new`] (which never
168 /// returns `Stub`) or [`Self::auto`] (likewise). (#1083)
169 #[doc(hidden)]
170 pub fn new_for_testing(device: CubeDevice) -> Self {
171 Self {
172 device,
173 client: CubeClient::Stub,
174 }
175 }
176
177 /// The device this runtime targets.
178 #[inline]
179 pub fn device(&self) -> &CubeDevice {
180 &self.device
181 }
182
183 /// The underlying compute client (one variant per backend).
184 #[inline]
185 pub fn client(&self) -> &CubeClient {
186 &self.client
187 }
188
189 /// Auto-detect the best available backend, returning `None` if no GPU
190 /// backend feature is enabled.
191 ///
192 /// Priority order: CUDA > ROCm > WGPU.
193 #[allow(unreachable_code)] // reason: each cfg-gated branch unconditionally returns; subsequent branches are tried only when the prior feature is off
194 pub fn auto() -> Option<Self> {
195 // CUDA takes priority when available.
196 #[cfg(feature = "cuda")]
197 {
198 return Self::new(CubeDevice::Cuda(0)).ok();
199 }
200
201 // ROCm for AMD-native workloads.
202 #[cfg(feature = "rocm")]
203 {
204 return Self::new(CubeDevice::Rocm(0)).ok();
205 }
206
207 // WGPU is the most portable fallback.
208 #[cfg(feature = "wgpu")]
209 {
210 return Self::new(CubeDevice::Wgpu(0)).ok();
211 }
212
213 None
214 }
215
216 /// Returns `true` if any GPU backend feature was compiled in.
217 pub fn is_available() -> bool {
218 cfg!(any(feature = "cuda", feature = "rocm", feature = "wgpu"))
219 }
220
221 /// Read `n` `f32` values from a device-resident handle back to host memory.
222 ///
223 /// This is the single readback point for callers (e.g. `ferrotorch-xpu`)
224 /// that receive a `(cubecl::server::Handle, Vec<usize>)` from a
225 /// `portable_*` op and need CPU-resident data. Dispatches to the correct
226 /// backend client. ADR #663 item 4.
227 #[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
228 pub fn read_f32s(
229 &self,
230 handle: cubecl::server::Handle,
231 n: usize,
232 ) -> ferrotorch_core::FerrotorchResult<Vec<f32>> {
233 use cubecl::prelude::*;
234 let bytes = match &self.client {
235 #[cfg(feature = "wgpu")]
236 CubeClient::Wgpu(c) => c.read_one(handle),
237 #[cfg(feature = "cuda")]
238 CubeClient::Cuda(c) => c.read_one(handle),
239 #[cfg(feature = "rocm")]
240 CubeClient::Rocm(c) => c.read_one(handle),
241 // #1083: the Stub variant is reserved for tests that exercise
242 // pre-dispatch paths only; reaching readback means a kernel
243 // would have already had to dispatch, which the dispatch
244 // macros refuse for Stub.
245 CubeClient::Stub => unreachable!(
246 "CubeClient::Stub reached read_f32s — Stub runtimes must not \
247 reach kernel dispatch or readback; shape check or signature \
248 pin should fire first (#1083)"
249 ),
250 }
251 .map_err(|e| ferrotorch_core::FerrotorchError::InvalidArgument {
252 message: format!("cubecl read_one failed: {e}"),
253 })?;
254 // SAFETY: `bytes` came from a `client.empty(n * size_of::<f32>())`
255 // buffer filled by a `#[cube]` kernel writing `f32` values. The byte
256 // length is `n * 4` by construction. `f32::from_bytes` reinterprets
257 // the slice as `&[f32]` (same alignment on all supported backends).
258 Ok(f32::from_bytes(&bytes)[..n].to_vec())
259 }
260
261 // ---------------------------------------------------------------------
262 // Backend client construction
263 // ---------------------------------------------------------------------
264
265 #[allow(clippy::unnecessary_wraps)] // reason: returns Err under #[cfg(not(feature=...))]; clippy only sees the all-features path
266 fn make_client(device: CubeDevice) -> FerrotorchResult<CubeClient> {
267 match device {
268 CubeDevice::Wgpu(idx) => {
269 #[cfg(feature = "wgpu")]
270 {
271 let wgpu_device = wgpu_device_for_index(idx);
272 let client = WgpuRuntime::client(&wgpu_device);
273 Ok(CubeClient::Wgpu(client))
274 }
275 #[cfg(not(feature = "wgpu"))]
276 {
277 let _ = idx;
278 Err(ferrotorch_core::FerrotorchError::DeviceUnavailable)
279 }
280 }
281 CubeDevice::Cuda(idx) => {
282 #[cfg(feature = "cuda")]
283 {
284 let cuda_device = CudaDevice { index: idx };
285 let client = CudaRuntime::client(&cuda_device);
286 Ok(CubeClient::Cuda(client))
287 }
288 #[cfg(not(feature = "cuda"))]
289 {
290 let _ = idx;
291 Err(ferrotorch_core::FerrotorchError::DeviceUnavailable)
292 }
293 }
294 CubeDevice::Rocm(idx) => {
295 #[cfg(feature = "rocm")]
296 {
297 let amd_device = AmdDevice { index: idx };
298 let client = HipRuntime::client(&amd_device);
299 Ok(CubeClient::Rocm(client))
300 }
301 #[cfg(not(feature = "rocm"))]
302 {
303 let _ = idx;
304 Err(ferrotorch_core::FerrotorchError::DeviceUnavailable)
305 }
306 }
307 }
308 }
309}
310
311#[cfg(feature = "wgpu")]
312fn wgpu_device_for_index(index: usize) -> WgpuDevice {
313 match index {
314 // Index 0 maps to the system default adapter; this is the most
315 // portable choice and matches how ferrotorch-gpu selects a GPU.
316 0 => WgpuDevice::DefaultDevice,
317 // Higher indices explicitly select a discrete GPU slot.
318 n => WgpuDevice::DiscreteGpu(n),
319 }
320}
321
322// ---------------------------------------------------------------------------
323// Tests
324// ---------------------------------------------------------------------------
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn cube_device_ordinal() {
332 assert_eq!(CubeDevice::Cuda(3).ordinal(), 3);
333 assert_eq!(CubeDevice::Wgpu(1).ordinal(), 1);
334 assert_eq!(CubeDevice::Rocm(0).ordinal(), 0);
335 }
336
337 #[test]
338 fn cube_device_backend_name() {
339 assert_eq!(CubeDevice::Cuda(0).backend_name(), "cuda");
340 assert_eq!(CubeDevice::Wgpu(0).backend_name(), "wgpu");
341 assert_eq!(CubeDevice::Rocm(0).backend_name(), "rocm");
342 }
343
344 #[test]
345 fn cube_device_display() {
346 assert_eq!(CubeDevice::Cuda(2).to_string(), "cuda:2");
347 assert_eq!(CubeDevice::Wgpu(0).to_string(), "wgpu:0");
348 assert_eq!(CubeDevice::Rocm(1).to_string(), "rocm:1");
349 }
350
351 #[test]
352 fn cube_device_equality() {
353 assert_eq!(CubeDevice::Cuda(0), CubeDevice::Cuda(0));
354 assert_ne!(CubeDevice::Cuda(0), CubeDevice::Cuda(1));
355 assert_ne!(CubeDevice::Cuda(0), CubeDevice::Wgpu(0));
356 }
357
358 #[test]
359 fn cube_device_clone_and_hash() {
360 use std::collections::HashSet;
361 let mut set = HashSet::new();
362 set.insert(CubeDevice::Cuda(0));
363 set.insert(CubeDevice::Wgpu(0));
364 set.insert(CubeDevice::Rocm(0));
365 assert_eq!(set.len(), 3);
366
367 // Duplicate should not increase size.
368 set.insert(CubeDevice::Cuda(0));
369 assert_eq!(set.len(), 3);
370 }
371
372 /// Probe whether wgpu can construct a runtime in the current
373 /// environment. WSL2 lacks a Vulkan ICD by default, so the cubecl-wgpu
374 /// worker thread panics during adapter selection and the panic surfaces
375 /// on the main thread as `RecvError`. Catch that here so tests skip
376 /// cleanly instead of failing.
377 #[cfg(feature = "wgpu")]
378 fn wgpu_probe_runtime() -> Option<CubeRuntime> {
379 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
380 CubeRuntime::new(CubeDevice::Wgpu(0)).ok()
381 }));
382 match result {
383 Ok(Some(rt)) => Some(rt),
384 _ => None,
385 }
386 }
387
388 #[cfg(feature = "wgpu")]
389 #[test]
390 fn wgpu_runtime_new_and_device() {
391 let Some(rt) = wgpu_probe_runtime() else {
392 eprintln!("[ferrotorch-cubecl] wgpu adapter unavailable; skipping");
393 return;
394 };
395 assert_eq!(*rt.device(), CubeDevice::Wgpu(0));
396 // Client should match the selected backend.
397 assert!(matches!(rt.client(), CubeClient::Wgpu(_)));
398 }
399
400 #[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
401 #[test]
402 fn no_backend_feature_yields_device_unavailable() {
403 let err = CubeRuntime::new(CubeDevice::Wgpu(0)).unwrap_err();
404 assert!(matches!(
405 err,
406 ferrotorch_core::FerrotorchError::DeviceUnavailable
407 ));
408 }
409
410 #[test]
411 fn cube_runtime_auto_returns_something_or_none() {
412 // `auto()` may panic on the worker thread if a backend feature is
413 // compiled in but the actual hardware/driver isn't available
414 // (e.g. wgpu in WSL without Vulkan). Catch that and treat it as
415 // "not available" — matching the documented contract that this
416 // function returns `Option`.
417 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(CubeRuntime::auto));
418 match result {
419 Ok(Some(_)) => assert!(CubeRuntime::is_available()),
420 Ok(None) | Err(_) => {
421 // Either no backend feature compiled in, or the backend
422 // feature is present but no adapter exists at runtime.
423 // Both are valid outcomes.
424 eprintln!(
425 "[ferrotorch-cubecl] auto() returned no runtime (no backend feature or no \
426 adapter); test passes"
427 );
428 }
429 }
430 }
431
432 #[test]
433 fn cube_runtime_is_available_consistent() {
434 // `is_available()` is a compile-time check (`cfg!(...)`). When a
435 // feature is compiled in but no hardware exists at runtime, `auto()`
436 // may still return `Some` (lazy init succeeds, kernel dispatch
437 // would fail later). We accept that asymmetry here — the test
438 // verifies that "feature compiled in" is at least consistent with
439 // "auto() doesn't return None for compile-time reasons".
440 let available = CubeRuntime::is_available();
441 if !available {
442 // Belt-and-suspenders: when no feature is compiled, auto() must
443 // be None.
444 let auto = std::panic::catch_unwind(std::panic::AssertUnwindSafe(CubeRuntime::auto));
445 assert!(auto.ok().flatten().is_none());
446 }
447 }
448}