Skip to main content

ferrotorch_gpu/
backend_impl.rs

1//! CUDA implementation of the [`GpuBackend`] trait from ferrotorch-core.
2//!
3//! This module bridges the existing GPU operations (`gpu_add`, `gpu_matmul_f32`,
4//! etc.) to the type-erased [`GpuBackend`] dispatch interface, enabling
5//! ferrotorch-core to call GPU operations without depending on this crate
6//! directly.
7//!
8//! # Initialization
9//!
10//! Call [`init_cuda_backend`] once at startup (typically via `ferrotorch::init()`).
11//! This creates a [`CudaBackendImpl`], initializes CUDA device 0, and registers
12//! it with [`ferrotorch_core::gpu_dispatch::register_gpu_backend`].
13
14use std::sync::Arc;
15
16use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
17use ferrotorch_core::gpu_dispatch::{GpuBackend, GpuBufferHandle, GpuRngState};
18
19use crate::buffer::CudaBuffer;
20use crate::device::GpuDevice;
21
22// ---------------------------------------------------------------------------
23// CudaBackendImpl
24// ---------------------------------------------------------------------------
25
26/// CUDA implementation of the [`GpuBackend`] trait.
27///
28/// Holds one or more [`GpuDevice`] handles (currently device 0 only) and
29/// delegates every trait method to the corresponding function in
30/// [`crate::kernels`], [`crate::blas`], or [`crate::transfer`].
31pub struct CudaBackendImpl {
32    devices: Vec<Arc<GpuDevice>>,
33}
34
35impl CudaBackendImpl {
36    /// Create a new CUDA backend, initializing device 0.
37    ///
38    /// # Errors
39    ///
40    /// Returns [`FerrotorchError::InvalidArgument`] if CUDA initialization fails
41    /// (e.g. no GPU available, driver not loaded).
42    pub fn new() -> FerrotorchResult<Self> {
43        let device = Arc::new(
44            GpuDevice::new(0).map_err(|e| FerrotorchError::InvalidArgument {
45                message: format!("CUDA init failed: {e}"),
46            })?,
47        );
48        Ok(Self {
49            devices: vec![device],
50        })
51    }
52
53    /// Get the device for ordinal 0 (the default device).
54    pub fn default_device(&self) -> FerrotorchResult<&Arc<GpuDevice>> {
55        self.device(0)
56    }
57
58    /// Look up a device by ordinal.
59    fn device(&self, ordinal: usize) -> FerrotorchResult<&Arc<GpuDevice>> {
60        self.devices
61            .get(ordinal)
62            .ok_or(FerrotorchError::InvalidArgument {
63                message: format!("CUDA device {ordinal} not available"),
64            })
65    }
66
67    /// Wrap a `CudaBuffer<f32>` into a type-erased [`GpuBufferHandle`].
68    fn wrap_buffer(buf: CudaBuffer<f32>, ordinal: usize) -> GpuBufferHandle {
69        let len = buf.len();
70        GpuBufferHandle::new(Box::new(buf), ordinal, len)
71    }
72
73    /// Wrap a `CudaBuffer<f64>` into a type-erased [`GpuBufferHandle`].
74    fn wrap_buffer_f64(buf: CudaBuffer<f64>, ordinal: usize) -> GpuBufferHandle {
75        let len = buf.len();
76        GpuBufferHandle::new(Box::new(buf), ordinal, len)
77    }
78
79    /// Extract a `&CudaBuffer<f32>` from a [`GpuBufferHandle`].
80    fn unwrap_buffer(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f32>> {
81        handle
82            .downcast_ref::<CudaBuffer<f32>>()
83            .ok_or(FerrotorchError::InvalidArgument {
84                message: "GPU handle does not contain a CudaBuffer<f32>".into(),
85            })
86    }
87
88    /// Extract a `&mut CudaBuffer<f32>` from a [`GpuBufferHandle`].
89    fn unwrap_buffer_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f32>> {
90        handle
91            .downcast_mut::<CudaBuffer<f32>>()
92            .ok_or(FerrotorchError::InvalidArgument {
93                message: "GPU handle does not contain a CudaBuffer<f32>".into(),
94            })
95    }
96
97    /// Extract a `&CudaBuffer<f64>` from a [`GpuBufferHandle`].
98    fn unwrap_buffer_f64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f64>> {
99        handle
100            .downcast_ref::<CudaBuffer<f64>>()
101            .ok_or(FerrotorchError::InvalidArgument {
102                message: "GPU handle does not contain a CudaBuffer<f64>".into(),
103            })
104    }
105
106    /// Convert a [`crate::error::GpuError`] into a [`FerrotorchError`].
107    fn map_gpu_err(e: crate::error::GpuError) -> FerrotorchError {
108        FerrotorchError::InvalidArgument {
109            message: format!("{e}"),
110        }
111    }
112}
113
114// ---------------------------------------------------------------------------
115// GpuBackend implementation
116// ---------------------------------------------------------------------------
117
118impl GpuBackend for CudaBackendImpl {
119    fn as_any(&self) -> &dyn std::any::Any {
120        self
121    }
122
123    fn cpu_to_gpu(
124        &self,
125        data: &[u8],
126        elem_size: usize,
127        device: usize,
128    ) -> FerrotorchResult<GpuBufferHandle> {
129        let dev = self.device(device)?;
130        match elem_size {
131            4 => {
132                // SAFETY: The caller (ferrotorch-core) guarantees that `data`
133                // was originally an f32 slice serialised to bytes.
134                let count = data.len() / 4;
135                let f32_data: &[f32] =
136                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
137                let buf = crate::transfer::cpu_to_gpu(f32_data, dev).map_err(Self::map_gpu_err)?;
138                Ok(Self::wrap_buffer(buf, device))
139            }
140            8 => {
141                // SAFETY: The caller (ferrotorch-core) guarantees that `data`
142                // was originally an f64 slice serialised to bytes.
143                let count = data.len() / 8;
144                let f64_data: &[f64] =
145                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
146                let buf = crate::transfer::cpu_to_gpu(f64_data, dev).map_err(Self::map_gpu_err)?;
147                Ok(Self::wrap_buffer_f64(buf, device))
148            }
149            other => Err(FerrotorchError::InvalidArgument {
150                message: format!("cpu_to_gpu: unsupported elem_size {other} (expected 4 or 8)"),
151            }),
152        }
153    }
154
155    fn cpu_to_gpu_pinned(
156        &self,
157        data: &[u8],
158        elem_size: usize,
159        device: usize,
160    ) -> FerrotorchResult<GpuBufferHandle> {
161        let dev = self.device(device)?;
162        match elem_size {
163            4 => {
164                let count = data.len() / 4;
165                let f32_data: &[f32] =
166                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
167                let buf = crate::transfer::cpu_to_gpu_pinned(f32_data, dev)
168                    .map_err(Self::map_gpu_err)?;
169                Ok(Self::wrap_buffer(buf, device))
170            }
171            8 => {
172                let count = data.len() / 8;
173                let f64_data: &[f64] =
174                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
175                let buf = crate::transfer::cpu_to_gpu_pinned(f64_data, dev)
176                    .map_err(Self::map_gpu_err)?;
177                Ok(Self::wrap_buffer_f64(buf, device))
178            }
179            other => Err(FerrotorchError::InvalidArgument {
180                message: format!(
181                    "cpu_to_gpu_pinned: unsupported elem_size {other} (expected 4 or 8)"
182                ),
183            }),
184        }
185    }
186
187    fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>> {
188        let dev = self.device(handle.device_ordinal())?;
189
190        // Try f32 first, then f64.
191        if let Ok(buf) = Self::unwrap_buffer(handle) {
192            let f32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
193
194            // Reinterpret Vec<f32> as Vec<u8> without copying.
195            // SAFETY: f32 has alignment 4 and size 4. We adjust len and capacity
196            // accordingly. The original Vec is consumed via ManuallyDrop so its
197            // destructor won't free the allocation.
198            let bytes = unsafe {
199                let mut v = std::mem::ManuallyDrop::new(f32_data);
200                let ptr = v.as_mut_ptr() as *mut u8;
201                let len = v.len() * 4;
202                let cap = v.capacity() * 4;
203                Vec::from_raw_parts(ptr, len, cap)
204            };
205            Ok(bytes)
206        } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
207            let f64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
208
209            // Reinterpret Vec<f64> as Vec<u8> without copying.
210            // SAFETY: f64 has alignment 8 and size 8. We adjust len and capacity
211            // accordingly. The original Vec is consumed via ManuallyDrop so its
212            // destructor won't free the allocation.
213            let bytes = unsafe {
214                let mut v = std::mem::ManuallyDrop::new(f64_data);
215                let ptr = v.as_mut_ptr() as *mut u8;
216                let len = v.len() * 8;
217                let cap = v.capacity() * 8;
218                Vec::from_raw_parts(ptr, len, cap)
219            };
220            Ok(bytes)
221        } else {
222            Err(FerrotorchError::InvalidArgument {
223                message: "gpu_to_cpu: handle is neither CudaBuffer<f32> nor CudaBuffer<f64>".into(),
224            })
225        }
226    }
227
228    fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
229        // Clone via GPU -> CPU -> GPU round-trip.
230        // Correct but not optimal; a device-to-device memcpy would be better.
231        let bytes = self.gpu_to_cpu(handle)?;
232        // Determine elem_size from the concrete buffer type.
233        let elem_size = if handle.downcast_ref::<CudaBuffer<f64>>().is_some() {
234            8
235        } else {
236            4
237        };
238        self.cpu_to_gpu(&bytes, elem_size, handle.device_ordinal())
239    }
240
241    fn alloc_zeros(
242        &self,
243        len: usize,
244        elem_size: usize,
245        device: usize,
246    ) -> FerrotorchResult<GpuBufferHandle> {
247        let dev = self.device(device)?;
248        match elem_size {
249            4 => {
250                let buf = crate::transfer::alloc_zeros_f32(len, dev).map_err(Self::map_gpu_err)?;
251                Ok(Self::wrap_buffer(buf, device))
252            }
253            8 => {
254                let buf = crate::transfer::alloc_zeros_f64(len, dev).map_err(Self::map_gpu_err)?;
255                Ok(Self::wrap_buffer_f64(buf, device))
256            }
257            other => Err(FerrotorchError::InvalidArgument {
258                message: format!("alloc_zeros: unsupported elem_size {other} (expected 4 or 8)"),
259            }),
260        }
261    }
262
263    // -- Elementwise f32 ------------------------------------------------------
264
265    fn add_f32(
266        &self,
267        a: &GpuBufferHandle,
268        b: &GpuBufferHandle,
269    ) -> FerrotorchResult<GpuBufferHandle> {
270        let a_buf = Self::unwrap_buffer(a)?;
271        let b_buf = Self::unwrap_buffer(b)?;
272        let dev = self.device(a.device_ordinal())?;
273        let result = crate::kernels::gpu_add(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
274        Ok(Self::wrap_buffer(result, a.device_ordinal()))
275    }
276
277    fn sub_f32(
278        &self,
279        a: &GpuBufferHandle,
280        b: &GpuBufferHandle,
281    ) -> FerrotorchResult<GpuBufferHandle> {
282        let a_buf = Self::unwrap_buffer(a)?;
283        let b_buf = Self::unwrap_buffer(b)?;
284        let dev = self.device(a.device_ordinal())?;
285        let result = crate::kernels::gpu_sub(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
286        Ok(Self::wrap_buffer(result, a.device_ordinal()))
287    }
288
289    fn mul_f32(
290        &self,
291        a: &GpuBufferHandle,
292        b: &GpuBufferHandle,
293    ) -> FerrotorchResult<GpuBufferHandle> {
294        let a_buf = Self::unwrap_buffer(a)?;
295        let b_buf = Self::unwrap_buffer(b)?;
296        let dev = self.device(a.device_ordinal())?;
297        let result = crate::kernels::gpu_mul(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
298        Ok(Self::wrap_buffer(result, a.device_ordinal()))
299    }
300
301    fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
302        let a_buf = Self::unwrap_buffer(a)?;
303        let dev = self.device(a.device_ordinal())?;
304        let result = crate::kernels::gpu_neg(a_buf, dev).map_err(Self::map_gpu_err)?;
305        Ok(Self::wrap_buffer(result, a.device_ordinal()))
306    }
307
308    fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
309        let a_buf = Self::unwrap_buffer(a)?;
310        let dev = self.device(a.device_ordinal())?;
311        let result = crate::kernels::gpu_relu(a_buf, dev).map_err(Self::map_gpu_err)?;
312        Ok(Self::wrap_buffer(result, a.device_ordinal()))
313    }
314
315    fn div_f32(
316        &self,
317        a: &GpuBufferHandle,
318        b: &GpuBufferHandle,
319    ) -> FerrotorchResult<GpuBufferHandle> {
320        let a_buf = Self::unwrap_buffer(a)?;
321        let b_buf = Self::unwrap_buffer(b)?;
322        let dev = self.device(a.device_ordinal())?;
323        let result = crate::kernels::gpu_div(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
324        Ok(Self::wrap_buffer(result, a.device_ordinal()))
325    }
326
327    fn exp_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
328        let a_buf = Self::unwrap_buffer(a)?;
329        let dev = self.device(a.device_ordinal())?;
330        let result = crate::kernels::gpu_exp(a_buf, dev).map_err(Self::map_gpu_err)?;
331        Ok(Self::wrap_buffer(result, a.device_ordinal()))
332    }
333
334    fn log_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
335        let a_buf = Self::unwrap_buffer(a)?;
336        let dev = self.device(a.device_ordinal())?;
337        let result = crate::kernels::gpu_log(a_buf, dev).map_err(Self::map_gpu_err)?;
338        Ok(Self::wrap_buffer(result, a.device_ordinal()))
339    }
340
341    fn sqrt_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
342        let a_buf = Self::unwrap_buffer(a)?;
343        let dev = self.device(a.device_ordinal())?;
344        let result = crate::kernels::gpu_sqrt(a_buf, dev).map_err(Self::map_gpu_err)?;
345        Ok(Self::wrap_buffer(result, a.device_ordinal()))
346    }
347
348    fn pow_f32(&self, a: &GpuBufferHandle, exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
349        let a_buf = Self::unwrap_buffer(a)?;
350        let dev = self.device(a.device_ordinal())?;
351        let result =
352            crate::kernels::gpu_pow(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
353        Ok(Self::wrap_buffer(result, a.device_ordinal()))
354    }
355
356    fn abs_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
357        let a_buf = Self::unwrap_buffer(a)?;
358        let dev = self.device(a.device_ordinal())?;
359        let result = crate::kernels::gpu_abs(a_buf, dev).map_err(Self::map_gpu_err)?;
360        Ok(Self::wrap_buffer(result, a.device_ordinal()))
361    }
362
363    fn sigmoid_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
364        let a_buf = Self::unwrap_buffer(a)?;
365        let dev = self.device(a.device_ordinal())?;
366        let result = crate::kernels::gpu_sigmoid(a_buf, dev).map_err(Self::map_gpu_err)?;
367        Ok(Self::wrap_buffer(result, a.device_ordinal()))
368    }
369
370    fn tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
371        let a_buf = Self::unwrap_buffer(a)?;
372        let dev = self.device(a.device_ordinal())?;
373        let result = crate::kernels::gpu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
374        Ok(Self::wrap_buffer(result, a.device_ordinal()))
375    }
376
377    #[allow(clippy::too_many_arguments)]
378    fn fused_adam_f32(
379        &self,
380        param: &mut GpuBufferHandle,
381        grad: &GpuBufferHandle,
382        exp_avg: &mut GpuBufferHandle,
383        exp_avg_sq: &mut GpuBufferHandle,
384        beta1: f32,
385        beta2: f32,
386        lr: f32,
387        eps: f32,
388        bc1: f32,
389        bc2: f32,
390        weight_decay: f32,
391    ) -> FerrotorchResult<()> {
392        let ordinal = param.device_ordinal();
393        let dev = self.device(ordinal)?;
394        let p_buf = Self::unwrap_buffer_mut(param)?;
395        let g_buf = Self::unwrap_buffer(grad)?;
396        let m_buf = Self::unwrap_buffer_mut(exp_avg)?;
397        let v_buf = Self::unwrap_buffer_mut(exp_avg_sq)?;
398        crate::kernels::gpu_fused_adam(
399            p_buf,
400            g_buf,
401            m_buf,
402            v_buf,
403            beta1,
404            beta2,
405            lr,
406            eps,
407            bc1,
408            bc2,
409            weight_decay,
410            dev,
411        )
412        .map_err(Self::map_gpu_err)?;
413        Ok(())
414    }
415
416    #[allow(clippy::too_many_arguments)]
417    fn maxpool2d_f32(
418        &self,
419        input: &GpuBufferHandle,
420        batch: usize,
421        channels: usize,
422        h_in: usize,
423        w_in: usize,
424        kh: usize,
425        kw: usize,
426        sh: usize,
427        sw: usize,
428        ph: usize,
429        pw: usize,
430    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
431        let buf = Self::unwrap_buffer(input)?;
432        let dev = self.device(input.device_ordinal())?;
433        let (out, shape) = crate::kernels::gpu_maxpool2d(
434            buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
435        ).map_err(Self::map_gpu_err)?;
436        Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
437    }
438
439    #[allow(clippy::too_many_arguments)]
440    fn avgpool2d_f32(
441        &self,
442        input: &GpuBufferHandle,
443        batch: usize,
444        channels: usize,
445        h_in: usize,
446        w_in: usize,
447        kh: usize,
448        kw: usize,
449        sh: usize,
450        sw: usize,
451        ph: usize,
452        pw: usize,
453    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
454        let buf = Self::unwrap_buffer(input)?;
455        let dev = self.device(input.device_ordinal())?;
456        let (out, shape) = crate::kernels::gpu_avgpool2d(
457            buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
458        ).map_err(Self::map_gpu_err)?;
459        Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
460    }
461
462    #[allow(clippy::too_many_arguments)]
463    fn conv2d_f32(
464        &self,
465        input: &GpuBufferHandle,
466        weight: &GpuBufferHandle,
467        bias: Option<&GpuBufferHandle>,
468        input_shape: [usize; 4],
469        weight_shape: [usize; 4],
470        stride: (usize, usize),
471        padding: (usize, usize),
472    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
473        let input_buf = Self::unwrap_buffer(input)?;
474        let weight_buf = Self::unwrap_buffer(weight)?;
475        let bias_buf = match bias {
476            Some(b) => Some(Self::unwrap_buffer(b)?),
477            None => None,
478        };
479        let dev = self.device(input.device_ordinal())?;
480        let (out_buf, out_shape) = crate::conv::gpu_conv2d_f32(
481            input_buf,
482            weight_buf,
483            bias_buf,
484            input_shape,
485            weight_shape,
486            stride,
487            padding,
488            dev,
489        )
490        .map_err(Self::map_gpu_err)?;
491        Ok((Self::wrap_buffer(out_buf, input.device_ordinal()), out_shape))
492    }
493
494    fn fused_gru_cell_f32(
495        &self,
496        input_gates: &GpuBufferHandle,
497        hidden_gates: &GpuBufferHandle,
498        bias_ih: &GpuBufferHandle,
499        bias_hh: &GpuBufferHandle,
500        hx: &GpuBufferHandle,
501        hidden_size: usize,
502    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
503        let ig = Self::unwrap_buffer(input_gates)?;
504        let hg = Self::unwrap_buffer(hidden_gates)?;
505        let bih = Self::unwrap_buffer(bias_ih)?;
506        let bhh = Self::unwrap_buffer(bias_hh)?;
507        let hx_buf = Self::unwrap_buffer(hx)?;
508        let dev = self.device(input_gates.device_ordinal())?;
509        let (hy, ws) = crate::kernels::gpu_fused_gru_forward(
510            ig, hg, bih, bhh, hx_buf, hidden_size, dev,
511        )
512        .map_err(Self::map_gpu_err)?;
513        let ord = input_gates.device_ordinal();
514        Ok((Self::wrap_buffer(hy, ord), Self::wrap_buffer(ws, ord)))
515    }
516
517    fn synchronize(&self, device: usize) -> FerrotorchResult<()> {
518        let dev = self.device(device)?;
519        dev.stream()
520            .synchronize()
521            .map_err(|e| FerrotorchError::InvalidArgument {
522                message: format!("CUDA synchronize failed: {e}"),
523            })?;
524        Ok(())
525    }
526
527    fn stream_count(&self, device: usize) -> usize {
528        crate::stream::StreamPool::pool_size(device)
529    }
530
531    // -- Linalg f32 -----------------------------------------------------------
532
533    fn matmul_f32(
534        &self,
535        a: &GpuBufferHandle,
536        b: &GpuBufferHandle,
537        m: usize,
538        k: usize,
539        n: usize,
540    ) -> FerrotorchResult<GpuBufferHandle> {
541        let a_buf = Self::unwrap_buffer(a)?;
542        let b_buf = Self::unwrap_buffer(b)?;
543        let dev = self.device(a.device_ordinal())?;
544        let result =
545            crate::blas::gpu_matmul_f32(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
546        Ok(Self::wrap_buffer(result, a.device_ordinal()))
547    }
548
549    // -- Reduction f32 --------------------------------------------------------
550
551    fn sum_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
552        let a_buf = Self::unwrap_buffer(a)?;
553        let dev = self.device(a.device_ordinal())?;
554        let result = crate::kernels::gpu_reduce_sum(a_buf, dev).map_err(Self::map_gpu_err)?;
555        Ok(Self::wrap_buffer(result, a.device_ordinal()))
556    }
557
558    // -- Linalg f64 (cuBLAS DGEMM) --------------------------------------------
559
560    fn matmul_f64(
561        &self,
562        a: &GpuBufferHandle,
563        b: &GpuBufferHandle,
564        m: usize,
565        k: usize,
566        n: usize,
567    ) -> FerrotorchResult<GpuBufferHandle> {
568        let a_buf = Self::unwrap_buffer_f64(a)?;
569        let b_buf = Self::unwrap_buffer_f64(b)?;
570        let dev = self.device(a.device_ordinal())?;
571        let result =
572            crate::blas::gpu_matmul_f64(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
573        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
574    }
575
576    // -- Broadcast binary f32 -------------------------------------------------
577
578    fn broadcast_add_f32(
579        &self,
580        a: &GpuBufferHandle,
581        b: &GpuBufferHandle,
582        a_shape: &[usize],
583        b_shape: &[usize],
584        out_shape: &[usize],
585    ) -> FerrotorchResult<GpuBufferHandle> {
586        let a_buf = Self::unwrap_buffer(a)?;
587        let b_buf = Self::unwrap_buffer(b)?;
588        let dev = self.device(a.device_ordinal())?;
589        let result =
590            crate::kernels::gpu_broadcast_add(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
591                .map_err(Self::map_gpu_err)?;
592        Ok(Self::wrap_buffer(result, a.device_ordinal()))
593    }
594
595    fn broadcast_sub_f32(
596        &self,
597        a: &GpuBufferHandle,
598        b: &GpuBufferHandle,
599        a_shape: &[usize],
600        b_shape: &[usize],
601        out_shape: &[usize],
602    ) -> FerrotorchResult<GpuBufferHandle> {
603        let a_buf = Self::unwrap_buffer(a)?;
604        let b_buf = Self::unwrap_buffer(b)?;
605        let dev = self.device(a.device_ordinal())?;
606        let result =
607            crate::kernels::gpu_broadcast_sub(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
608                .map_err(Self::map_gpu_err)?;
609        Ok(Self::wrap_buffer(result, a.device_ordinal()))
610    }
611
612    fn broadcast_mul_f32(
613        &self,
614        a: &GpuBufferHandle,
615        b: &GpuBufferHandle,
616        a_shape: &[usize],
617        b_shape: &[usize],
618        out_shape: &[usize],
619    ) -> FerrotorchResult<GpuBufferHandle> {
620        let a_buf = Self::unwrap_buffer(a)?;
621        let b_buf = Self::unwrap_buffer(b)?;
622        let dev = self.device(a.device_ordinal())?;
623        let result =
624            crate::kernels::gpu_broadcast_mul(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
625                .map_err(Self::map_gpu_err)?;
626        Ok(Self::wrap_buffer(result, a.device_ordinal()))
627    }
628
629    fn broadcast_div_f32(
630        &self,
631        a: &GpuBufferHandle,
632        b: &GpuBufferHandle,
633        a_shape: &[usize],
634        b_shape: &[usize],
635        out_shape: &[usize],
636    ) -> FerrotorchResult<GpuBufferHandle> {
637        let a_buf = Self::unwrap_buffer(a)?;
638        let b_buf = Self::unwrap_buffer(b)?;
639        let dev = self.device(a.device_ordinal())?;
640        let result =
641            crate::kernels::gpu_broadcast_div(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
642                .map_err(Self::map_gpu_err)?;
643        Ok(Self::wrap_buffer(result, a.device_ordinal()))
644    }
645
646    fn softmax_f32(
647        &self,
648        a: &GpuBufferHandle,
649        rows: usize,
650        cols: usize,
651    ) -> FerrotorchResult<GpuBufferHandle> {
652        let a_buf = Self::unwrap_buffer(a)?;
653        let dev = self.device(a.device_ordinal())?;
654        let result =
655            crate::kernels::gpu_softmax(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
656        Ok(Self::wrap_buffer(result, a.device_ordinal()))
657    }
658
659    fn dropout_f32(
660        &self,
661        a: &GpuBufferHandle,
662        threshold: u32,
663        scale: f32,
664        seed: u32,
665    ) -> FerrotorchResult<GpuBufferHandle> {
666        let a_buf = Self::unwrap_buffer(a)?;
667        let dev = self.device(a.device_ordinal())?;
668        let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, seed, dev)
669            .map_err(Self::map_gpu_err)?;
670        Ok(Self::wrap_buffer(result, a.device_ordinal()))
671    }
672
673    fn dropout_philox_f32(
674        &self,
675        a: &GpuBufferHandle,
676        threshold: u32,
677        scale: f32,
678    ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
679        let device_ordinal = a.device_ordinal();
680        let n = a.len();
681
682        // Snapshot the current RNG state and advance it.
683        let rng_state = {
684            let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
685                FerrotorchError::InvalidArgument {
686                    message: "failed to lock CUDA RNG manager".into(),
687                }
688            })?;
689            let philox_gen = mgr.generator(device_ordinal);
690            let state = philox_gen.get_state();
691            // Advance by ceil(n/4) counters (each counter produces 4 u32 values)
692            let counters_needed = n.div_ceil(4);
693            philox_gen.advance(counters_needed as u64);
694            state
695        };
696
697        // Use the Philox state as the seed for the dropout kernel.
698        // We encode the Philox counter+seed into a u32 seed that the existing
699        // dropout kernel can use. For full correctness on GPU, we should use
700        // the Philox uniform kernel to generate the mask, then apply it.
701        // However, for consistency between GPU forward and CPU backward mask
702        // regeneration, we use the Philox state to deterministically derive a
703        // seed for the existing kernel.
704        let a_buf = Self::unwrap_buffer(a)?;
705        let dev = self.device(device_ordinal)?;
706
707        // Use the Philox counter XOR seed as the dropout kernel's seed.
708        // This gives us deterministic behavior tied to the Philox state.
709        let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
710        let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, derived_seed, dev)
711            .map_err(Self::map_gpu_err)?;
712
713        let gpu_rng_state = GpuRngState {
714            counter: rng_state.counter,
715            seed: rng_state.seed,
716            offset: rng_state.offset,
717            device: device_ordinal,
718        };
719
720        Ok((Self::wrap_buffer(result, device_ordinal), gpu_rng_state))
721    }
722
723    fn transpose_2d_f32(
724        &self,
725        a: &GpuBufferHandle,
726        m: usize,
727        n: usize,
728    ) -> FerrotorchResult<GpuBufferHandle> {
729        let a_buf = Self::unwrap_buffer(a)?;
730        let dev = self.device(a.device_ordinal())?;
731        let result =
732            crate::kernels::gpu_transpose_2d(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
733        Ok(Self::wrap_buffer(result, a.device_ordinal()))
734    }
735
736    fn permute_0213_f32(
737        &self,
738        a: &GpuBufferHandle,
739        d0: usize,
740        d1: usize,
741        d2: usize,
742        d3: usize,
743    ) -> FerrotorchResult<GpuBufferHandle> {
744        let a_buf = Self::unwrap_buffer(a)?;
745        let dev = self.device(a.device_ordinal())?;
746        let result = crate::kernels::gpu_permute_0213(a_buf, d0, d1, d2, d3, dev)
747            .map_err(Self::map_gpu_err)?;
748        Ok(Self::wrap_buffer(result, a.device_ordinal()))
749    }
750
751    fn bmm_f32(
752        &self,
753        a: &GpuBufferHandle,
754        b: &GpuBufferHandle,
755        batch: usize,
756        m: usize,
757        k: usize,
758        n: usize,
759    ) -> FerrotorchResult<GpuBufferHandle> {
760        let a_buf = Self::unwrap_buffer(a)?;
761        let b_buf = Self::unwrap_buffer(b)?;
762        let dev = self.device(a.device_ordinal())?;
763        let result = crate::blas::gpu_bmm_f32(a_buf, b_buf, batch, m, k, n, dev)
764            .map_err(Self::map_gpu_err)?;
765        Ok(Self::wrap_buffer(result, a.device_ordinal()))
766    }
767
768    fn bmm_f16_f32(
769        &self,
770        a: &GpuBufferHandle,
771        b: &GpuBufferHandle,
772        batch: usize,
773        m: usize,
774        k: usize,
775        n: usize,
776    ) -> FerrotorchResult<GpuBufferHandle> {
777        let a_buf = Self::unwrap_buffer(a)?;
778        let b_buf = Self::unwrap_buffer(b)?;
779        let dev = self.device(a.device_ordinal())?;
780        let result = crate::blas::gpu_bmm_f16(a_buf, b_buf, batch, m, k, n, dev)
781            .map_err(Self::map_gpu_err)?;
782        Ok(Self::wrap_buffer(result, a.device_ordinal()))
783    }
784
785    fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
786        let a_buf = Self::unwrap_buffer(a)?;
787        let dev = self.device(a.device_ordinal())?;
788        let result = crate::kernels::gpu_gelu(a_buf, dev).map_err(Self::map_gpu_err)?;
789        Ok(Self::wrap_buffer(result, a.device_ordinal()))
790    }
791
792    fn gelu_tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
793        let a_buf = Self::unwrap_buffer(a)?;
794        let dev = self.device(a.device_ordinal())?;
795        let result = crate::kernels::gpu_gelu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
796        Ok(Self::wrap_buffer(result, a.device_ordinal()))
797    }
798
799    fn gelu_erf_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
800        let a_buf = Self::unwrap_buffer(a)?;
801        let dev = self.device(a.device_ordinal())?;
802        let result = crate::kernels::gpu_gelu_erf(a_buf, dev).map_err(Self::map_gpu_err)?;
803        Ok(Self::wrap_buffer(result, a.device_ordinal()))
804    }
805
806    fn layernorm_f32(
807        &self,
808        input: &GpuBufferHandle,
809        weight: &GpuBufferHandle,
810        bias: &GpuBufferHandle,
811        rows: usize,
812        cols: usize,
813        eps: f32,
814    ) -> FerrotorchResult<GpuBufferHandle> {
815        let in_buf = Self::unwrap_buffer(input)?;
816        let w_buf = Self::unwrap_buffer(weight)?;
817        let b_buf = Self::unwrap_buffer(bias)?;
818        let dev = self.device(input.device_ordinal())?;
819        let result = crate::kernels::gpu_layernorm(in_buf, w_buf, b_buf, rows, cols, eps, dev)
820            .map_err(Self::map_gpu_err)?;
821        Ok(Self::wrap_buffer(result, input.device_ordinal()))
822    }
823
824    fn slice_write_f32(
825        &self,
826        src: &GpuBufferHandle,
827        dst: &mut GpuBufferHandle,
828        n_batch: usize,
829        d: usize,
830        max_len: usize,
831        pos: usize,
832    ) -> FerrotorchResult<()> {
833        let src_buf = Self::unwrap_buffer(src)?;
834        let dst_buf =
835            dst.downcast_mut::<CudaBuffer<f32>>()
836                .ok_or(FerrotorchError::InvalidArgument {
837                    message: "slice_write_f32: dst is not CudaBuffer<f32>".into(),
838                })?;
839        let dev = self.device(src.device_ordinal())?;
840        crate::kernels::gpu_slice_write(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
841            .map_err(Self::map_gpu_err)?;
842        Ok(())
843    }
844
845    fn slice_read_f32(
846        &self,
847        src: &GpuBufferHandle,
848        n_batch: usize,
849        d: usize,
850        len: usize,
851        max_len: usize,
852    ) -> FerrotorchResult<GpuBufferHandle> {
853        let src_buf = Self::unwrap_buffer(src)?;
854        let dev = self.device(src.device_ordinal())?;
855        let result = crate::kernels::gpu_slice_read(src_buf, n_batch, d, len, max_len, dev)
856            .map_err(Self::map_gpu_err)?;
857        Ok(Self::wrap_buffer(result, src.device_ordinal()))
858    }
859
860    fn embed_lookup_f32(
861        &self,
862        idx: &GpuBufferHandle,
863        weight: &GpuBufferHandle,
864        d: usize,
865    ) -> FerrotorchResult<GpuBufferHandle> {
866        let idx_buf = Self::unwrap_buffer(idx)?;
867        let w_buf = Self::unwrap_buffer(weight)?;
868        let dev = self.device(idx.device_ordinal())?;
869        let result =
870            crate::kernels::gpu_embed_lookup(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
871        Ok(Self::wrap_buffer(result, idx.device_ordinal()))
872    }
873
874    fn embed_lookup_batch_f32(
875        &self,
876        indices: &GpuBufferHandle,
877        weight: &GpuBufferHandle,
878        n: usize,
879        d: usize,
880    ) -> FerrotorchResult<GpuBufferHandle> {
881        let idx_buf = Self::unwrap_buffer(indices)?;
882        let w_buf = Self::unwrap_buffer(weight)?;
883        let dev = self.device(indices.device_ordinal())?;
884        let result = crate::kernels::gpu_embed_lookup_batch(idx_buf, w_buf, n, d, dev)
885            .map_err(Self::map_gpu_err)?;
886        Ok(Self::wrap_buffer(result, indices.device_ordinal()))
887    }
888
889    fn scatter_add_rows_f32(
890        &self,
891        grad_output: &GpuBufferHandle,
892        indices: &GpuBufferHandle,
893        num_embeddings: usize,
894        d: usize,
895    ) -> FerrotorchResult<GpuBufferHandle> {
896        let go_buf = Self::unwrap_buffer(grad_output)?;
897        let idx_buf = Self::unwrap_buffer(indices)?;
898        let dev = self.device(grad_output.device_ordinal())?;
899        let result = crate::kernels::gpu_scatter_add_rows(go_buf, idx_buf, num_embeddings, d, dev)
900            .map_err(Self::map_gpu_err)?;
901        Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
902    }
903
904    fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle> {
905        let a_buf = Self::unwrap_buffer(a)?;
906        let dev = self.device(a.device_ordinal())?;
907        let result = crate::kernels::gpu_scale(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
908        Ok(Self::wrap_buffer(result, a.device_ordinal()))
909    }
910
911    fn relu_backward_f32(
912        &self,
913        grad: &GpuBufferHandle,
914        input: &GpuBufferHandle,
915    ) -> FerrotorchResult<GpuBufferHandle> {
916        let grad_buf = Self::unwrap_buffer(grad)?;
917        let input_buf = Self::unwrap_buffer(input)?;
918        let dev = self.device(grad.device_ordinal())?;
919        let result = crate::kernels::gpu_relu_backward(grad_buf, input_buf, dev)
920            .map_err(Self::map_gpu_err)?;
921        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
922    }
923
924    fn gelu_backward_f32(
925        &self,
926        grad: &GpuBufferHandle,
927        input: &GpuBufferHandle,
928    ) -> FerrotorchResult<GpuBufferHandle> {
929        let grad_buf = Self::unwrap_buffer(grad)?;
930        let input_buf = Self::unwrap_buffer(input)?;
931        let dev = self.device(grad.device_ordinal())?;
932        let result = crate::kernels::gpu_gelu_backward(grad_buf, input_buf, dev)
933            .map_err(Self::map_gpu_err)?;
934        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
935    }
936
937    fn gelu_backward_tanh_f32(
938        &self,
939        grad: &GpuBufferHandle,
940        input: &GpuBufferHandle,
941    ) -> FerrotorchResult<GpuBufferHandle> {
942        let grad_buf = Self::unwrap_buffer(grad)?;
943        let input_buf = Self::unwrap_buffer(input)?;
944        let dev = self.device(grad.device_ordinal())?;
945        let result = crate::kernels::gpu_gelu_backward_tanh(grad_buf, input_buf, dev)
946            .map_err(Self::map_gpu_err)?;
947        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
948    }
949
950    fn gelu_backward_erf_f32(
951        &self,
952        grad: &GpuBufferHandle,
953        input: &GpuBufferHandle,
954    ) -> FerrotorchResult<GpuBufferHandle> {
955        let grad_buf = Self::unwrap_buffer(grad)?;
956        let input_buf = Self::unwrap_buffer(input)?;
957        let dev = self.device(grad.device_ordinal())?;
958        let result = crate::kernels::gpu_gelu_backward_erf(grad_buf, input_buf, dev)
959            .map_err(Self::map_gpu_err)?;
960        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
961    }
962
963    fn index_select_1d_f32(
964        &self,
965        input: &GpuBufferHandle,
966        indices: &GpuBufferHandle,
967    ) -> FerrotorchResult<GpuBufferHandle> {
968        let input_buf = Self::unwrap_buffer(input)?;
969        let idx_buf = Self::unwrap_buffer(indices)?;
970        let dev = self.device(input.device_ordinal())?;
971        let result = crate::kernels::gpu_index_select_1d(input_buf, idx_buf, dev)
972            .map_err(Self::map_gpu_err)?;
973        Ok(Self::wrap_buffer(result, input.device_ordinal()))
974    }
975
976    fn scatter_add_1d_f32(
977        &self,
978        grad_output: &GpuBufferHandle,
979        indices: &GpuBufferHandle,
980        input_len: usize,
981    ) -> FerrotorchResult<GpuBufferHandle> {
982        let go_buf = Self::unwrap_buffer(grad_output)?;
983        let idx_buf = Self::unwrap_buffer(indices)?;
984        let dev = self.device(grad_output.device_ordinal())?;
985        let result = crate::kernels::gpu_scatter_add_1d(go_buf, idx_buf, input_len, dev)
986            .map_err(Self::map_gpu_err)?;
987        Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
988    }
989
990    fn masked_fill_f32(
991        &self,
992        input: &GpuBufferHandle,
993        mask: &GpuBufferHandle,
994        value: f32,
995    ) -> FerrotorchResult<GpuBufferHandle> {
996        let input_buf = Self::unwrap_buffer(input)?;
997        let mask_buf = Self::unwrap_buffer(mask)?;
998        let dev = self.device(input.device_ordinal())?;
999        let result = crate::kernels::gpu_masked_fill(input_buf, mask_buf, value, dev)
1000            .map_err(Self::map_gpu_err)?;
1001        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1002    }
1003
1004    fn masked_zero_f32(
1005        &self,
1006        grad: &GpuBufferHandle,
1007        mask: &GpuBufferHandle,
1008    ) -> FerrotorchResult<GpuBufferHandle> {
1009        let grad_buf = Self::unwrap_buffer(grad)?;
1010        let mask_buf = Self::unwrap_buffer(mask)?;
1011        let dev = self.device(grad.device_ordinal())?;
1012        let result =
1013            crate::kernels::gpu_masked_zero(grad_buf, mask_buf, dev).map_err(Self::map_gpu_err)?;
1014        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1015    }
1016
1017    fn sigmoid_backward_f32(
1018        &self,
1019        grad: &GpuBufferHandle,
1020        output: &GpuBufferHandle,
1021    ) -> FerrotorchResult<GpuBufferHandle> {
1022        let grad_buf = Self::unwrap_buffer(grad)?;
1023        let output_buf = Self::unwrap_buffer(output)?;
1024        let dev = self.device(grad.device_ordinal())?;
1025        let result = crate::kernels::gpu_sigmoid_backward(grad_buf, output_buf, dev)
1026            .map_err(Self::map_gpu_err)?;
1027        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1028    }
1029
1030    fn tanh_backward_f32(
1031        &self,
1032        grad: &GpuBufferHandle,
1033        output: &GpuBufferHandle,
1034    ) -> FerrotorchResult<GpuBufferHandle> {
1035        let grad_buf = Self::unwrap_buffer(grad)?;
1036        let output_buf = Self::unwrap_buffer(output)?;
1037        let dev = self.device(grad.device_ordinal())?;
1038        let result = crate::kernels::gpu_tanh_backward(grad_buf, output_buf, dev)
1039            .map_err(Self::map_gpu_err)?;
1040        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1041    }
1042
1043    fn softmax_backward_f32(
1044        &self,
1045        grad: &GpuBufferHandle,
1046        output: &GpuBufferHandle,
1047        cols: usize,
1048    ) -> FerrotorchResult<GpuBufferHandle> {
1049        let grad_buf = Self::unwrap_buffer(grad)?;
1050        let output_buf = Self::unwrap_buffer(output)?;
1051        let dev = self.device(grad.device_ordinal())?;
1052        let result = crate::kernels::gpu_softmax_backward(grad_buf, output_buf, cols, dev)
1053            .map_err(Self::map_gpu_err)?;
1054        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1055    }
1056
1057    fn layernorm_backward_f32(
1058        &self,
1059        input: &GpuBufferHandle,
1060        grad_output: &GpuBufferHandle,
1061        weight: &GpuBufferHandle,
1062        rows: usize,
1063        cols: usize,
1064        eps: f32,
1065    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
1066        let in_buf = Self::unwrap_buffer(input)?;
1067        let go_buf = Self::unwrap_buffer(grad_output)?;
1068        let w_buf = Self::unwrap_buffer(weight)?;
1069        let dev = self.device(input.device_ordinal())?;
1070        let (gi, gw, gb) =
1071            crate::kernels::gpu_layernorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
1072                .map_err(Self::map_gpu_err)?;
1073        let ordinal = input.device_ordinal();
1074        Ok((
1075            Self::wrap_buffer(gi, ordinal),
1076            Self::wrap_buffer(gw, ordinal),
1077            Self::wrap_buffer(gb, ordinal),
1078        ))
1079    }
1080
1081    fn sum_axis_f32(
1082        &self,
1083        a: &GpuBufferHandle,
1084        shape: &[usize],
1085        axis: usize,
1086    ) -> FerrotorchResult<GpuBufferHandle> {
1087        let a_buf = Self::unwrap_buffer(a)?;
1088        let dev = self.device(a.device_ordinal())?;
1089        let outer: usize = shape[..axis].iter().product();
1090        let axis_size = shape[axis];
1091        let inner: usize = shape[axis + 1..].iter().product::<usize>().max(1);
1092        let result = crate::kernels::gpu_sum_axis(a_buf, outer, axis_size, inner, dev)
1093            .map_err(Self::map_gpu_err)?;
1094        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1095    }
1096
1097    fn matmul_f16_f32(
1098        &self,
1099        a: &GpuBufferHandle,
1100        b: &GpuBufferHandle,
1101        m: usize,
1102        k: usize,
1103        n: usize,
1104    ) -> FerrotorchResult<GpuBufferHandle> {
1105        let a_buf = Self::unwrap_buffer(a)?;
1106        let b_buf = Self::unwrap_buffer(b)?;
1107        let dev = self.device(a.device_ordinal())?;
1108        let result =
1109            crate::blas::gpu_matmul_f16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1110        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1111    }
1112
1113    fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
1114        let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1115            FerrotorchError::InvalidArgument {
1116                message: "failed to lock CUDA RNG manager".into(),
1117            }
1118        })?;
1119        let state = mgr.get_rng_state(device);
1120        Ok(GpuRngState {
1121            counter: state.counter,
1122            seed: state.seed,
1123            offset: state.offset,
1124            device,
1125        })
1126    }
1127
1128    fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
1129        let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1130            FerrotorchError::InvalidArgument {
1131                message: "failed to lock CUDA RNG manager".into(),
1132            }
1133        })?;
1134        mgr.set_rng_state(
1135            state.device,
1136            crate::rng::PhiloxState {
1137                counter: state.counter,
1138                seed: state.seed,
1139                offset: state.offset,
1140            },
1141        );
1142        Ok(())
1143    }
1144
1145    fn strided_split_f32(
1146        &self,
1147        input: &GpuBufferHandle,
1148        total_along_axis: usize,
1149        split_offset: usize,
1150        split_size: usize,
1151        inner_size: usize,
1152        n: usize,
1153    ) -> FerrotorchResult<GpuBufferHandle> {
1154        let in_buf = Self::unwrap_buffer(input)?;
1155        let dev = self.device(input.device_ordinal())?;
1156        let result = crate::kernels::gpu_strided_split(
1157            in_buf,
1158            total_along_axis,
1159            split_offset,
1160            split_size,
1161            inner_size,
1162            n,
1163            dev,
1164        )
1165        .map_err(Self::map_gpu_err)?;
1166        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1167    }
1168
1169    fn strided_cat_f32(
1170        &self,
1171        input: &GpuBufferHandle,
1172        output: &mut GpuBufferHandle,
1173        total_along_axis: usize,
1174        cat_offset: usize,
1175        part_size: usize,
1176        inner_size: usize,
1177        n: usize,
1178    ) -> FerrotorchResult<()> {
1179        let in_buf = Self::unwrap_buffer(input)?;
1180        let dev = self.device(input.device_ordinal())?;
1181        let out_buf =
1182            output
1183                .downcast_mut::<CudaBuffer<f32>>()
1184                .ok_or(FerrotorchError::InvalidArgument {
1185                    message: "strided_cat_f32: output is not CudaBuffer<f32>".into(),
1186                })?;
1187        crate::kernels::gpu_strided_cat(
1188            in_buf,
1189            out_buf,
1190            total_along_axis,
1191            cat_offset,
1192            part_size,
1193            inner_size,
1194            n,
1195            dev,
1196        )
1197        .map_err(Self::map_gpu_err)?;
1198        Ok(())
1199    }
1200}
1201
1202// ---------------------------------------------------------------------------
1203// Registration
1204// ---------------------------------------------------------------------------
1205
1206/// Get the `GpuDevice` from the registered CUDA backend.
1207///
1208/// This retrieves the device that was created during [`init_cuda_backend`],
1209/// ensuring all kernel modules and cuBLAS handles are shared. Creating a
1210/// second `GpuDevice` via `GpuDevice::new(0)` would create a separate
1211/// CUDA context with its own module cache, which is not interoperable.
1212pub fn get_cuda_device() -> FerrotorchResult<Arc<GpuDevice>> {
1213    let backend =
1214        ferrotorch_core::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
1215    // The global backend is a &dyn GpuBackend. We know it's CudaBackendImpl
1216    // because init_cuda_backend registered it. Downcast via Any.
1217    let cuda_backend = backend.as_any().downcast_ref::<CudaBackendImpl>().ok_or(
1218        FerrotorchError::InvalidArgument {
1219            message: "registered GPU backend is not CudaBackendImpl".into(),
1220        },
1221    )?;
1222    Ok(Arc::clone(cuda_backend.default_device()?))
1223}
1224
1225/// Initialize the CUDA backend and register it with ferrotorch-core.
1226///
1227/// This must be called before any GPU tensor operations. It creates a
1228/// [`CudaBackendImpl`] (initializing CUDA device 0) and registers it via
1229/// [`ferrotorch_core::gpu_dispatch::register_gpu_backend`].
1230///
1231/// Calling this a second time returns an error (the backend is already
1232/// registered).
1233///
1234/// # Errors
1235///
1236/// - [`FerrotorchError::InvalidArgument`] if CUDA initialization fails.
1237/// - [`FerrotorchError::InvalidArgument`] if a GPU backend is already registered.
1238pub fn init_cuda_backend() -> FerrotorchResult<()> {
1239    // Idempotent: if already registered, return Ok silently.
1240    if ferrotorch_core::gpu_dispatch::has_gpu_backend() {
1241        return Ok(());
1242    }
1243    let backend = CudaBackendImpl::new()?;
1244    // OnceLock::set can still race if two threads call init concurrently —
1245    // if that happens, the second set() fails but the backend is registered
1246    // by the first. We treat that as success.
1247    let _ = ferrotorch_core::gpu_dispatch::register_gpu_backend(Box::new(backend));
1248    Ok(())
1249}
1250
1251// ---------------------------------------------------------------------------
1252// Tests
1253// ---------------------------------------------------------------------------
1254
1255#[cfg(test)]
1256#[cfg(feature = "cuda")]
1257mod tests {
1258    use super::*;
1259    use ferrotorch_core::gpu_dispatch;
1260
1261    // Note: Because `register_gpu_backend` uses a `OnceLock`, only the first
1262    // test to call `init_cuda_backend()` will succeed at registration. The
1263    // others will see the backend as already registered. We handle this by
1264    // checking `has_gpu_backend()` before calling init.
1265
1266    /// Ensure the backend can be initialized (or was already initialized).
1267    fn ensure_init() {
1268        if !gpu_dispatch::has_gpu_backend() {
1269            init_cuda_backend().expect("init_cuda_backend");
1270        }
1271    }
1272
1273    #[test]
1274    fn test_init_cuda_backend() {
1275        // First call succeeds (or backend was already registered by another test).
1276        ensure_init();
1277        assert!(gpu_dispatch::has_gpu_backend());
1278    }
1279
1280    #[test]
1281    fn test_gpu_backend_returns_some() {
1282        ensure_init();
1283        assert!(gpu_dispatch::gpu_backend().is_some());
1284    }
1285
1286    #[test]
1287    fn test_roundtrip_cpu_gpu_cpu() {
1288        ensure_init();
1289        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1290
1291        let host: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1292        let bytes: &[u8] = unsafe {
1293            std::slice::from_raw_parts(
1294                host.as_ptr() as *const u8,
1295                host.len() * std::mem::size_of::<f32>(),
1296            )
1297        };
1298
1299        let handle = backend.cpu_to_gpu(bytes, 4, 0).expect("cpu_to_gpu");
1300        assert_eq!(handle.len(), 5);
1301        assert_eq!(handle.device_ordinal(), 0);
1302
1303        let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu");
1304        let back: &[f32] = unsafe {
1305            std::slice::from_raw_parts(back_bytes.as_ptr() as *const f32, back_bytes.len() / 4)
1306        };
1307        assert_eq!(back, &host[..]);
1308    }
1309
1310    #[test]
1311    fn test_add_f32() {
1312        ensure_init();
1313        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1314
1315        let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1316        let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
1317        let expected: Vec<f32> = vec![11.0, 22.0, 33.0, 44.0];
1318
1319        let a_bytes: &[u8] =
1320            unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
1321        let b_bytes: &[u8] =
1322            unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
1323
1324        let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
1325        let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
1326
1327        let result = backend.add_f32(&a_handle, &b_handle).expect("add_f32");
1328        assert_eq!(result.len(), 4);
1329
1330        let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
1331        let result_f32: &[f32] = unsafe {
1332            std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
1333        };
1334
1335        for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
1336            assert!(
1337                (got - exp).abs() < 1e-6,
1338                "element {i}: got {got}, expected {exp}",
1339            );
1340        }
1341    }
1342
1343    #[test]
1344    fn test_matmul_f32() {
1345        ensure_init();
1346        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1347
1348        // A = [[1, 2, 3],
1349        //      [4, 5, 6]]  (2x3)
1350        // B = [[7, 8],
1351        //      [9, 10],
1352        //      [11, 12]]   (3x2)
1353        // C = [[58, 64],
1354        //      [139, 154]] (2x2)
1355        let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1356        let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
1357        let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
1358
1359        let a_bytes: &[u8] =
1360            unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
1361        let b_bytes: &[u8] =
1362            unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
1363
1364        let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
1365        let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
1366
1367        let result = backend
1368            .matmul_f32(&a_handle, &b_handle, 2, 3, 2)
1369            .expect("matmul_f32");
1370        assert_eq!(result.len(), 4);
1371
1372        let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
1373        let result_f32: &[f32] = unsafe {
1374            std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
1375        };
1376
1377        for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
1378            assert!(
1379                (got - exp).abs() < 1e-3,
1380                "element {i}: got {got}, expected {exp}",
1381            );
1382        }
1383    }
1384}