Skip to main content

ferrotorch_cubecl/
runtime.rs

1//! Unified runtime selection for CubeCL backends.
2//!
3//! [`CubeDevice`] enumerates the three supported backends (CUDA, ROCm, WGPU),
4//! each parameterised by a device ordinal. [`CubeRuntime`] resolves that
5//! selection into a real CubeCL [`ComputeClient`] — one per backend — which
6//! owns the on-device memory and the compiled-kernel cache for that device.
7//!
8//! ## REQ status (per `.design/ferrotorch-cubecl/runtime.md`)
9//!
10//! Full evidence rows (impl + non-test production consumer + upstream
11//! cites) live in the design doc; this synopsis is a one-line summary per
12//! REQ.
13//!
14//! | REQ | Status | Evidence |
15//! |---|---|---|
16//! | REQ-1 (`CubeDevice` enum) | SHIPPED | `pub enum CubeDevice in runtime.rs` with `Cuda(usize)/Wgpu(usize)/Rocm(usize)` + derived `Debug, Clone, Copy, PartialEq, Eq, Hash`; consumer `ferrotorch-xpu/src/lib.rs` constructs `CubeDevice::Wgpu(ordinal)` inside `XpuDevice::new` |
17//! | REQ-2 (`ordinal/backend_name/Display`) | SHIPPED | `impl CubeDevice in runtime.rs` + `impl fmt::Display for CubeDevice in runtime.rs`; pinned by `cube_device_display` test |
18//! | REQ-3 (`CubeClient` enum) | SHIPPED | `pub enum CubeClient in runtime.rs` with cfg-gated `Cuda/Wgpu/Rocm(ComputeClient<R>)` + always-present `Stub` (#1083); consumer `ferrotorch-cubecl/src/ops.rs` dispatch macros match on `CubeClient` |
19//! | REQ-4 (`CubeRuntime::new`) | SHIPPED | `pub fn new in runtime.rs` + `fn make_client in runtime.rs` with cfg-arm fallback to `DeviceUnavailable`; consumer `ferrotorch-xpu/src/lib.rs` calls `CubeRuntime::new(CubeDevice::Wgpu(ordinal))?` |
20//! | REQ-5 (`CubeRuntime::auto`) | SHIPPED | `pub fn auto in runtime.rs` with cfg-gated CUDA > ROCm > WGPU priority chain; consumer `lib.rs` rustdoc demo + meta-crate `ferrotorch::cubecl` on-ramp |
21//! | REQ-6 (`is_available`) | SHIPPED | `pub fn is_available in runtime.rs` returning `cfg!(any(feature = "cuda", feature = "rocm", feature = "wgpu"))`; consumer `ferrotorch-xpu/src/lib.rs::XpuDevice::is_available` |
22//! | REQ-7 (`read_f32s`) | SHIPPED | `pub fn read_f32s in runtime.rs` (cfg `any(wgpu,cuda,rocm)`) dispatching `c.read_one(handle)`; consumer `ferrotorch-cubecl/src/storage.rs::CubeclStorageHandle::read_to_host` |
23//! | REQ-8 (`new_for_testing`) | SHIPPED | `#[doc(hidden)] pub fn new_for_testing in runtime.rs` returning `CubeClient::Stub` (#1083); test-infrastructure boundary API grandfathered per goal.md S5 |
24
25use std::fmt;
26
27#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
28use cubecl::Runtime;
29#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
30use cubecl::prelude::ComputeClient;
31use ferrotorch_core::FerrotorchResult;
32
33#[cfg(feature = "cuda")]
34use cubecl_cuda::{CudaDevice, CudaRuntime};
35#[cfg(feature = "rocm")]
36use cubecl_hip::{AmdDevice, HipRuntime};
37#[cfg(feature = "wgpu")]
38use cubecl_wgpu::{WgpuDevice, WgpuRuntime};
39
40// ---------------------------------------------------------------------------
41// CubeDevice
42// ---------------------------------------------------------------------------
43
44/// A device selector for CubeCL backends.
45///
46/// The `usize` field is the device ordinal (e.g., GPU index 0, 1, ...).
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48pub enum CubeDevice {
49    /// NVIDIA GPU via CUDA PTX codegen.
50    Cuda(usize),
51    /// Portable GPU via WGPU — AMD (Vulkan), Intel (Vulkan), Apple (Metal).
52    Wgpu(usize),
53    /// AMD GPU via native HIP/ROCm runtime.
54    Rocm(usize),
55}
56
57impl CubeDevice {
58    /// Device ordinal regardless of backend.
59    #[inline]
60    pub fn ordinal(&self) -> usize {
61        match self {
62            Self::Cuda(o) | Self::Wgpu(o) | Self::Rocm(o) => *o,
63        }
64    }
65
66    /// Human-readable backend name.
67    #[inline]
68    pub fn backend_name(&self) -> &'static str {
69        match self {
70            Self::Cuda(_) => "cuda",
71            Self::Wgpu(_) => "wgpu",
72            Self::Rocm(_) => "rocm",
73        }
74    }
75}
76
77impl fmt::Display for CubeDevice {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        write!(f, "{}:{}", self.backend_name(), self.ordinal())
80    }
81}
82
83// ---------------------------------------------------------------------------
84// CubeClient — per-backend real compute client
85// ---------------------------------------------------------------------------
86
87/// An initialised CubeCL compute client for one of the supported backends.
88///
89/// The variant is determined by which runtime feature was compiled in and
90/// what [`CubeDevice`] the runtime was built for. `ops.rs` matches on this
91/// enum to dispatch generic CubeCL kernels to the correct backend.
92///
93/// The [`CubeClient::Stub`] variant is always present (no cfg gate) and
94/// is reserved for tests — it holds no client state and every kernel
95/// dispatch macro has a `Stub => unreachable!()` arm. Tests that want to
96/// exercise pre-dispatch paths (shape checks, signature pins) on a
97/// machine without a real backend can build a [`CubeRuntime`] via
98/// [`CubeRuntime::new_for_testing`], whose `client` field is `Stub`.
99/// Production code paths never construct or observe `Stub` because
100/// [`CubeRuntime::new`] only ever yields a real backend client. (#1083)
101#[derive(Clone)]
102pub enum CubeClient {
103    /// A real Wgpu (Vulkan/Metal/DX12) compute client.
104    #[cfg(feature = "wgpu")]
105    Wgpu(ComputeClient<WgpuRuntime>),
106    /// A real CUDA compute client.
107    #[cfg(feature = "cuda")]
108    Cuda(ComputeClient<CudaRuntime>),
109    /// A real HIP/ROCm compute client.
110    #[cfg(feature = "rocm")]
111    Rocm(ComputeClient<HipRuntime>),
112    /// Test stub — every kernel dispatch panics; only pre-dispatch
113    /// paths (shape checks, signature pins) are reachable.
114    ///
115    /// This variant is always compiled in (no cfg gate) so tests in
116    /// any feature configuration can construct a runtime without a
117    /// real backend client. Reaching a kernel dispatch arm with
118    /// `Stub` is a test-discipline bug — the shape-mismatch test
119    /// must fire before dispatch. (#1083)
120    Stub,
121}
122
123impl fmt::Debug for CubeClient {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        match self {
126            #[cfg(feature = "wgpu")]
127            Self::Wgpu(_) => f.write_str("CubeClient::Wgpu(..)"),
128            #[cfg(feature = "cuda")]
129            Self::Cuda(_) => f.write_str("CubeClient::Cuda(..)"),
130            #[cfg(feature = "rocm")]
131            Self::Rocm(_) => f.write_str("CubeClient::Rocm(..)"),
132            Self::Stub => f.write_str("CubeClient::Stub"),
133        }
134    }
135}
136
137// ---------------------------------------------------------------------------
138// CubeRuntime
139// ---------------------------------------------------------------------------
140
141/// CubeCL runtime wrapper that holds a real compute client for one device.
142#[derive(Clone, Debug)]
143pub struct CubeRuntime {
144    device: CubeDevice,
145    client: CubeClient,
146}
147
148impl CubeRuntime {
149    /// Create a runtime targeting the given device.
150    ///
151    /// Returns `Err(FerrotorchError::DeviceUnavailable)` if the required
152    /// backend feature was not compiled in.
153    pub fn new(device: CubeDevice) -> FerrotorchResult<Self> {
154        let client = Self::make_client(device)?;
155        Ok(Self { device, client })
156    }
157
158    /// Construct a `CubeRuntime` whose [`CubeClient`] is the test-only
159    /// [`CubeClient::Stub`] variant.
160    ///
161    /// Reserved for conformance tests that want to exercise CPU-side
162    /// pre-dispatch paths — shape validation, dtype checks, signature
163    /// pins — when the test environment has no usable wgpu/CUDA/ROCm
164    /// adapter. Reaching kernel dispatch with this runtime is a test
165    /// bug: every dispatch macro arm panics on `Stub`.
166    ///
167    /// Production code paths must use [`Self::new`] (which never
168    /// returns `Stub`) or [`Self::auto`] (likewise). (#1083)
169    #[doc(hidden)]
170    pub fn new_for_testing(device: CubeDevice) -> Self {
171        Self {
172            device,
173            client: CubeClient::Stub,
174        }
175    }
176
177    /// The device this runtime targets.
178    #[inline]
179    pub fn device(&self) -> &CubeDevice {
180        &self.device
181    }
182
183    /// The underlying compute client (one variant per backend).
184    #[inline]
185    pub fn client(&self) -> &CubeClient {
186        &self.client
187    }
188
189    /// Auto-detect the best available backend, returning `None` if no GPU
190    /// backend feature is enabled.
191    ///
192    /// Priority order: CUDA > ROCm > WGPU.
193    #[allow(unreachable_code)] // reason: each cfg-gated branch unconditionally returns; subsequent branches are tried only when the prior feature is off
194    pub fn auto() -> Option<Self> {
195        // CUDA takes priority when available.
196        #[cfg(feature = "cuda")]
197        {
198            return Self::new(CubeDevice::Cuda(0)).ok();
199        }
200
201        // ROCm for AMD-native workloads.
202        #[cfg(feature = "rocm")]
203        {
204            return Self::new(CubeDevice::Rocm(0)).ok();
205        }
206
207        // WGPU is the most portable fallback.
208        #[cfg(feature = "wgpu")]
209        {
210            return Self::new(CubeDevice::Wgpu(0)).ok();
211        }
212
213        None
214    }
215
216    /// Returns `true` if any GPU backend feature was compiled in.
217    pub fn is_available() -> bool {
218        cfg!(any(feature = "cuda", feature = "rocm", feature = "wgpu"))
219    }
220
221    /// Read `n` `f32` values from a device-resident handle back to host memory.
222    ///
223    /// This is the single readback point for callers (e.g. `ferrotorch-xpu`)
224    /// that receive a `(cubecl::server::Handle, Vec<usize>)` from a
225    /// `portable_*` op and need CPU-resident data. Dispatches to the correct
226    /// backend client. ADR #663 item 4.
227    #[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
228    pub fn read_f32s(
229        &self,
230        handle: cubecl::server::Handle,
231        n: usize,
232    ) -> ferrotorch_core::FerrotorchResult<Vec<f32>> {
233        use cubecl::prelude::*;
234        let bytes = match &self.client {
235            #[cfg(feature = "wgpu")]
236            CubeClient::Wgpu(c) => c.read_one(handle),
237            #[cfg(feature = "cuda")]
238            CubeClient::Cuda(c) => c.read_one(handle),
239            #[cfg(feature = "rocm")]
240            CubeClient::Rocm(c) => c.read_one(handle),
241            // #1083: the Stub variant is reserved for tests that exercise
242            // pre-dispatch paths only; reaching readback means a kernel
243            // would have already had to dispatch, which the dispatch
244            // macros refuse for Stub.
245            CubeClient::Stub => unreachable!(
246                "CubeClient::Stub reached read_f32s — Stub runtimes must not \
247                 reach kernel dispatch or readback; shape check or signature \
248                 pin should fire first (#1083)"
249            ),
250        }
251        .map_err(|e| ferrotorch_core::FerrotorchError::InvalidArgument {
252            message: format!("cubecl read_one failed: {e}"),
253        })?;
254        // SAFETY: `bytes` came from a `client.empty(n * size_of::<f32>())`
255        // buffer filled by a `#[cube]` kernel writing `f32` values. The byte
256        // length is `n * 4` by construction. `f32::from_bytes` reinterprets
257        // the slice as `&[f32]` (same alignment on all supported backends).
258        Ok(f32::from_bytes(&bytes)[..n].to_vec())
259    }
260
261    // ---------------------------------------------------------------------
262    // Backend client construction
263    // ---------------------------------------------------------------------
264
265    #[allow(clippy::unnecessary_wraps)] // reason: returns Err under #[cfg(not(feature=...))]; clippy only sees the all-features path
266    fn make_client(device: CubeDevice) -> FerrotorchResult<CubeClient> {
267        match device {
268            CubeDevice::Wgpu(idx) => {
269                #[cfg(feature = "wgpu")]
270                {
271                    let wgpu_device = wgpu_device_for_index(idx);
272                    let client = WgpuRuntime::client(&wgpu_device);
273                    Ok(CubeClient::Wgpu(client))
274                }
275                #[cfg(not(feature = "wgpu"))]
276                {
277                    let _ = idx;
278                    Err(ferrotorch_core::FerrotorchError::DeviceUnavailable)
279                }
280            }
281            CubeDevice::Cuda(idx) => {
282                #[cfg(feature = "cuda")]
283                {
284                    let cuda_device = CudaDevice { index: idx };
285                    let client = CudaRuntime::client(&cuda_device);
286                    Ok(CubeClient::Cuda(client))
287                }
288                #[cfg(not(feature = "cuda"))]
289                {
290                    let _ = idx;
291                    Err(ferrotorch_core::FerrotorchError::DeviceUnavailable)
292                }
293            }
294            CubeDevice::Rocm(idx) => {
295                #[cfg(feature = "rocm")]
296                {
297                    let amd_device = AmdDevice { index: idx };
298                    let client = HipRuntime::client(&amd_device);
299                    Ok(CubeClient::Rocm(client))
300                }
301                #[cfg(not(feature = "rocm"))]
302                {
303                    let _ = idx;
304                    Err(ferrotorch_core::FerrotorchError::DeviceUnavailable)
305                }
306            }
307        }
308    }
309}
310
311#[cfg(feature = "wgpu")]
312fn wgpu_device_for_index(index: usize) -> WgpuDevice {
313    match index {
314        // Index 0 maps to the system default adapter; this is the most
315        // portable choice and matches how ferrotorch-gpu selects a GPU.
316        0 => WgpuDevice::DefaultDevice,
317        // Higher indices explicitly select a discrete GPU slot.
318        n => WgpuDevice::DiscreteGpu(n),
319    }
320}
321
322// ---------------------------------------------------------------------------
323// Tests
324// ---------------------------------------------------------------------------
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn cube_device_ordinal() {
332        assert_eq!(CubeDevice::Cuda(3).ordinal(), 3);
333        assert_eq!(CubeDevice::Wgpu(1).ordinal(), 1);
334        assert_eq!(CubeDevice::Rocm(0).ordinal(), 0);
335    }
336
337    #[test]
338    fn cube_device_backend_name() {
339        assert_eq!(CubeDevice::Cuda(0).backend_name(), "cuda");
340        assert_eq!(CubeDevice::Wgpu(0).backend_name(), "wgpu");
341        assert_eq!(CubeDevice::Rocm(0).backend_name(), "rocm");
342    }
343
344    #[test]
345    fn cube_device_display() {
346        assert_eq!(CubeDevice::Cuda(2).to_string(), "cuda:2");
347        assert_eq!(CubeDevice::Wgpu(0).to_string(), "wgpu:0");
348        assert_eq!(CubeDevice::Rocm(1).to_string(), "rocm:1");
349    }
350
351    #[test]
352    fn cube_device_equality() {
353        assert_eq!(CubeDevice::Cuda(0), CubeDevice::Cuda(0));
354        assert_ne!(CubeDevice::Cuda(0), CubeDevice::Cuda(1));
355        assert_ne!(CubeDevice::Cuda(0), CubeDevice::Wgpu(0));
356    }
357
358    #[test]
359    fn cube_device_clone_and_hash() {
360        use std::collections::HashSet;
361        let mut set = HashSet::new();
362        set.insert(CubeDevice::Cuda(0));
363        set.insert(CubeDevice::Wgpu(0));
364        set.insert(CubeDevice::Rocm(0));
365        assert_eq!(set.len(), 3);
366
367        // Duplicate should not increase size.
368        set.insert(CubeDevice::Cuda(0));
369        assert_eq!(set.len(), 3);
370    }
371
372    /// Probe whether wgpu can construct a runtime in the current
373    /// environment. WSL2 lacks a Vulkan ICD by default, so the cubecl-wgpu
374    /// worker thread panics during adapter selection and the panic surfaces
375    /// on the main thread as `RecvError`. Catch that here so tests skip
376    /// cleanly instead of failing.
377    #[cfg(feature = "wgpu")]
378    fn wgpu_probe_runtime() -> Option<CubeRuntime> {
379        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
380            CubeRuntime::new(CubeDevice::Wgpu(0)).ok()
381        }));
382        match result {
383            Ok(Some(rt)) => Some(rt),
384            _ => None,
385        }
386    }
387
388    #[cfg(feature = "wgpu")]
389    #[test]
390    fn wgpu_runtime_new_and_device() {
391        let Some(rt) = wgpu_probe_runtime() else {
392            eprintln!("[ferrotorch-cubecl] wgpu adapter unavailable; skipping");
393            return;
394        };
395        assert_eq!(*rt.device(), CubeDevice::Wgpu(0));
396        // Client should match the selected backend.
397        assert!(matches!(rt.client(), CubeClient::Wgpu(_)));
398    }
399
400    #[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
401    #[test]
402    fn no_backend_feature_yields_device_unavailable() {
403        let err = CubeRuntime::new(CubeDevice::Wgpu(0)).unwrap_err();
404        assert!(matches!(
405            err,
406            ferrotorch_core::FerrotorchError::DeviceUnavailable
407        ));
408    }
409
410    #[test]
411    fn cube_runtime_auto_returns_something_or_none() {
412        // `auto()` may panic on the worker thread if a backend feature is
413        // compiled in but the actual hardware/driver isn't available
414        // (e.g. wgpu in WSL without Vulkan). Catch that and treat it as
415        // "not available" — matching the documented contract that this
416        // function returns `Option`.
417        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(CubeRuntime::auto));
418        match result {
419            Ok(Some(_)) => assert!(CubeRuntime::is_available()),
420            Ok(None) | Err(_) => {
421                // Either no backend feature compiled in, or the backend
422                // feature is present but no adapter exists at runtime.
423                // Both are valid outcomes.
424                eprintln!(
425                    "[ferrotorch-cubecl] auto() returned no runtime (no backend feature or no \
426                     adapter); test passes"
427                );
428            }
429        }
430    }
431
432    #[test]
433    fn cube_runtime_is_available_consistent() {
434        // `is_available()` is a compile-time check (`cfg!(...)`). When a
435        // feature is compiled in but no hardware exists at runtime, `auto()`
436        // may still return `Some` (lazy init succeeds, kernel dispatch
437        // would fail later). We accept that asymmetry here — the test
438        // verifies that "feature compiled in" is at least consistent with
439        // "auto() doesn't return None for compile-time reasons".
440        let available = CubeRuntime::is_available();
441        if !available {
442            // Belt-and-suspenders: when no feature is compiled, auto() must
443            // be None.
444            let auto = std::panic::catch_unwind(std::panic::AssertUnwindSafe(CubeRuntime::auto));
445            assert!(auto.ok().flatten().is_none());
446        }
447    }
448}