Skip to main content

ferrotorch_cubecl/
storage.rs

1//! Concrete [`CubeStorageHandle`] implementation for ferrotorch-cubecl.
2//!
3//! [`CubeclStorageHandle`] wraps a `cubecl::server::Handle` together with an
4//! `Arc<CubeRuntime>` so the runtime (and thus the device memory) stays alive
5//! as long as any tensor holds a reference to this handle.
6//!
7//! Core defines the [`CubeStorageHandle`] trait; this crate provides the only
8//! concrete implementation, keeping the core→cubecl dependency one-way.
9//!
10//! Issue #673: device-resident XPU storage.
11//!
12//! ## REQ status (per `.design/ferrotorch-cubecl/storage.md`)
13//!
14//! Full evidence rows (impl + non-test production consumer + upstream
15//! cites) live in the design doc; this synopsis is a one-line summary per
16//! REQ.
17//!
18//! | REQ | Status | Evidence |
19//! |---|---|---|
20//! | REQ-1 (`CubeclStorageHandle` struct) | SHIPPED | `pub struct CubeclStorageHandle in storage.rs` with `handle/runtime/len/ordinal` fields; consumer `ferrotorch-xpu/src/lib.rs` packs result of `upload_f32` into `TensorStorage::xpu_from_handle` |
21//! | REQ-2 (`CubeStorageHandle` trait impl) | SHIPPED | `impl CubeStorageHandle for CubeclStorageHandle in storage.rs` with `as_any/len/ordinal/read_to_host/clone_handle`; consumer `ferrotorch-core::TensorStorage::Cubecl` invokes via the trait on `Tensor::cpu()` readback |
22//! | REQ-3 (`from_raw`) | SHIPPED | `pub fn from_raw in storage.rs`; consumer `pub fn wrap_kernel_output in storage.rs` invoked from `ferrotorch-xpu/src/lib.rs` |
23//! | REQ-4 (`raw_handle`) | SHIPPED | `pub fn raw_handle in storage.rs`; consumer `ferrotorch-cubecl/src/ops.rs` dispatch macros call `ha.raw_handle().clone()` |
24//! | REQ-5 (`runtime` accessor) | SHIPPED | `pub fn runtime in storage.rs`; consumer downstream code reaching the handle via `cubecl_handle_of` to chain into read-back |
25//! | REQ-6 (`wrap_kernel_output`) | SHIPPED | `pub fn wrap_kernel_output in storage.rs`; consumer `ferrotorch-xpu/src/lib.rs` `xpu_binary!`/`xpu_unary!`/`xpu_polynomial!` macro expansions |
26//! | REQ-7 (`upload_f32`) | SHIPPED | `pub fn upload_f32 in storage.rs` with cfg arms (`create_from_slice` on feature, `DeviceUnavailable` off); consumer `ferrotorch-xpu/src/lib.rs::make_xpu_tensor` |
27//! | REQ-8 (`cubecl_handle_of`) | SHIPPED | `pub fn cubecl_handle_of in storage.rs` using `Any::downcast_ref`; consumer `ferrotorch-cubecl/src/ops.rs` dispatch macros |
28
29use std::sync::Arc;
30
31#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
32use ferrotorch_core::FerrotorchError;
33use ferrotorch_core::FerrotorchResult;
34use ferrotorch_core::storage::CubeStorageHandle;
35
36use crate::runtime::CubeRuntime;
37
38// ---------------------------------------------------------------------------
39// CubeclStorageHandle
40// ---------------------------------------------------------------------------
41
42/// Device-resident buffer handle for CubeCL-backed tensors.
43///
44/// Holds:
45/// - A `cubecl::server::Handle` pointing to GPU memory.
46/// - An `Arc<CubeRuntime>` so the device client stays alive.
47/// - `len`: element count (`f32` elements).
48/// - `ordinal`: device ordinal.
49///
50/// This is the concrete type stored inside `StorageBuffer::Cubecl` for XPU
51/// tensors. Constructed by [`upload_f32`].
52#[derive(Debug)]
53pub struct CubeclStorageHandle {
54    handle: cubecl::server::Handle,
55    runtime: Arc<CubeRuntime>,
56    len: usize,
57    ordinal: usize,
58}
59
60impl CubeclStorageHandle {
61    /// Construct a handle from its parts (internal use by `upload_f32`).
62    #[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
63    fn new(
64        handle: cubecl::server::Handle,
65        runtime: Arc<CubeRuntime>,
66        len: usize,
67        ordinal: usize,
68    ) -> Self {
69        Self {
70            handle,
71            runtime,
72            len,
73            ordinal,
74        }
75    }
76
77    /// Construct a handle from a raw `cubecl::server::Handle` returned by a
78    /// kernel launcher.
79    ///
80    /// Used by `ferrotorch-xpu`'s `wrap_result_handle` to turn the
81    /// `(cubecl::server::Handle, shape)` pair from `portable_*` into a
82    /// `CubeclStorageHandle` without an extra H2D upload. Issue #673.
83    pub fn from_raw(
84        handle: cubecl::server::Handle,
85        runtime: Arc<CubeRuntime>,
86        len: usize,
87        ordinal: usize,
88    ) -> Self {
89        Self {
90            handle,
91            runtime,
92            len,
93            ordinal,
94        }
95    }
96
97    /// Borrow the raw `cubecl::server::Handle`.
98    ///
99    /// Used by `ops.rs` to pass handles directly to kernel launchers without
100    /// an extra H2D upload. Issue #673.
101    pub fn raw_handle(&self) -> &cubecl::server::Handle {
102        &self.handle
103    }
104
105    /// Borrow the `CubeRuntime` this handle belongs to.
106    pub fn runtime(&self) -> &Arc<CubeRuntime> {
107        &self.runtime
108    }
109}
110
111impl CubeStorageHandle for CubeclStorageHandle {
112    fn as_any(&self) -> &dyn std::any::Any {
113        self
114    }
115
116    fn len(&self) -> usize {
117        self.len
118    }
119
120    fn ordinal(&self) -> usize {
121        self.ordinal
122    }
123
124    fn read_to_host(&self) -> FerrotorchResult<Vec<f32>> {
125        // `read_f32s` is only available when a backend feature is compiled in.
126        // Without a backend, this path is unreachable because the handle can
127        // only be constructed via `upload_f32`, which also requires a feature.
128        #[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
129        {
130            self.runtime.read_f32s(self.handle.clone(), self.len)
131        }
132        #[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
133        {
134            Err(FerrotorchError::DeviceUnavailable)
135        }
136    }
137
138    fn clone_handle(&self) -> Box<dyn CubeStorageHandle> {
139        // `cubecl::server::Handle` clone is cheap — it bumps an internal
140        // ref count in the cubecl server's handle table, not a buffer copy.
141        Box::new(CubeclStorageHandle {
142            handle: self.handle.clone(),
143            runtime: Arc::clone(&self.runtime),
144            len: self.len,
145            ordinal: self.ordinal,
146        })
147    }
148}
149
150// ---------------------------------------------------------------------------
151// Result-wrapping helper — for ferrotorch-xpu to use without depending on cubecl directly
152// ---------------------------------------------------------------------------
153
154/// Wrap the `(cubecl::server::Handle, Vec<usize>)` result of a `portable_*`
155/// kernel call into a `CubeclStorageHandle`.
156///
157/// `ferrotorch-xpu` uses this so it never needs to name `cubecl::server::Handle`
158/// directly (cubecl is not a direct dep of that crate). Issue #673.
159pub fn wrap_kernel_output(
160    handle: cubecl::server::Handle,
161    shape: &[usize],
162    runtime: Arc<CubeRuntime>,
163    ordinal: usize,
164) -> CubeclStorageHandle {
165    let numel: usize = shape.iter().product();
166    CubeclStorageHandle::from_raw(handle, runtime, numel, ordinal)
167}
168
169// ---------------------------------------------------------------------------
170// H2D upload helper
171// ---------------------------------------------------------------------------
172
173/// Upload a host `f32` slice to device memory, returning a
174/// [`CubeclStorageHandle`] wrapping the device-resident buffer.
175///
176/// This is the single H2D upload point for the XPU path. The caller wraps the
177/// returned handle in `TensorStorage::xpu_from_handle` to produce XPU storage.
178///
179/// # Errors
180///
181/// Returns `DeviceUnavailable` if no backend feature is compiled in.
182#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
183pub fn upload_f32(
184    data: &[f32],
185    runtime: Arc<CubeRuntime>,
186    ordinal: usize,
187) -> FerrotorchResult<CubeclStorageHandle> {
188    use crate::runtime::CubeClient;
189    use cubecl::prelude::*;
190
191    let bytes = f32::as_bytes(data);
192    let handle = match runtime.client() {
193        #[cfg(feature = "wgpu")]
194        CubeClient::Wgpu(c) => c.create_from_slice(bytes),
195        #[cfg(feature = "cuda")]
196        CubeClient::Cuda(c) => c.create_from_slice(bytes),
197        #[cfg(feature = "rocm")]
198        CubeClient::Rocm(c) => c.create_from_slice(bytes),
199        // #1083: Stub is reserved for tests that exercise pre-dispatch
200        // paths only; uploading data through a Stub runtime would imply
201        // a kernel could subsequently consume the buffer, which the
202        // dispatch macros refuse.
203        CubeClient::Stub => unreachable!(
204            "CubeClient::Stub reached upload_f32 — Stub runtimes must not \
205             reach kernel buffers; shape check or signature pin should fire \
206             first (#1083)"
207        ),
208    };
209    Ok(CubeclStorageHandle::new(
210        handle,
211        runtime,
212        data.len(),
213        ordinal,
214    ))
215}
216
217#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
218pub fn upload_f32(
219    _data: &[f32],
220    _runtime: Arc<CubeRuntime>,
221    _ordinal: usize,
222) -> FerrotorchResult<CubeclStorageHandle> {
223    Err(FerrotorchError::DeviceUnavailable)
224}
225
226// ---------------------------------------------------------------------------
227// Handle extraction helper — used by ops.rs to avoid re-uploading
228// ---------------------------------------------------------------------------
229
230/// Extract a `&CubeclStorageHandle` from a tensor's storage, if present.
231///
232/// Returns `None` when the tensor is not backed by a CubeCL device buffer
233/// (e.g. it is a CPU tensor). Ops use this to route device-resident inputs
234/// through the handle-direct kernel path (no H2D upload) vs. the slice-upload
235/// fallback path for CPU tensors passed to an XPU op.
236///
237/// # Example (in ops.rs)
238///
239/// ```ignore
240/// match (cubecl_handle_of(a), cubecl_handle_of(b)) {
241///     (Some(ha), Some(hb)) => run_add_handles(client, ha, hb),
242///     _ => { /* slice-upload fallback */ }
243/// }
244/// ```
245pub fn cubecl_handle_of(t: &ferrotorch_core::Tensor<f32>) -> Option<&CubeclStorageHandle> {
246    t.inner_storage_arc()
247        .cubecl_handle()
248        .and_then(|h| h.as_any().downcast_ref::<CubeclStorageHandle>())
249}