Skip to main content

ferrotorch_gpu/
backend_impl.rs

1//! CUDA implementation of the [`GpuBackend`] trait from ferrotorch-core.
2//!
3//! This module bridges the existing GPU operations (`gpu_add`, `gpu_matmul_f32`,
4//! etc.) to the type-erased [`GpuBackend`] dispatch interface, enabling
5//! ferrotorch-core to call GPU operations without depending on this crate
6//! directly.
7//!
8//! # Initialization
9//!
10//! Call [`init_cuda_backend`] once at startup (typically via `ferrotorch::init()`).
11//! This creates a [`CudaBackendImpl`], initializes CUDA device 0, and registers
12//! it with [`ferrotorch_core::gpu_dispatch::register_gpu_backend`].
13
14use std::sync::Arc;
15
16use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
17use ferrotorch_core::gpu_dispatch::{GpuBackend, GpuBufferHandle, GpuRngState};
18
19use crate::buffer::CudaBuffer;
20use crate::device::GpuDevice;
21
22// ---------------------------------------------------------------------------
23// CudaBackendImpl
24// ---------------------------------------------------------------------------
25
26/// CUDA implementation of the [`GpuBackend`] trait.
27///
28/// Holds one or more [`GpuDevice`] handles (currently device 0 only) and
29/// delegates every trait method to the corresponding function in
30/// [`crate::kernels`], [`crate::blas`], or [`crate::transfer`].
31pub struct CudaBackendImpl {
32    devices: Vec<Arc<GpuDevice>>,
33}
34
35impl CudaBackendImpl {
36    /// Create a new CUDA backend, initializing device 0.
37    ///
38    /// # Errors
39    ///
40    /// Returns [`FerrotorchError::InvalidArgument`] if CUDA initialization fails
41    /// (e.g. no GPU available, driver not loaded).
42    pub fn new() -> FerrotorchResult<Self> {
43        let device = Arc::new(
44            GpuDevice::new(0).map_err(|e| FerrotorchError::InvalidArgument {
45                message: format!("CUDA init failed: {e}"),
46            })?,
47        );
48        Ok(Self {
49            devices: vec![device],
50        })
51    }
52
53    /// Get the device for ordinal 0 (the default device).
54    pub fn default_device(&self) -> FerrotorchResult<&Arc<GpuDevice>> {
55        self.device(0)
56    }
57
58    /// Look up a device by ordinal.
59    fn device(&self, ordinal: usize) -> FerrotorchResult<&Arc<GpuDevice>> {
60        self.devices
61            .get(ordinal)
62            .ok_or(FerrotorchError::InvalidArgument {
63                message: format!("CUDA device {ordinal} not available"),
64            })
65    }
66
67    /// Wrap a `CudaBuffer<f32>` into a type-erased [`GpuBufferHandle`].
68    fn wrap_buffer(buf: CudaBuffer<f32>, ordinal: usize) -> GpuBufferHandle {
69        let len = buf.len();
70        GpuBufferHandle::new(Box::new(buf), ordinal, len)
71    }
72
73    /// Wrap a `CudaBuffer<f64>` into a type-erased [`GpuBufferHandle`].
74    fn wrap_buffer_f64(buf: CudaBuffer<f64>, ordinal: usize) -> GpuBufferHandle {
75        let len = buf.len();
76        GpuBufferHandle::new(Box::new(buf), ordinal, len)
77    }
78
79    /// Extract a `&CudaBuffer<f32>` from a [`GpuBufferHandle`].
80    fn unwrap_buffer(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f32>> {
81        handle
82            .downcast_ref::<CudaBuffer<f32>>()
83            .ok_or(FerrotorchError::InvalidArgument {
84                message: "GPU handle does not contain a CudaBuffer<f32>".into(),
85            })
86    }
87
88    /// Extract a `&mut CudaBuffer<f32>` from a [`GpuBufferHandle`].
89    fn unwrap_buffer_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f32>> {
90        handle
91            .downcast_mut::<CudaBuffer<f32>>()
92            .ok_or(FerrotorchError::InvalidArgument {
93                message: "GPU handle does not contain a CudaBuffer<f32>".into(),
94            })
95    }
96
97    /// Extract a `&CudaBuffer<f64>` from a [`GpuBufferHandle`].
98    fn unwrap_buffer_f64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f64>> {
99        handle
100            .downcast_ref::<CudaBuffer<f64>>()
101            .ok_or(FerrotorchError::InvalidArgument {
102                message: "GPU handle does not contain a CudaBuffer<f64>".into(),
103            })
104    }
105
106    /// Convert a [`crate::error::GpuError`] into a [`FerrotorchError`].
107    fn map_gpu_err(e: crate::error::GpuError) -> FerrotorchError {
108        FerrotorchError::InvalidArgument {
109            message: format!("{e}"),
110        }
111    }
112}
113
114// ---------------------------------------------------------------------------
115// GpuBackend implementation
116// ---------------------------------------------------------------------------
117
118impl GpuBackend for CudaBackendImpl {
119    fn as_any(&self) -> &dyn std::any::Any {
120        self
121    }
122
123    fn cpu_to_gpu(
124        &self,
125        data: &[u8],
126        elem_size: usize,
127        device: usize,
128    ) -> FerrotorchResult<GpuBufferHandle> {
129        let dev = self.device(device)?;
130        match elem_size {
131            4 => {
132                // SAFETY: The caller (ferrotorch-core) guarantees that `data`
133                // was originally an f32 slice serialised to bytes.
134                let count = data.len() / 4;
135                let f32_data: &[f32] =
136                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
137                let buf = crate::transfer::cpu_to_gpu(f32_data, dev).map_err(Self::map_gpu_err)?;
138                Ok(Self::wrap_buffer(buf, device))
139            }
140            8 => {
141                // SAFETY: The caller (ferrotorch-core) guarantees that `data`
142                // was originally an f64 slice serialised to bytes.
143                let count = data.len() / 8;
144                let f64_data: &[f64] =
145                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
146                let buf = crate::transfer::cpu_to_gpu(f64_data, dev).map_err(Self::map_gpu_err)?;
147                Ok(Self::wrap_buffer_f64(buf, device))
148            }
149            other => Err(FerrotorchError::InvalidArgument {
150                message: format!("cpu_to_gpu: unsupported elem_size {other} (expected 4 or 8)"),
151            }),
152        }
153    }
154
155    fn cpu_to_gpu_pinned(
156        &self,
157        data: &[u8],
158        elem_size: usize,
159        device: usize,
160    ) -> FerrotorchResult<GpuBufferHandle> {
161        let dev = self.device(device)?;
162        match elem_size {
163            4 => {
164                let count = data.len() / 4;
165                let f32_data: &[f32] =
166                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
167                let buf = crate::transfer::cpu_to_gpu_pinned(f32_data, dev)
168                    .map_err(Self::map_gpu_err)?;
169                Ok(Self::wrap_buffer(buf, device))
170            }
171            8 => {
172                let count = data.len() / 8;
173                let f64_data: &[f64] =
174                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
175                let buf = crate::transfer::cpu_to_gpu_pinned(f64_data, dev)
176                    .map_err(Self::map_gpu_err)?;
177                Ok(Self::wrap_buffer_f64(buf, device))
178            }
179            other => Err(FerrotorchError::InvalidArgument {
180                message: format!(
181                    "cpu_to_gpu_pinned: unsupported elem_size {other} (expected 4 or 8)"
182                ),
183            }),
184        }
185    }
186
187    fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>> {
188        let dev = self.device(handle.device_ordinal())?;
189
190        // Try f32 first, then f64.
191        if let Ok(buf) = Self::unwrap_buffer(handle) {
192            let f32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
193
194            // Reinterpret Vec<f32> as Vec<u8> without copying.
195            // SAFETY: f32 has alignment 4 and size 4. We adjust len and capacity
196            // accordingly. The original Vec is consumed via ManuallyDrop so its
197            // destructor won't free the allocation.
198            let bytes = unsafe {
199                let mut v = std::mem::ManuallyDrop::new(f32_data);
200                let ptr = v.as_mut_ptr() as *mut u8;
201                let len = v.len() * 4;
202                let cap = v.capacity() * 4;
203                Vec::from_raw_parts(ptr, len, cap)
204            };
205            Ok(bytes)
206        } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
207            let f64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
208
209            // Reinterpret Vec<f64> as Vec<u8> without copying.
210            // SAFETY: f64 has alignment 8 and size 8. We adjust len and capacity
211            // accordingly. The original Vec is consumed via ManuallyDrop so its
212            // destructor won't free the allocation.
213            let bytes = unsafe {
214                let mut v = std::mem::ManuallyDrop::new(f64_data);
215                let ptr = v.as_mut_ptr() as *mut u8;
216                let len = v.len() * 8;
217                let cap = v.capacity() * 8;
218                Vec::from_raw_parts(ptr, len, cap)
219            };
220            Ok(bytes)
221        } else {
222            Err(FerrotorchError::InvalidArgument {
223                message: "gpu_to_cpu: handle is neither CudaBuffer<f32> nor CudaBuffer<f64>".into(),
224            })
225        }
226    }
227
228    fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
229        // Clone via GPU -> CPU -> GPU round-trip.
230        // Correct but not optimal; a device-to-device memcpy would be better.
231        let bytes = self.gpu_to_cpu(handle)?;
232        // Determine elem_size from the concrete buffer type.
233        let elem_size = if handle.downcast_ref::<CudaBuffer<f64>>().is_some() {
234            8
235        } else {
236            4
237        };
238        self.cpu_to_gpu(&bytes, elem_size, handle.device_ordinal())
239    }
240
241    fn alloc_zeros(
242        &self,
243        len: usize,
244        elem_size: usize,
245        device: usize,
246    ) -> FerrotorchResult<GpuBufferHandle> {
247        let dev = self.device(device)?;
248        match elem_size {
249            4 => {
250                let buf = crate::transfer::alloc_zeros_f32(len, dev).map_err(Self::map_gpu_err)?;
251                Ok(Self::wrap_buffer(buf, device))
252            }
253            8 => {
254                let buf = crate::transfer::alloc_zeros_f64(len, dev).map_err(Self::map_gpu_err)?;
255                Ok(Self::wrap_buffer_f64(buf, device))
256            }
257            other => Err(FerrotorchError::InvalidArgument {
258                message: format!("alloc_zeros: unsupported elem_size {other} (expected 4 or 8)"),
259            }),
260        }
261    }
262
263    // -- Elementwise f32 ------------------------------------------------------
264
265    fn add_f32(
266        &self,
267        a: &GpuBufferHandle,
268        b: &GpuBufferHandle,
269    ) -> FerrotorchResult<GpuBufferHandle> {
270        let a_buf = Self::unwrap_buffer(a)?;
271        let b_buf = Self::unwrap_buffer(b)?;
272        let dev = self.device(a.device_ordinal())?;
273        let result = crate::kernels::gpu_add(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
274        Ok(Self::wrap_buffer(result, a.device_ordinal()))
275    }
276
277    fn sub_f32(
278        &self,
279        a: &GpuBufferHandle,
280        b: &GpuBufferHandle,
281    ) -> FerrotorchResult<GpuBufferHandle> {
282        let a_buf = Self::unwrap_buffer(a)?;
283        let b_buf = Self::unwrap_buffer(b)?;
284        let dev = self.device(a.device_ordinal())?;
285        let result = crate::kernels::gpu_sub(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
286        Ok(Self::wrap_buffer(result, a.device_ordinal()))
287    }
288
289    fn mul_f32(
290        &self,
291        a: &GpuBufferHandle,
292        b: &GpuBufferHandle,
293    ) -> FerrotorchResult<GpuBufferHandle> {
294        let a_buf = Self::unwrap_buffer(a)?;
295        let b_buf = Self::unwrap_buffer(b)?;
296        let dev = self.device(a.device_ordinal())?;
297        let result = crate::kernels::gpu_mul(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
298        Ok(Self::wrap_buffer(result, a.device_ordinal()))
299    }
300
301    fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
302        let a_buf = Self::unwrap_buffer(a)?;
303        let dev = self.device(a.device_ordinal())?;
304        let result = crate::kernels::gpu_neg(a_buf, dev).map_err(Self::map_gpu_err)?;
305        Ok(Self::wrap_buffer(result, a.device_ordinal()))
306    }
307
308    fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
309        let a_buf = Self::unwrap_buffer(a)?;
310        let dev = self.device(a.device_ordinal())?;
311        let result = crate::kernels::gpu_relu(a_buf, dev).map_err(Self::map_gpu_err)?;
312        Ok(Self::wrap_buffer(result, a.device_ordinal()))
313    }
314
315    fn div_f32(
316        &self,
317        a: &GpuBufferHandle,
318        b: &GpuBufferHandle,
319    ) -> FerrotorchResult<GpuBufferHandle> {
320        let a_buf = Self::unwrap_buffer(a)?;
321        let b_buf = Self::unwrap_buffer(b)?;
322        let dev = self.device(a.device_ordinal())?;
323        let result = crate::kernels::gpu_div(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
324        Ok(Self::wrap_buffer(result, a.device_ordinal()))
325    }
326
327    fn exp_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
328        let a_buf = Self::unwrap_buffer(a)?;
329        let dev = self.device(a.device_ordinal())?;
330        let result = crate::kernels::gpu_exp(a_buf, dev).map_err(Self::map_gpu_err)?;
331        Ok(Self::wrap_buffer(result, a.device_ordinal()))
332    }
333
334    fn log_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
335        let a_buf = Self::unwrap_buffer(a)?;
336        let dev = self.device(a.device_ordinal())?;
337        let result = crate::kernels::gpu_log(a_buf, dev).map_err(Self::map_gpu_err)?;
338        Ok(Self::wrap_buffer(result, a.device_ordinal()))
339    }
340
341    fn sqrt_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
342        let a_buf = Self::unwrap_buffer(a)?;
343        let dev = self.device(a.device_ordinal())?;
344        let result = crate::kernels::gpu_sqrt(a_buf, dev).map_err(Self::map_gpu_err)?;
345        Ok(Self::wrap_buffer(result, a.device_ordinal()))
346    }
347
348    fn pow_f32(&self, a: &GpuBufferHandle, exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
349        let a_buf = Self::unwrap_buffer(a)?;
350        let dev = self.device(a.device_ordinal())?;
351        let result =
352            crate::kernels::gpu_pow(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
353        Ok(Self::wrap_buffer(result, a.device_ordinal()))
354    }
355
356    fn abs_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
357        let a_buf = Self::unwrap_buffer(a)?;
358        let dev = self.device(a.device_ordinal())?;
359        let result = crate::kernels::gpu_abs(a_buf, dev).map_err(Self::map_gpu_err)?;
360        Ok(Self::wrap_buffer(result, a.device_ordinal()))
361    }
362
363    fn sigmoid_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
364        let a_buf = Self::unwrap_buffer(a)?;
365        let dev = self.device(a.device_ordinal())?;
366        let result = crate::kernels::gpu_sigmoid(a_buf, dev).map_err(Self::map_gpu_err)?;
367        Ok(Self::wrap_buffer(result, a.device_ordinal()))
368    }
369
370    fn tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
371        let a_buf = Self::unwrap_buffer(a)?;
372        let dev = self.device(a.device_ordinal())?;
373        let result = crate::kernels::gpu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
374        Ok(Self::wrap_buffer(result, a.device_ordinal()))
375    }
376
377    #[allow(clippy::too_many_arguments)]
378    fn fused_adam_f32(
379        &self,
380        param: &mut GpuBufferHandle,
381        grad: &GpuBufferHandle,
382        exp_avg: &mut GpuBufferHandle,
383        exp_avg_sq: &mut GpuBufferHandle,
384        beta1: f32,
385        beta2: f32,
386        lr: f32,
387        eps: f32,
388        bc1: f32,
389        bc2: f32,
390        weight_decay: f32,
391    ) -> FerrotorchResult<()> {
392        let ordinal = param.device_ordinal();
393        let dev = self.device(ordinal)?;
394        let p_buf = Self::unwrap_buffer_mut(param)?;
395        let g_buf = Self::unwrap_buffer(grad)?;
396        let m_buf = Self::unwrap_buffer_mut(exp_avg)?;
397        let v_buf = Self::unwrap_buffer_mut(exp_avg_sq)?;
398        crate::kernels::gpu_fused_adam(
399            p_buf,
400            g_buf,
401            m_buf,
402            v_buf,
403            beta1,
404            beta2,
405            lr,
406            eps,
407            bc1,
408            bc2,
409            weight_decay,
410            dev,
411        )
412        .map_err(Self::map_gpu_err)?;
413        Ok(())
414    }
415
416    #[allow(clippy::too_many_arguments)]
417    fn maxpool2d_f32(
418        &self,
419        input: &GpuBufferHandle,
420        batch: usize,
421        channels: usize,
422        h_in: usize,
423        w_in: usize,
424        kh: usize,
425        kw: usize,
426        sh: usize,
427        sw: usize,
428        ph: usize,
429        pw: usize,
430    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
431        let buf = Self::unwrap_buffer(input)?;
432        let dev = self.device(input.device_ordinal())?;
433        let (out, shape) = crate::kernels::gpu_maxpool2d(
434            buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
435        ).map_err(Self::map_gpu_err)?;
436        Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
437    }
438
439    #[allow(clippy::too_many_arguments)]
440    fn avgpool2d_f32(
441        &self,
442        input: &GpuBufferHandle,
443        batch: usize,
444        channels: usize,
445        h_in: usize,
446        w_in: usize,
447        kh: usize,
448        kw: usize,
449        sh: usize,
450        sw: usize,
451        ph: usize,
452        pw: usize,
453    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
454        let buf = Self::unwrap_buffer(input)?;
455        let dev = self.device(input.device_ordinal())?;
456        let (out, shape) = crate::kernels::gpu_avgpool2d(
457            buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
458        ).map_err(Self::map_gpu_err)?;
459        Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
460    }
461
462    #[allow(clippy::too_many_arguments)]
463    fn conv2d_f32(
464        &self,
465        input: &GpuBufferHandle,
466        weight: &GpuBufferHandle,
467        bias: Option<&GpuBufferHandle>,
468        input_shape: [usize; 4],
469        weight_shape: [usize; 4],
470        stride: (usize, usize),
471        padding: (usize, usize),
472    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
473        let input_buf = Self::unwrap_buffer(input)?;
474        let weight_buf = Self::unwrap_buffer(weight)?;
475        let bias_buf = match bias {
476            Some(b) => Some(Self::unwrap_buffer(b)?),
477            None => None,
478        };
479        let dev = self.device(input.device_ordinal())?;
480        let (out_buf, out_shape) = crate::conv::gpu_conv2d_f32(
481            input_buf,
482            weight_buf,
483            bias_buf,
484            input_shape,
485            weight_shape,
486            stride,
487            padding,
488            dev,
489        )
490        .map_err(Self::map_gpu_err)?;
491        Ok((Self::wrap_buffer(out_buf, input.device_ordinal()), out_shape))
492    }
493
494    fn fused_gru_cell_f32(
495        &self,
496        input_gates: &GpuBufferHandle,
497        hidden_gates: &GpuBufferHandle,
498        bias_ih: &GpuBufferHandle,
499        bias_hh: &GpuBufferHandle,
500        hx: &GpuBufferHandle,
501        hidden_size: usize,
502    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
503        let ig = Self::unwrap_buffer(input_gates)?;
504        let hg = Self::unwrap_buffer(hidden_gates)?;
505        let bih = Self::unwrap_buffer(bias_ih)?;
506        let bhh = Self::unwrap_buffer(bias_hh)?;
507        let hx_buf = Self::unwrap_buffer(hx)?;
508        let dev = self.device(input_gates.device_ordinal())?;
509        let (hy, ws) = crate::kernels::gpu_fused_gru_forward(
510            ig, hg, bih, bhh, hx_buf, hidden_size, dev,
511        )
512        .map_err(Self::map_gpu_err)?;
513        let ord = input_gates.device_ordinal();
514        Ok((Self::wrap_buffer(hy, ord), Self::wrap_buffer(ws, ord)))
515    }
516
517    fn synchronize(&self, device: usize) -> FerrotorchResult<()> {
518        let dev = self.device(device)?;
519        dev.stream()
520            .synchronize()
521            .map_err(|e| FerrotorchError::InvalidArgument {
522                message: format!("CUDA synchronize failed: {e}"),
523            })?;
524        Ok(())
525    }
526
527    fn stream_count(&self, device: usize) -> usize {
528        crate::stream::StreamPool::pool_size(device)
529    }
530
531    // -- Linalg f32 -----------------------------------------------------------
532
533    fn matmul_f32(
534        &self,
535        a: &GpuBufferHandle,
536        b: &GpuBufferHandle,
537        m: usize,
538        k: usize,
539        n: usize,
540    ) -> FerrotorchResult<GpuBufferHandle> {
541        let a_buf = Self::unwrap_buffer(a)?;
542        let b_buf = Self::unwrap_buffer(b)?;
543        let dev = self.device(a.device_ordinal())?;
544        let result =
545            crate::blas::gpu_matmul_f32(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
546        Ok(Self::wrap_buffer(result, a.device_ordinal()))
547    }
548
549    // -- Reduction f32 --------------------------------------------------------
550
551    fn sum_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
552        let a_buf = Self::unwrap_buffer(a)?;
553        let dev = self.device(a.device_ordinal())?;
554        let result = crate::kernels::gpu_reduce_sum(a_buf, dev).map_err(Self::map_gpu_err)?;
555        Ok(Self::wrap_buffer(result, a.device_ordinal()))
556    }
557
558    // -- Linalg f64 (cuBLAS DGEMM) --------------------------------------------
559
560    fn matmul_f64(
561        &self,
562        a: &GpuBufferHandle,
563        b: &GpuBufferHandle,
564        m: usize,
565        k: usize,
566        n: usize,
567    ) -> FerrotorchResult<GpuBufferHandle> {
568        let a_buf = Self::unwrap_buffer_f64(a)?;
569        let b_buf = Self::unwrap_buffer_f64(b)?;
570        let dev = self.device(a.device_ordinal())?;
571        let result =
572            crate::blas::gpu_matmul_f64(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
573        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
574    }
575
576    // -- Broadcast binary f32 -------------------------------------------------
577
578    fn broadcast_add_f32(
579        &self,
580        a: &GpuBufferHandle,
581        b: &GpuBufferHandle,
582        a_shape: &[usize],
583        b_shape: &[usize],
584        out_shape: &[usize],
585    ) -> FerrotorchResult<GpuBufferHandle> {
586        let a_buf = Self::unwrap_buffer(a)?;
587        let b_buf = Self::unwrap_buffer(b)?;
588        let dev = self.device(a.device_ordinal())?;
589        let result =
590            crate::kernels::gpu_broadcast_add(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
591                .map_err(Self::map_gpu_err)?;
592        Ok(Self::wrap_buffer(result, a.device_ordinal()))
593    }
594
595    fn broadcast_sub_f32(
596        &self,
597        a: &GpuBufferHandle,
598        b: &GpuBufferHandle,
599        a_shape: &[usize],
600        b_shape: &[usize],
601        out_shape: &[usize],
602    ) -> FerrotorchResult<GpuBufferHandle> {
603        let a_buf = Self::unwrap_buffer(a)?;
604        let b_buf = Self::unwrap_buffer(b)?;
605        let dev = self.device(a.device_ordinal())?;
606        let result =
607            crate::kernels::gpu_broadcast_sub(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
608                .map_err(Self::map_gpu_err)?;
609        Ok(Self::wrap_buffer(result, a.device_ordinal()))
610    }
611
612    fn broadcast_mul_f32(
613        &self,
614        a: &GpuBufferHandle,
615        b: &GpuBufferHandle,
616        a_shape: &[usize],
617        b_shape: &[usize],
618        out_shape: &[usize],
619    ) -> FerrotorchResult<GpuBufferHandle> {
620        let a_buf = Self::unwrap_buffer(a)?;
621        let b_buf = Self::unwrap_buffer(b)?;
622        let dev = self.device(a.device_ordinal())?;
623        let result =
624            crate::kernels::gpu_broadcast_mul(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
625                .map_err(Self::map_gpu_err)?;
626        Ok(Self::wrap_buffer(result, a.device_ordinal()))
627    }
628
629    fn broadcast_div_f32(
630        &self,
631        a: &GpuBufferHandle,
632        b: &GpuBufferHandle,
633        a_shape: &[usize],
634        b_shape: &[usize],
635        out_shape: &[usize],
636    ) -> FerrotorchResult<GpuBufferHandle> {
637        let a_buf = Self::unwrap_buffer(a)?;
638        let b_buf = Self::unwrap_buffer(b)?;
639        let dev = self.device(a.device_ordinal())?;
640        let result =
641            crate::kernels::gpu_broadcast_div(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
642                .map_err(Self::map_gpu_err)?;
643        Ok(Self::wrap_buffer(result, a.device_ordinal()))
644    }
645
646    fn softmax_f32(
647        &self,
648        a: &GpuBufferHandle,
649        rows: usize,
650        cols: usize,
651    ) -> FerrotorchResult<GpuBufferHandle> {
652        let a_buf = Self::unwrap_buffer(a)?;
653        let dev = self.device(a.device_ordinal())?;
654        let result =
655            crate::kernels::gpu_softmax(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
656        Ok(Self::wrap_buffer(result, a.device_ordinal()))
657    }
658
659    fn dropout_f32(
660        &self,
661        a: &GpuBufferHandle,
662        threshold: u32,
663        scale: f32,
664        seed: u32,
665    ) -> FerrotorchResult<GpuBufferHandle> {
666        let a_buf = Self::unwrap_buffer(a)?;
667        let dev = self.device(a.device_ordinal())?;
668        let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, seed, dev)
669            .map_err(Self::map_gpu_err)?;
670        Ok(Self::wrap_buffer(result, a.device_ordinal()))
671    }
672
673    fn dropout_philox_f32(
674        &self,
675        a: &GpuBufferHandle,
676        threshold: u32,
677        scale: f32,
678    ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
679        let device_ordinal = a.device_ordinal();
680        let n = a.len();
681
682        // Snapshot the current RNG state and advance it.
683        let rng_state = {
684            let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
685                FerrotorchError::InvalidArgument {
686                    message: "failed to lock CUDA RNG manager".into(),
687                }
688            })?;
689            let philox_gen = mgr.generator(device_ordinal);
690            let state = philox_gen.get_state();
691            // Advance by ceil(n/4) counters (each counter produces 4 u32 values)
692            let counters_needed = n.div_ceil(4);
693            philox_gen.advance(counters_needed as u64);
694            state
695        };
696
697        // Use the Philox state as the seed for the dropout kernel.
698        // We encode the Philox counter+seed into a u32 seed that the existing
699        // dropout kernel can use. For full correctness on GPU, we should use
700        // the Philox uniform kernel to generate the mask, then apply it.
701        // However, for consistency between GPU forward and CPU backward mask
702        // regeneration, we use the Philox state to deterministically derive a
703        // seed for the existing kernel.
704        let a_buf = Self::unwrap_buffer(a)?;
705        let dev = self.device(device_ordinal)?;
706
707        // Use the Philox counter XOR seed as the dropout kernel's seed.
708        // This gives us deterministic behavior tied to the Philox state.
709        let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
710        let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, derived_seed, dev)
711            .map_err(Self::map_gpu_err)?;
712
713        let gpu_rng_state = GpuRngState {
714            counter: rng_state.counter,
715            seed: rng_state.seed,
716            offset: rng_state.offset,
717            device: device_ordinal,
718        };
719
720        Ok((Self::wrap_buffer(result, device_ordinal), gpu_rng_state))
721    }
722
723    fn transpose_2d_f32(
724        &self,
725        a: &GpuBufferHandle,
726        m: usize,
727        n: usize,
728    ) -> FerrotorchResult<GpuBufferHandle> {
729        let a_buf = Self::unwrap_buffer(a)?;
730        let dev = self.device(a.device_ordinal())?;
731        let result =
732            crate::kernels::gpu_transpose_2d(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
733        Ok(Self::wrap_buffer(result, a.device_ordinal()))
734    }
735
736    fn permute_0213_f32(
737        &self,
738        a: &GpuBufferHandle,
739        d0: usize,
740        d1: usize,
741        d2: usize,
742        d3: usize,
743    ) -> FerrotorchResult<GpuBufferHandle> {
744        let a_buf = Self::unwrap_buffer(a)?;
745        let dev = self.device(a.device_ordinal())?;
746        let result = crate::kernels::gpu_permute_0213(a_buf, d0, d1, d2, d3, dev)
747            .map_err(Self::map_gpu_err)?;
748        Ok(Self::wrap_buffer(result, a.device_ordinal()))
749    }
750
751    fn bmm_f32(
752        &self,
753        a: &GpuBufferHandle,
754        b: &GpuBufferHandle,
755        batch: usize,
756        m: usize,
757        k: usize,
758        n: usize,
759    ) -> FerrotorchResult<GpuBufferHandle> {
760        let a_buf = Self::unwrap_buffer(a)?;
761        let b_buf = Self::unwrap_buffer(b)?;
762        let dev = self.device(a.device_ordinal())?;
763        let result = crate::blas::gpu_bmm_f32(a_buf, b_buf, batch, m, k, n, dev)
764            .map_err(Self::map_gpu_err)?;
765        Ok(Self::wrap_buffer(result, a.device_ordinal()))
766    }
767
768    fn bmm_f16_f32(
769        &self,
770        a: &GpuBufferHandle,
771        b: &GpuBufferHandle,
772        batch: usize,
773        m: usize,
774        k: usize,
775        n: usize,
776    ) -> FerrotorchResult<GpuBufferHandle> {
777        let a_buf = Self::unwrap_buffer(a)?;
778        let b_buf = Self::unwrap_buffer(b)?;
779        let dev = self.device(a.device_ordinal())?;
780        let result = crate::blas::gpu_bmm_f16(a_buf, b_buf, batch, m, k, n, dev)
781            .map_err(Self::map_gpu_err)?;
782        Ok(Self::wrap_buffer(result, a.device_ordinal()))
783    }
784
785    fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
786        let a_buf = Self::unwrap_buffer(a)?;
787        let dev = self.device(a.device_ordinal())?;
788        let result = crate::kernels::gpu_gelu(a_buf, dev).map_err(Self::map_gpu_err)?;
789        Ok(Self::wrap_buffer(result, a.device_ordinal()))
790    }
791
792    fn gelu_tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
793        let a_buf = Self::unwrap_buffer(a)?;
794        let dev = self.device(a.device_ordinal())?;
795        let result = crate::kernels::gpu_gelu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
796        Ok(Self::wrap_buffer(result, a.device_ordinal()))
797    }
798
799    fn gelu_erf_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
800        let a_buf = Self::unwrap_buffer(a)?;
801        let dev = self.device(a.device_ordinal())?;
802        let result = crate::kernels::gpu_gelu_erf(a_buf, dev).map_err(Self::map_gpu_err)?;
803        Ok(Self::wrap_buffer(result, a.device_ordinal()))
804    }
805
806    fn layernorm_f32(
807        &self,
808        input: &GpuBufferHandle,
809        weight: &GpuBufferHandle,
810        bias: &GpuBufferHandle,
811        rows: usize,
812        cols: usize,
813        eps: f32,
814    ) -> FerrotorchResult<GpuBufferHandle> {
815        let in_buf = Self::unwrap_buffer(input)?;
816        let w_buf = Self::unwrap_buffer(weight)?;
817        let b_buf = Self::unwrap_buffer(bias)?;
818        let dev = self.device(input.device_ordinal())?;
819        let result = crate::kernels::gpu_layernorm(in_buf, w_buf, b_buf, rows, cols, eps, dev)
820            .map_err(Self::map_gpu_err)?;
821        Ok(Self::wrap_buffer(result, input.device_ordinal()))
822    }
823
824    fn rmsnorm_f32(
825        &self,
826        input: &GpuBufferHandle,
827        weight: &GpuBufferHandle,
828        rows: usize,
829        cols: usize,
830        eps: f32,
831    ) -> FerrotorchResult<GpuBufferHandle> {
832        let in_buf = Self::unwrap_buffer(input)?;
833        let w_buf = Self::unwrap_buffer(weight)?;
834        let dev = self.device(input.device_ordinal())?;
835        let result = crate::kernels::gpu_rmsnorm(in_buf, w_buf, rows, cols, eps, dev)
836            .map_err(Self::map_gpu_err)?;
837        Ok(Self::wrap_buffer(result, input.device_ordinal()))
838    }
839
840    fn rmsnorm_backward_f32(
841        &self,
842        input: &GpuBufferHandle,
843        grad_output: &GpuBufferHandle,
844        weight: &GpuBufferHandle,
845        rows: usize,
846        cols: usize,
847        eps: f32,
848    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
849        let in_buf = Self::unwrap_buffer(input)?;
850        let go_buf = Self::unwrap_buffer(grad_output)?;
851        let w_buf = Self::unwrap_buffer(weight)?;
852        let dev = self.device(input.device_ordinal())?;
853        let (gi, gw) =
854            crate::kernels::gpu_rmsnorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
855                .map_err(Self::map_gpu_err)?;
856        let ordinal = input.device_ordinal();
857        Ok((Self::wrap_buffer(gi, ordinal), Self::wrap_buffer(gw, ordinal)))
858    }
859
860    fn slice_write_f32(
861        &self,
862        src: &GpuBufferHandle,
863        dst: &mut GpuBufferHandle,
864        n_batch: usize,
865        d: usize,
866        max_len: usize,
867        pos: usize,
868    ) -> FerrotorchResult<()> {
869        let src_buf = Self::unwrap_buffer(src)?;
870        let dst_buf =
871            dst.downcast_mut::<CudaBuffer<f32>>()
872                .ok_or(FerrotorchError::InvalidArgument {
873                    message: "slice_write_f32: dst is not CudaBuffer<f32>".into(),
874                })?;
875        let dev = self.device(src.device_ordinal())?;
876        crate::kernels::gpu_slice_write(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
877            .map_err(Self::map_gpu_err)?;
878        Ok(())
879    }
880
881    fn slice_read_f32(
882        &self,
883        src: &GpuBufferHandle,
884        n_batch: usize,
885        d: usize,
886        len: usize,
887        max_len: usize,
888    ) -> FerrotorchResult<GpuBufferHandle> {
889        let src_buf = Self::unwrap_buffer(src)?;
890        let dev = self.device(src.device_ordinal())?;
891        let result = crate::kernels::gpu_slice_read(src_buf, n_batch, d, len, max_len, dev)
892            .map_err(Self::map_gpu_err)?;
893        Ok(Self::wrap_buffer(result, src.device_ordinal()))
894    }
895
896    fn embed_lookup_f32(
897        &self,
898        idx: &GpuBufferHandle,
899        weight: &GpuBufferHandle,
900        d: usize,
901    ) -> FerrotorchResult<GpuBufferHandle> {
902        let idx_buf = Self::unwrap_buffer(idx)?;
903        let w_buf = Self::unwrap_buffer(weight)?;
904        let dev = self.device(idx.device_ordinal())?;
905        let result =
906            crate::kernels::gpu_embed_lookup(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
907        Ok(Self::wrap_buffer(result, idx.device_ordinal()))
908    }
909
910    fn embed_lookup_batch_f32(
911        &self,
912        indices: &GpuBufferHandle,
913        weight: &GpuBufferHandle,
914        n: usize,
915        d: usize,
916    ) -> FerrotorchResult<GpuBufferHandle> {
917        let idx_buf = Self::unwrap_buffer(indices)?;
918        let w_buf = Self::unwrap_buffer(weight)?;
919        let dev = self.device(indices.device_ordinal())?;
920        let result = crate::kernels::gpu_embed_lookup_batch(idx_buf, w_buf, n, d, dev)
921            .map_err(Self::map_gpu_err)?;
922        Ok(Self::wrap_buffer(result, indices.device_ordinal()))
923    }
924
925    fn scatter_add_rows_f32(
926        &self,
927        grad_output: &GpuBufferHandle,
928        indices: &GpuBufferHandle,
929        num_embeddings: usize,
930        d: usize,
931    ) -> FerrotorchResult<GpuBufferHandle> {
932        let go_buf = Self::unwrap_buffer(grad_output)?;
933        let idx_buf = Self::unwrap_buffer(indices)?;
934        let dev = self.device(grad_output.device_ordinal())?;
935        let result = crate::kernels::gpu_scatter_add_rows(go_buf, idx_buf, num_embeddings, d, dev)
936            .map_err(Self::map_gpu_err)?;
937        Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
938    }
939
940    fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle> {
941        let a_buf = Self::unwrap_buffer(a)?;
942        let dev = self.device(a.device_ordinal())?;
943        let result = crate::kernels::gpu_scale(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
944        Ok(Self::wrap_buffer(result, a.device_ordinal()))
945    }
946
947    fn relu_backward_f32(
948        &self,
949        grad: &GpuBufferHandle,
950        input: &GpuBufferHandle,
951    ) -> FerrotorchResult<GpuBufferHandle> {
952        let grad_buf = Self::unwrap_buffer(grad)?;
953        let input_buf = Self::unwrap_buffer(input)?;
954        let dev = self.device(grad.device_ordinal())?;
955        let result = crate::kernels::gpu_relu_backward(grad_buf, input_buf, dev)
956            .map_err(Self::map_gpu_err)?;
957        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
958    }
959
960    fn gelu_backward_f32(
961        &self,
962        grad: &GpuBufferHandle,
963        input: &GpuBufferHandle,
964    ) -> FerrotorchResult<GpuBufferHandle> {
965        let grad_buf = Self::unwrap_buffer(grad)?;
966        let input_buf = Self::unwrap_buffer(input)?;
967        let dev = self.device(grad.device_ordinal())?;
968        let result = crate::kernels::gpu_gelu_backward(grad_buf, input_buf, dev)
969            .map_err(Self::map_gpu_err)?;
970        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
971    }
972
973    fn gelu_backward_tanh_f32(
974        &self,
975        grad: &GpuBufferHandle,
976        input: &GpuBufferHandle,
977    ) -> FerrotorchResult<GpuBufferHandle> {
978        let grad_buf = Self::unwrap_buffer(grad)?;
979        let input_buf = Self::unwrap_buffer(input)?;
980        let dev = self.device(grad.device_ordinal())?;
981        let result = crate::kernels::gpu_gelu_backward_tanh(grad_buf, input_buf, dev)
982            .map_err(Self::map_gpu_err)?;
983        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
984    }
985
986    fn gelu_backward_erf_f32(
987        &self,
988        grad: &GpuBufferHandle,
989        input: &GpuBufferHandle,
990    ) -> FerrotorchResult<GpuBufferHandle> {
991        let grad_buf = Self::unwrap_buffer(grad)?;
992        let input_buf = Self::unwrap_buffer(input)?;
993        let dev = self.device(grad.device_ordinal())?;
994        let result = crate::kernels::gpu_gelu_backward_erf(grad_buf, input_buf, dev)
995            .map_err(Self::map_gpu_err)?;
996        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
997    }
998
999    fn cumsum_f32(
1000        &self,
1001        a: &GpuBufferHandle,
1002        outer: usize,
1003        dim_size: usize,
1004        inner: usize,
1005    ) -> FerrotorchResult<GpuBufferHandle> {
1006        let a_buf = Self::unwrap_buffer(a)?;
1007        let dev = self.device(a.device_ordinal())?;
1008        let result = crate::kernels::gpu_cumsum(a_buf, outer, dim_size, inner, dev)
1009            .map_err(Self::map_gpu_err)?;
1010        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1011    }
1012
1013    fn cumprod_f32(
1014        &self,
1015        a: &GpuBufferHandle,
1016        outer: usize,
1017        dim_size: usize,
1018        inner: usize,
1019    ) -> FerrotorchResult<GpuBufferHandle> {
1020        let a_buf = Self::unwrap_buffer(a)?;
1021        let dev = self.device(a.device_ordinal())?;
1022        let result = crate::kernels::gpu_cumprod(a_buf, outer, dim_size, inner, dev)
1023            .map_err(Self::map_gpu_err)?;
1024        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1025    }
1026
1027    fn cummax_f32(
1028        &self,
1029        a: &GpuBufferHandle,
1030        outer: usize,
1031        dim_size: usize,
1032        inner: usize,
1033    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1034        let a_buf = Self::unwrap_buffer(a)?;
1035        let dev = self.device(a.device_ordinal())?;
1036        let (vals, idxs) = crate::kernels::gpu_cummax(a_buf, outer, dim_size, inner, dev)
1037            .map_err(Self::map_gpu_err)?;
1038        let ord = a.device_ordinal();
1039        Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1040    }
1041
1042    fn cummin_f32(
1043        &self,
1044        a: &GpuBufferHandle,
1045        outer: usize,
1046        dim_size: usize,
1047        inner: usize,
1048    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1049        let a_buf = Self::unwrap_buffer(a)?;
1050        let dev = self.device(a.device_ordinal())?;
1051        let (vals, idxs) = crate::kernels::gpu_cummin(a_buf, outer, dim_size, inner, dev)
1052            .map_err(Self::map_gpu_err)?;
1053        let ord = a.device_ordinal();
1054        Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1055    }
1056
1057    fn logcumsumexp_f32(
1058        &self,
1059        a: &GpuBufferHandle,
1060        outer: usize,
1061        dim_size: usize,
1062        inner: usize,
1063    ) -> FerrotorchResult<GpuBufferHandle> {
1064        let a_buf = Self::unwrap_buffer(a)?;
1065        let dev = self.device(a.device_ordinal())?;
1066        let result = crate::kernels::gpu_logcumsumexp(a_buf, outer, dim_size, inner, dev)
1067            .map_err(Self::map_gpu_err)?;
1068        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1069    }
1070
1071    fn clamp_f32(
1072        &self,
1073        a: &GpuBufferHandle,
1074        min_val: f32,
1075        max_val: f32,
1076    ) -> FerrotorchResult<GpuBufferHandle> {
1077        let a_buf = Self::unwrap_buffer(a)?;
1078        let dev = self.device(a.device_ordinal())?;
1079        let result =
1080            crate::kernels::gpu_clamp(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
1081        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1082    }
1083
1084    fn silu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1085        let a_buf = Self::unwrap_buffer(a)?;
1086        let dev = self.device(a.device_ordinal())?;
1087        let result = crate::kernels::gpu_silu(a_buf, dev).map_err(Self::map_gpu_err)?;
1088        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1089    }
1090
1091    fn silu_backward_f32(
1092        &self,
1093        grad: &GpuBufferHandle,
1094        input: &GpuBufferHandle,
1095    ) -> FerrotorchResult<GpuBufferHandle> {
1096        let grad_buf = Self::unwrap_buffer(grad)?;
1097        let input_buf = Self::unwrap_buffer(input)?;
1098        let dev = self.device(grad.device_ordinal())?;
1099        let result = crate::kernels::gpu_silu_backward(grad_buf, input_buf, dev)
1100            .map_err(Self::map_gpu_err)?;
1101        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1102    }
1103
1104    fn elu_f32(&self, a: &GpuBufferHandle, alpha: f32) -> FerrotorchResult<GpuBufferHandle> {
1105        let a_buf = Self::unwrap_buffer(a)?;
1106        let dev = self.device(a.device_ordinal())?;
1107        let result = crate::kernels::gpu_elu(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
1108        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1109    }
1110
1111    fn elu_backward_f32(
1112        &self,
1113        grad: &GpuBufferHandle,
1114        input: &GpuBufferHandle,
1115        alpha: f32,
1116    ) -> FerrotorchResult<GpuBufferHandle> {
1117        let grad_buf = Self::unwrap_buffer(grad)?;
1118        let input_buf = Self::unwrap_buffer(input)?;
1119        let dev = self.device(grad.device_ordinal())?;
1120        let result = crate::kernels::gpu_elu_backward(grad_buf, input_buf, alpha, dev)
1121            .map_err(Self::map_gpu_err)?;
1122        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1123    }
1124
1125    fn mish_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1126        let a_buf = Self::unwrap_buffer(a)?;
1127        let dev = self.device(a.device_ordinal())?;
1128        let result = crate::kernels::gpu_mish(a_buf, dev).map_err(Self::map_gpu_err)?;
1129        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1130    }
1131
1132    fn mish_backward_f32(
1133        &self,
1134        grad: &GpuBufferHandle,
1135        input: &GpuBufferHandle,
1136    ) -> FerrotorchResult<GpuBufferHandle> {
1137        let grad_buf = Self::unwrap_buffer(grad)?;
1138        let input_buf = Self::unwrap_buffer(input)?;
1139        let dev = self.device(grad.device_ordinal())?;
1140        let result = crate::kernels::gpu_mish_backward(grad_buf, input_buf, dev)
1141            .map_err(Self::map_gpu_err)?;
1142        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1143    }
1144
1145    fn log_softmax_f32(
1146        &self,
1147        a: &GpuBufferHandle,
1148        cols: usize,
1149    ) -> FerrotorchResult<GpuBufferHandle> {
1150        let a_buf = Self::unwrap_buffer(a)?;
1151        let dev = self.device(a.device_ordinal())?;
1152        let result =
1153            crate::kernels::gpu_log_softmax(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
1154        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1155    }
1156
1157    fn log_softmax_backward_f32(
1158        &self,
1159        grad: &GpuBufferHandle,
1160        output: &GpuBufferHandle,
1161        cols: usize,
1162    ) -> FerrotorchResult<GpuBufferHandle> {
1163        let grad_buf = Self::unwrap_buffer(grad)?;
1164        let output_buf = Self::unwrap_buffer(output)?;
1165        let dev = self.device(grad.device_ordinal())?;
1166        let result =
1167            crate::kernels::gpu_log_softmax_backward(grad_buf, output_buf, cols, dev)
1168                .map_err(Self::map_gpu_err)?;
1169        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1170    }
1171
1172    fn index_select_1d_f32(
1173        &self,
1174        input: &GpuBufferHandle,
1175        indices: &GpuBufferHandle,
1176    ) -> FerrotorchResult<GpuBufferHandle> {
1177        let input_buf = Self::unwrap_buffer(input)?;
1178        let idx_buf = Self::unwrap_buffer(indices)?;
1179        let dev = self.device(input.device_ordinal())?;
1180        let result = crate::kernels::gpu_index_select_1d(input_buf, idx_buf, dev)
1181            .map_err(Self::map_gpu_err)?;
1182        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1183    }
1184
1185    fn scatter_add_1d_f32(
1186        &self,
1187        grad_output: &GpuBufferHandle,
1188        indices: &GpuBufferHandle,
1189        input_len: usize,
1190    ) -> FerrotorchResult<GpuBufferHandle> {
1191        let go_buf = Self::unwrap_buffer(grad_output)?;
1192        let idx_buf = Self::unwrap_buffer(indices)?;
1193        let dev = self.device(grad_output.device_ordinal())?;
1194        let result = crate::kernels::gpu_scatter_add_1d(go_buf, idx_buf, input_len, dev)
1195            .map_err(Self::map_gpu_err)?;
1196        Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
1197    }
1198
1199    fn masked_fill_f32(
1200        &self,
1201        input: &GpuBufferHandle,
1202        mask: &GpuBufferHandle,
1203        value: f32,
1204    ) -> FerrotorchResult<GpuBufferHandle> {
1205        let input_buf = Self::unwrap_buffer(input)?;
1206        let mask_buf = Self::unwrap_buffer(mask)?;
1207        let dev = self.device(input.device_ordinal())?;
1208        let result = crate::kernels::gpu_masked_fill(input_buf, mask_buf, value, dev)
1209            .map_err(Self::map_gpu_err)?;
1210        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1211    }
1212
1213    fn masked_zero_f32(
1214        &self,
1215        grad: &GpuBufferHandle,
1216        mask: &GpuBufferHandle,
1217    ) -> FerrotorchResult<GpuBufferHandle> {
1218        let grad_buf = Self::unwrap_buffer(grad)?;
1219        let mask_buf = Self::unwrap_buffer(mask)?;
1220        let dev = self.device(grad.device_ordinal())?;
1221        let result =
1222            crate::kernels::gpu_masked_zero(grad_buf, mask_buf, dev).map_err(Self::map_gpu_err)?;
1223        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1224    }
1225
1226    fn sigmoid_backward_f32(
1227        &self,
1228        grad: &GpuBufferHandle,
1229        output: &GpuBufferHandle,
1230    ) -> FerrotorchResult<GpuBufferHandle> {
1231        let grad_buf = Self::unwrap_buffer(grad)?;
1232        let output_buf = Self::unwrap_buffer(output)?;
1233        let dev = self.device(grad.device_ordinal())?;
1234        let result = crate::kernels::gpu_sigmoid_backward(grad_buf, output_buf, dev)
1235            .map_err(Self::map_gpu_err)?;
1236        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1237    }
1238
1239    fn tanh_backward_f32(
1240        &self,
1241        grad: &GpuBufferHandle,
1242        output: &GpuBufferHandle,
1243    ) -> FerrotorchResult<GpuBufferHandle> {
1244        let grad_buf = Self::unwrap_buffer(grad)?;
1245        let output_buf = Self::unwrap_buffer(output)?;
1246        let dev = self.device(grad.device_ordinal())?;
1247        let result = crate::kernels::gpu_tanh_backward(grad_buf, output_buf, dev)
1248            .map_err(Self::map_gpu_err)?;
1249        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1250    }
1251
1252    fn softmax_backward_f32(
1253        &self,
1254        grad: &GpuBufferHandle,
1255        output: &GpuBufferHandle,
1256        cols: usize,
1257    ) -> FerrotorchResult<GpuBufferHandle> {
1258        let grad_buf = Self::unwrap_buffer(grad)?;
1259        let output_buf = Self::unwrap_buffer(output)?;
1260        let dev = self.device(grad.device_ordinal())?;
1261        let result = crate::kernels::gpu_softmax_backward(grad_buf, output_buf, cols, dev)
1262            .map_err(Self::map_gpu_err)?;
1263        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1264    }
1265
1266    fn layernorm_backward_f32(
1267        &self,
1268        input: &GpuBufferHandle,
1269        grad_output: &GpuBufferHandle,
1270        weight: &GpuBufferHandle,
1271        rows: usize,
1272        cols: usize,
1273        eps: f32,
1274    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
1275        let in_buf = Self::unwrap_buffer(input)?;
1276        let go_buf = Self::unwrap_buffer(grad_output)?;
1277        let w_buf = Self::unwrap_buffer(weight)?;
1278        let dev = self.device(input.device_ordinal())?;
1279        let (gi, gw, gb) =
1280            crate::kernels::gpu_layernorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
1281                .map_err(Self::map_gpu_err)?;
1282        let ordinal = input.device_ordinal();
1283        Ok((
1284            Self::wrap_buffer(gi, ordinal),
1285            Self::wrap_buffer(gw, ordinal),
1286            Self::wrap_buffer(gb, ordinal),
1287        ))
1288    }
1289
1290    fn sum_axis_f32(
1291        &self,
1292        a: &GpuBufferHandle,
1293        shape: &[usize],
1294        axis: usize,
1295    ) -> FerrotorchResult<GpuBufferHandle> {
1296        let a_buf = Self::unwrap_buffer(a)?;
1297        let dev = self.device(a.device_ordinal())?;
1298        let outer: usize = shape[..axis].iter().product();
1299        let axis_size = shape[axis];
1300        let inner: usize = shape[axis + 1..].iter().product::<usize>().max(1);
1301        let result = crate::kernels::gpu_sum_axis(a_buf, outer, axis_size, inner, dev)
1302            .map_err(Self::map_gpu_err)?;
1303        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1304    }
1305
1306    fn matmul_f16_f32(
1307        &self,
1308        a: &GpuBufferHandle,
1309        b: &GpuBufferHandle,
1310        m: usize,
1311        k: usize,
1312        n: usize,
1313    ) -> FerrotorchResult<GpuBufferHandle> {
1314        let a_buf = Self::unwrap_buffer(a)?;
1315        let b_buf = Self::unwrap_buffer(b)?;
1316        let dev = self.device(a.device_ordinal())?;
1317        let result =
1318            crate::blas::gpu_matmul_f16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1319        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1320    }
1321
1322    fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
1323        let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1324            FerrotorchError::InvalidArgument {
1325                message: "failed to lock CUDA RNG manager".into(),
1326            }
1327        })?;
1328        let state = mgr.get_rng_state(device);
1329        Ok(GpuRngState {
1330            counter: state.counter,
1331            seed: state.seed,
1332            offset: state.offset,
1333            device,
1334        })
1335    }
1336
1337    fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
1338        let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1339            FerrotorchError::InvalidArgument {
1340                message: "failed to lock CUDA RNG manager".into(),
1341            }
1342        })?;
1343        mgr.set_rng_state(
1344            state.device,
1345            crate::rng::PhiloxState {
1346                counter: state.counter,
1347                seed: state.seed,
1348                offset: state.offset,
1349            },
1350        );
1351        Ok(())
1352    }
1353
1354    fn strided_split_f32(
1355        &self,
1356        input: &GpuBufferHandle,
1357        total_along_axis: usize,
1358        split_offset: usize,
1359        split_size: usize,
1360        inner_size: usize,
1361        n: usize,
1362    ) -> FerrotorchResult<GpuBufferHandle> {
1363        let in_buf = Self::unwrap_buffer(input)?;
1364        let dev = self.device(input.device_ordinal())?;
1365        let result = crate::kernels::gpu_strided_split(
1366            in_buf,
1367            total_along_axis,
1368            split_offset,
1369            split_size,
1370            inner_size,
1371            n,
1372            dev,
1373        )
1374        .map_err(Self::map_gpu_err)?;
1375        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1376    }
1377
1378    fn strided_cat_f32(
1379        &self,
1380        input: &GpuBufferHandle,
1381        output: &mut GpuBufferHandle,
1382        total_along_axis: usize,
1383        cat_offset: usize,
1384        part_size: usize,
1385        inner_size: usize,
1386        n: usize,
1387    ) -> FerrotorchResult<()> {
1388        let in_buf = Self::unwrap_buffer(input)?;
1389        let dev = self.device(input.device_ordinal())?;
1390        let out_buf =
1391            output
1392                .downcast_mut::<CudaBuffer<f32>>()
1393                .ok_or(FerrotorchError::InvalidArgument {
1394                    message: "strided_cat_f32: output is not CudaBuffer<f32>".into(),
1395                })?;
1396        crate::kernels::gpu_strided_cat(
1397            in_buf,
1398            out_buf,
1399            total_along_axis,
1400            cat_offset,
1401            part_size,
1402            inner_size,
1403            n,
1404            dev,
1405        )
1406        .map_err(Self::map_gpu_err)?;
1407        Ok(())
1408    }
1409}
1410
1411// ---------------------------------------------------------------------------
1412// Registration
1413// ---------------------------------------------------------------------------
1414
1415/// Get the `GpuDevice` from the registered CUDA backend.
1416///
1417/// This retrieves the device that was created during [`init_cuda_backend`],
1418/// ensuring all kernel modules and cuBLAS handles are shared. Creating a
1419/// second `GpuDevice` via `GpuDevice::new(0)` would create a separate
1420/// CUDA context with its own module cache, which is not interoperable.
1421pub fn get_cuda_device() -> FerrotorchResult<Arc<GpuDevice>> {
1422    let backend =
1423        ferrotorch_core::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
1424    // The global backend is a &dyn GpuBackend. We know it's CudaBackendImpl
1425    // because init_cuda_backend registered it. Downcast via Any.
1426    let cuda_backend = backend.as_any().downcast_ref::<CudaBackendImpl>().ok_or(
1427        FerrotorchError::InvalidArgument {
1428            message: "registered GPU backend is not CudaBackendImpl".into(),
1429        },
1430    )?;
1431    Ok(Arc::clone(cuda_backend.default_device()?))
1432}
1433
1434/// Initialize the CUDA backend and register it with ferrotorch-core.
1435///
1436/// This must be called before any GPU tensor operations. It creates a
1437/// [`CudaBackendImpl`] (initializing CUDA device 0) and registers it via
1438/// [`ferrotorch_core::gpu_dispatch::register_gpu_backend`].
1439///
1440/// Calling this a second time returns an error (the backend is already
1441/// registered).
1442///
1443/// # Errors
1444///
1445/// - [`FerrotorchError::InvalidArgument`] if CUDA initialization fails.
1446/// - [`FerrotorchError::InvalidArgument`] if a GPU backend is already registered.
1447pub fn init_cuda_backend() -> FerrotorchResult<()> {
1448    // Idempotent: if already registered, return Ok silently.
1449    if ferrotorch_core::gpu_dispatch::has_gpu_backend() {
1450        return Ok(());
1451    }
1452    let backend = CudaBackendImpl::new()?;
1453    // OnceLock::set can still race if two threads call init concurrently —
1454    // if that happens, the second set() fails but the backend is registered
1455    // by the first. We treat that as success.
1456    let _ = ferrotorch_core::gpu_dispatch::register_gpu_backend(Box::new(backend));
1457    Ok(())
1458}
1459
1460// ---------------------------------------------------------------------------
1461// Tests
1462// ---------------------------------------------------------------------------
1463
1464#[cfg(test)]
1465#[cfg(feature = "cuda")]
1466mod tests {
1467    use super::*;
1468    use ferrotorch_core::gpu_dispatch;
1469
1470    // Note: Because `register_gpu_backend` uses a `OnceLock`, only the first
1471    // test to call `init_cuda_backend()` will succeed at registration. The
1472    // others will see the backend as already registered. We handle this by
1473    // checking `has_gpu_backend()` before calling init.
1474
1475    /// Ensure the backend can be initialized (or was already initialized).
1476    fn ensure_init() {
1477        if !gpu_dispatch::has_gpu_backend() {
1478            init_cuda_backend().expect("init_cuda_backend");
1479        }
1480    }
1481
1482    #[test]
1483    fn test_init_cuda_backend() {
1484        // First call succeeds (or backend was already registered by another test).
1485        ensure_init();
1486        assert!(gpu_dispatch::has_gpu_backend());
1487    }
1488
1489    #[test]
1490    fn test_gpu_backend_returns_some() {
1491        ensure_init();
1492        assert!(gpu_dispatch::gpu_backend().is_some());
1493    }
1494
1495    #[test]
1496    fn test_roundtrip_cpu_gpu_cpu() {
1497        ensure_init();
1498        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1499
1500        let host: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1501        let bytes: &[u8] = unsafe {
1502            std::slice::from_raw_parts(
1503                host.as_ptr() as *const u8,
1504                host.len() * std::mem::size_of::<f32>(),
1505            )
1506        };
1507
1508        let handle = backend.cpu_to_gpu(bytes, 4, 0).expect("cpu_to_gpu");
1509        assert_eq!(handle.len(), 5);
1510        assert_eq!(handle.device_ordinal(), 0);
1511
1512        let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu");
1513        let back: &[f32] = unsafe {
1514            std::slice::from_raw_parts(back_bytes.as_ptr() as *const f32, back_bytes.len() / 4)
1515        };
1516        assert_eq!(back, &host[..]);
1517    }
1518
1519    #[test]
1520    fn test_add_f32() {
1521        ensure_init();
1522        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1523
1524        let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1525        let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
1526        let expected: Vec<f32> = vec![11.0, 22.0, 33.0, 44.0];
1527
1528        let a_bytes: &[u8] =
1529            unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
1530        let b_bytes: &[u8] =
1531            unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
1532
1533        let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
1534        let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
1535
1536        let result = backend.add_f32(&a_handle, &b_handle).expect("add_f32");
1537        assert_eq!(result.len(), 4);
1538
1539        let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
1540        let result_f32: &[f32] = unsafe {
1541            std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
1542        };
1543
1544        for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
1545            assert!(
1546                (got - exp).abs() < 1e-6,
1547                "element {i}: got {got}, expected {exp}",
1548            );
1549        }
1550    }
1551
1552    #[test]
1553    fn test_matmul_f32() {
1554        ensure_init();
1555        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1556
1557        // A = [[1, 2, 3],
1558        //      [4, 5, 6]]  (2x3)
1559        // B = [[7, 8],
1560        //      [9, 10],
1561        //      [11, 12]]   (3x2)
1562        // C = [[58, 64],
1563        //      [139, 154]] (2x2)
1564        let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1565        let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
1566        let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
1567
1568        let a_bytes: &[u8] =
1569            unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
1570        let b_bytes: &[u8] =
1571            unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
1572
1573        let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
1574        let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
1575
1576        let result = backend
1577            .matmul_f32(&a_handle, &b_handle, 2, 3, 2)
1578            .expect("matmul_f32");
1579        assert_eq!(result.len(), 4);
1580
1581        let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
1582        let result_f32: &[f32] = unsafe {
1583            std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
1584        };
1585
1586        for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
1587            assert!(
1588                (got - exp).abs() < 1e-3,
1589                "element {i}: got {got}, expected {exp}",
1590            );
1591        }
1592    }
1593}