ferrotorch_cubecl/storage.rs
1//! Concrete [`CubeStorageHandle`] implementation for ferrotorch-cubecl.
2//!
3//! [`CubeclStorageHandle`] wraps a `cubecl::server::Handle` together with an
4//! `Arc<CubeRuntime>` so the runtime (and thus the device memory) stays alive
5//! as long as any tensor holds a reference to this handle.
6//!
7//! Core defines the [`CubeStorageHandle`] trait; this crate provides the only
8//! concrete implementation, keeping the core→cubecl dependency one-way.
9//!
10//! Issue #673: device-resident XPU storage.
11//!
12//! ## REQ status (per `.design/ferrotorch-cubecl/storage.md`)
13//!
14//! Full evidence rows (impl + non-test production consumer + upstream
15//! cites) live in the design doc; this synopsis is a one-line summary per
16//! REQ.
17//!
18//! | REQ | Status | Evidence |
19//! |---|---|---|
20//! | REQ-1 (`CubeclStorageHandle` struct) | SHIPPED | `pub struct CubeclStorageHandle in storage.rs` with `handle/runtime/len/ordinal` fields; consumer `ferrotorch-xpu/src/lib.rs` packs result of `upload_f32` into `TensorStorage::xpu_from_handle` |
21//! | REQ-2 (`CubeStorageHandle` trait impl) | SHIPPED | `impl CubeStorageHandle for CubeclStorageHandle in storage.rs` with `as_any/len/ordinal/read_to_host/clone_handle`; consumer `ferrotorch-core::TensorStorage::Cubecl` invokes via the trait on `Tensor::cpu()` readback |
22//! | REQ-3 (`from_raw`) | SHIPPED | `pub fn from_raw in storage.rs`; consumer `pub fn wrap_kernel_output in storage.rs` invoked from `ferrotorch-xpu/src/lib.rs` |
23//! | REQ-4 (`raw_handle`) | SHIPPED | `pub fn raw_handle in storage.rs`; consumer `ferrotorch-cubecl/src/ops.rs` dispatch macros call `ha.raw_handle().clone()` |
24//! | REQ-5 (`runtime` accessor) | SHIPPED | `pub fn runtime in storage.rs`; consumer downstream code reaching the handle via `cubecl_handle_of` to chain into read-back |
25//! | REQ-6 (`wrap_kernel_output`) | SHIPPED | `pub fn wrap_kernel_output in storage.rs`; consumer `ferrotorch-xpu/src/lib.rs` `xpu_binary!`/`xpu_unary!`/`xpu_polynomial!` macro expansions |
26//! | REQ-7 (`upload_f32`) | SHIPPED | `pub fn upload_f32 in storage.rs` with cfg arms (`create_from_slice` on feature, `DeviceUnavailable` off); consumer `ferrotorch-xpu/src/lib.rs::make_xpu_tensor` |
27//! | REQ-8 (`cubecl_handle_of`) | SHIPPED | `pub fn cubecl_handle_of in storage.rs` using `Any::downcast_ref`; consumer `ferrotorch-cubecl/src/ops.rs` dispatch macros |
28
29use std::sync::Arc;
30
31#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
32use ferrotorch_core::FerrotorchError;
33use ferrotorch_core::FerrotorchResult;
34use ferrotorch_core::storage::CubeStorageHandle;
35
36use crate::runtime::CubeRuntime;
37
38// ---------------------------------------------------------------------------
39// CubeclStorageHandle
40// ---------------------------------------------------------------------------
41
42/// Device-resident buffer handle for CubeCL-backed tensors.
43///
44/// Holds:
45/// - A `cubecl::server::Handle` pointing to GPU memory.
46/// - An `Arc<CubeRuntime>` so the device client stays alive.
47/// - `len`: element count (`f32` elements).
48/// - `ordinal`: device ordinal.
49///
50/// This is the concrete type stored inside `StorageBuffer::Cubecl` for XPU
51/// tensors. Constructed by [`upload_f32`].
52#[derive(Debug)]
53pub struct CubeclStorageHandle {
54 handle: cubecl::server::Handle,
55 runtime: Arc<CubeRuntime>,
56 len: usize,
57 ordinal: usize,
58}
59
60impl CubeclStorageHandle {
61 /// Construct a handle from its parts (internal use by `upload_f32`).
62 #[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
63 fn new(
64 handle: cubecl::server::Handle,
65 runtime: Arc<CubeRuntime>,
66 len: usize,
67 ordinal: usize,
68 ) -> Self {
69 Self {
70 handle,
71 runtime,
72 len,
73 ordinal,
74 }
75 }
76
77 /// Construct a handle from a raw `cubecl::server::Handle` returned by a
78 /// kernel launcher.
79 ///
80 /// Used by `ferrotorch-xpu`'s `wrap_result_handle` to turn the
81 /// `(cubecl::server::Handle, shape)` pair from `portable_*` into a
82 /// `CubeclStorageHandle` without an extra H2D upload. Issue #673.
83 pub fn from_raw(
84 handle: cubecl::server::Handle,
85 runtime: Arc<CubeRuntime>,
86 len: usize,
87 ordinal: usize,
88 ) -> Self {
89 Self {
90 handle,
91 runtime,
92 len,
93 ordinal,
94 }
95 }
96
97 /// Borrow the raw `cubecl::server::Handle`.
98 ///
99 /// Used by `ops.rs` to pass handles directly to kernel launchers without
100 /// an extra H2D upload. Issue #673.
101 pub fn raw_handle(&self) -> &cubecl::server::Handle {
102 &self.handle
103 }
104
105 /// Borrow the `CubeRuntime` this handle belongs to.
106 pub fn runtime(&self) -> &Arc<CubeRuntime> {
107 &self.runtime
108 }
109}
110
111impl CubeStorageHandle for CubeclStorageHandle {
112 fn as_any(&self) -> &dyn std::any::Any {
113 self
114 }
115
116 fn len(&self) -> usize {
117 self.len
118 }
119
120 fn ordinal(&self) -> usize {
121 self.ordinal
122 }
123
124 fn read_to_host(&self) -> FerrotorchResult<Vec<f32>> {
125 // `read_f32s` is only available when a backend feature is compiled in.
126 // Without a backend, this path is unreachable because the handle can
127 // only be constructed via `upload_f32`, which also requires a feature.
128 #[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
129 {
130 self.runtime.read_f32s(self.handle.clone(), self.len)
131 }
132 #[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
133 {
134 Err(FerrotorchError::DeviceUnavailable)
135 }
136 }
137
138 fn clone_handle(&self) -> Box<dyn CubeStorageHandle> {
139 // `cubecl::server::Handle` clone is cheap — it bumps an internal
140 // ref count in the cubecl server's handle table, not a buffer copy.
141 Box::new(CubeclStorageHandle {
142 handle: self.handle.clone(),
143 runtime: Arc::clone(&self.runtime),
144 len: self.len,
145 ordinal: self.ordinal,
146 })
147 }
148}
149
150// ---------------------------------------------------------------------------
151// Result-wrapping helper — for ferrotorch-xpu to use without depending on cubecl directly
152// ---------------------------------------------------------------------------
153
154/// Wrap the `(cubecl::server::Handle, Vec<usize>)` result of a `portable_*`
155/// kernel call into a `CubeclStorageHandle`.
156///
157/// `ferrotorch-xpu` uses this so it never needs to name `cubecl::server::Handle`
158/// directly (cubecl is not a direct dep of that crate). Issue #673.
159pub fn wrap_kernel_output(
160 handle: cubecl::server::Handle,
161 shape: &[usize],
162 runtime: Arc<CubeRuntime>,
163 ordinal: usize,
164) -> CubeclStorageHandle {
165 let numel: usize = shape.iter().product();
166 CubeclStorageHandle::from_raw(handle, runtime, numel, ordinal)
167}
168
169// ---------------------------------------------------------------------------
170// H2D upload helper
171// ---------------------------------------------------------------------------
172
173/// Upload a host `f32` slice to device memory, returning a
174/// [`CubeclStorageHandle`] wrapping the device-resident buffer.
175///
176/// This is the single H2D upload point for the XPU path. The caller wraps the
177/// returned handle in `TensorStorage::xpu_from_handle` to produce XPU storage.
178///
179/// # Errors
180///
181/// Returns `DeviceUnavailable` if no backend feature is compiled in.
182#[cfg(any(feature = "wgpu", feature = "cuda", feature = "rocm"))]
183pub fn upload_f32(
184 data: &[f32],
185 runtime: Arc<CubeRuntime>,
186 ordinal: usize,
187) -> FerrotorchResult<CubeclStorageHandle> {
188 use crate::runtime::CubeClient;
189 use cubecl::prelude::*;
190
191 let bytes = f32::as_bytes(data);
192 let handle = match runtime.client() {
193 #[cfg(feature = "wgpu")]
194 CubeClient::Wgpu(c) => c.create_from_slice(bytes),
195 #[cfg(feature = "cuda")]
196 CubeClient::Cuda(c) => c.create_from_slice(bytes),
197 #[cfg(feature = "rocm")]
198 CubeClient::Rocm(c) => c.create_from_slice(bytes),
199 // #1083: Stub is reserved for tests that exercise pre-dispatch
200 // paths only; uploading data through a Stub runtime would imply
201 // a kernel could subsequently consume the buffer, which the
202 // dispatch macros refuse.
203 CubeClient::Stub => unreachable!(
204 "CubeClient::Stub reached upload_f32 — Stub runtimes must not \
205 reach kernel buffers; shape check or signature pin should fire \
206 first (#1083)"
207 ),
208 };
209 Ok(CubeclStorageHandle::new(
210 handle,
211 runtime,
212 data.len(),
213 ordinal,
214 ))
215}
216
217#[cfg(not(any(feature = "wgpu", feature = "cuda", feature = "rocm")))]
218pub fn upload_f32(
219 _data: &[f32],
220 _runtime: Arc<CubeRuntime>,
221 _ordinal: usize,
222) -> FerrotorchResult<CubeclStorageHandle> {
223 Err(FerrotorchError::DeviceUnavailable)
224}
225
226// ---------------------------------------------------------------------------
227// Handle extraction helper — used by ops.rs to avoid re-uploading
228// ---------------------------------------------------------------------------
229
230/// Extract a `&CubeclStorageHandle` from a tensor's storage, if present.
231///
232/// Returns `None` when the tensor is not backed by a CubeCL device buffer
233/// (e.g. it is a CPU tensor). Ops use this to route device-resident inputs
234/// through the handle-direct kernel path (no H2D upload) vs. the slice-upload
235/// fallback path for CPU tensors passed to an XPU op.
236///
237/// # Example (in ops.rs)
238///
239/// ```ignore
240/// match (cubecl_handle_of(a), cubecl_handle_of(b)) {
241/// (Some(ha), Some(hb)) => run_add_handles(client, ha, hb),
242/// _ => { /* slice-upload fallback */ }
243/// }
244/// ```
245pub fn cubecl_handle_of(t: &ferrotorch_core::Tensor<f32>) -> Option<&CubeclStorageHandle> {
246 t.inner_storage_arc()
247 .cubecl_handle()
248 .and_then(|h| h.as_any().downcast_ref::<CubeclStorageHandle>())
249}