Skip to main content

ferrotorch_gpu/
backend_impl.rs

1//! CUDA implementation of the [`GpuBackend`] trait from ferrotorch-core.
2//!
3//! This module bridges the existing GPU operations (`gpu_add`, `gpu_matmul_f32`,
4//! etc.) to the type-erased [`GpuBackend`] dispatch interface, enabling
5//! ferrotorch-core to call GPU operations without depending on this crate
6//! directly.
7//!
8//! # Initialization
9//!
10//! Call [`init_cuda_backend`] once at startup (typically via `ferrotorch::init()`).
11//! This creates a [`CudaBackendImpl`], initializes CUDA device 0, and registers
12//! it with [`ferrotorch_core::gpu_dispatch::register_gpu_backend`].
13
14use std::sync::Arc;
15
16use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
17use ferrotorch_core::gpu_dispatch::{GpuBackend, GpuBufferHandle, GpuRngState};
18
19use crate::buffer::CudaBuffer;
20use crate::device::GpuDevice;
21
22// ---------------------------------------------------------------------------
23// CudaBackendImpl
24// ---------------------------------------------------------------------------
25
26/// CUDA implementation of the [`GpuBackend`] trait.
27///
28/// Holds one or more [`GpuDevice`] handles (currently device 0 only) and
29/// delegates every trait method to the corresponding function in
30/// [`crate::kernels`], [`crate::blas`], or [`crate::transfer`].
31pub struct CudaBackendImpl {
32    devices: Vec<Arc<GpuDevice>>,
33}
34
35impl CudaBackendImpl {
36    /// Create a new CUDA backend, initializing device 0.
37    ///
38    /// # Errors
39    ///
40    /// Returns [`FerrotorchError::InvalidArgument`] if CUDA initialization fails
41    /// (e.g. no GPU available, driver not loaded).
42    pub fn new() -> FerrotorchResult<Self> {
43        let device = Arc::new(
44            GpuDevice::new(0).map_err(|e| FerrotorchError::InvalidArgument {
45                message: format!("CUDA init failed: {e}"),
46            })?,
47        );
48        Ok(Self {
49            devices: vec![device],
50        })
51    }
52
53    /// Get the device for ordinal 0 (the default device).
54    pub fn default_device(&self) -> FerrotorchResult<&Arc<GpuDevice>> {
55        self.device(0)
56    }
57
58    /// Look up a device by ordinal.
59    fn device(&self, ordinal: usize) -> FerrotorchResult<&Arc<GpuDevice>> {
60        self.devices
61            .get(ordinal)
62            .ok_or(FerrotorchError::InvalidArgument {
63                message: format!("CUDA device {ordinal} not available"),
64            })
65    }
66
67    /// Wrap a `CudaBuffer<f32>` into a type-erased [`GpuBufferHandle`].
68    fn wrap_buffer(buf: CudaBuffer<f32>, ordinal: usize) -> GpuBufferHandle {
69        let len = buf.len();
70        GpuBufferHandle::new(Box::new(buf), ordinal, len)
71    }
72
73    /// Wrap a `CudaBuffer<f64>` into a type-erased [`GpuBufferHandle`].
74    fn wrap_buffer_f64(buf: CudaBuffer<f64>, ordinal: usize) -> GpuBufferHandle {
75        let len = buf.len();
76        GpuBufferHandle::new(Box::new(buf), ordinal, len)
77    }
78
79    /// Extract a `&CudaBuffer<f32>` from a [`GpuBufferHandle`].
80    fn unwrap_buffer(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f32>> {
81        handle
82            .downcast_ref::<CudaBuffer<f32>>()
83            .ok_or(FerrotorchError::InvalidArgument {
84                message: "GPU handle does not contain a CudaBuffer<f32>".into(),
85            })
86    }
87
88    /// Extract a `&mut CudaBuffer<f32>` from a [`GpuBufferHandle`].
89    fn unwrap_buffer_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f32>> {
90        handle
91            .downcast_mut::<CudaBuffer<f32>>()
92            .ok_or(FerrotorchError::InvalidArgument {
93                message: "GPU handle does not contain a CudaBuffer<f32>".into(),
94            })
95    }
96
97    /// Extract a `&mut CudaBuffer<f64>` from a [`GpuBufferHandle`].
98    fn unwrap_buffer_f64_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f64>> {
99        handle
100            .downcast_mut::<CudaBuffer<f64>>()
101            .ok_or(FerrotorchError::InvalidArgument {
102                message: "GPU handle does not contain a CudaBuffer<f64>".into(),
103            })
104    }
105
106    /// Extract a `&CudaBuffer<f64>` from a [`GpuBufferHandle`].
107    fn unwrap_buffer_f64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f64>> {
108        handle
109            .downcast_ref::<CudaBuffer<f64>>()
110            .ok_or(FerrotorchError::InvalidArgument {
111                message: "GPU handle does not contain a CudaBuffer<f64>".into(),
112            })
113    }
114
115    /// Convert a [`crate::error::GpuError`] into a [`FerrotorchError`].
116    fn map_gpu_err(e: crate::error::GpuError) -> FerrotorchError {
117        FerrotorchError::InvalidArgument {
118            message: format!("{e}"),
119        }
120    }
121}
122
123// ---------------------------------------------------------------------------
124// GpuBackend implementation
125// ---------------------------------------------------------------------------
126
127impl GpuBackend for CudaBackendImpl {
128    fn as_any(&self) -> &dyn std::any::Any {
129        self
130    }
131
132    fn raw_device_ptr(&self, handle: &GpuBufferHandle) -> *const std::ffi::c_void {
133        use cudarc::driver::DevicePtr;
134        let dev = match self.device(handle.device_ordinal()) {
135            Ok(d) => d,
136            Err(_) => return std::ptr::null(),
137        };
138        let stream = dev.stream();
139        if let Ok(buf) = Self::unwrap_buffer(handle) {
140            let (ptr, _sync) = buf.inner().device_ptr(&stream);
141            ptr as *const std::ffi::c_void
142        } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
143            let (ptr, _sync) = buf.inner().device_ptr(&stream);
144            ptr as *const std::ffi::c_void
145        } else {
146            std::ptr::null()
147        }
148    }
149
150    fn raw_device_ptr_mut(&self, handle: &mut GpuBufferHandle) -> *mut std::ffi::c_void {
151        use cudarc::driver::DevicePtrMut;
152        let ordinal = handle.device_ordinal();
153        let dev = match self.device(ordinal) {
154            Ok(d) => d,
155            Err(_) => return std::ptr::null_mut(),
156        };
157        let stream = dev.stream();
158        if let Some(buf) = handle.downcast_mut::<CudaBuffer<f32>>() {
159            let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
160            ptr as *mut std::ffi::c_void
161        } else if let Some(buf) = handle.downcast_mut::<CudaBuffer<f64>>() {
162            let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
163            ptr as *mut std::ffi::c_void
164        } else {
165            std::ptr::null_mut()
166        }
167    }
168
169    fn buffer_elem_size(&self, handle: &GpuBufferHandle) -> usize {
170        if Self::unwrap_buffer(handle).is_ok() {
171            4 // f32
172        } else if Self::unwrap_buffer_f64(handle).is_ok() {
173            8 // f64
174        } else {
175            0
176        }
177    }
178
179    fn cpu_to_gpu(
180        &self,
181        data: &[u8],
182        elem_size: usize,
183        device: usize,
184    ) -> FerrotorchResult<GpuBufferHandle> {
185        let dev = self.device(device)?;
186        match elem_size {
187            4 => {
188                // SAFETY: The caller (ferrotorch-core) guarantees that `data`
189                // was originally an f32 slice serialised to bytes.
190                let count = data.len() / 4;
191                let f32_data: &[f32] =
192                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
193                let buf = crate::transfer::cpu_to_gpu(f32_data, dev).map_err(Self::map_gpu_err)?;
194                Ok(Self::wrap_buffer(buf, device))
195            }
196            8 => {
197                // SAFETY: The caller (ferrotorch-core) guarantees that `data`
198                // was originally an f64 slice serialised to bytes.
199                let count = data.len() / 8;
200                let f64_data: &[f64] =
201                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
202                let buf = crate::transfer::cpu_to_gpu(f64_data, dev).map_err(Self::map_gpu_err)?;
203                Ok(Self::wrap_buffer_f64(buf, device))
204            }
205            other => Err(FerrotorchError::InvalidArgument {
206                message: format!("cpu_to_gpu: unsupported elem_size {other} (expected 4 or 8)"),
207            }),
208        }
209    }
210
211    fn cpu_to_gpu_pinned(
212        &self,
213        data: &[u8],
214        elem_size: usize,
215        device: usize,
216    ) -> FerrotorchResult<GpuBufferHandle> {
217        let dev = self.device(device)?;
218        match elem_size {
219            4 => {
220                let count = data.len() / 4;
221                let f32_data: &[f32] =
222                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
223                let buf = crate::transfer::cpu_to_gpu_pinned(f32_data, dev)
224                    .map_err(Self::map_gpu_err)?;
225                Ok(Self::wrap_buffer(buf, device))
226            }
227            8 => {
228                let count = data.len() / 8;
229                let f64_data: &[f64] =
230                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
231                let buf = crate::transfer::cpu_to_gpu_pinned(f64_data, dev)
232                    .map_err(Self::map_gpu_err)?;
233                Ok(Self::wrap_buffer_f64(buf, device))
234            }
235            other => Err(FerrotorchError::InvalidArgument {
236                message: format!(
237                    "cpu_to_gpu_pinned: unsupported elem_size {other} (expected 4 or 8)"
238                ),
239            }),
240        }
241    }
242
243    fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>> {
244        let dev = self.device(handle.device_ordinal())?;
245
246        // Try f32 first, then f64.
247        if let Ok(buf) = Self::unwrap_buffer(handle) {
248            let f32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
249
250            // Reinterpret Vec<f32> as Vec<u8> without copying.
251            // SAFETY: f32 has alignment 4 and size 4. We adjust len and capacity
252            // accordingly. The original Vec is consumed via ManuallyDrop so its
253            // destructor won't free the allocation.
254            let bytes = unsafe {
255                let mut v = std::mem::ManuallyDrop::new(f32_data);
256                let ptr = v.as_mut_ptr() as *mut u8;
257                let len = v.len() * 4;
258                let cap = v.capacity() * 4;
259                Vec::from_raw_parts(ptr, len, cap)
260            };
261            Ok(bytes)
262        } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
263            let f64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
264
265            // Reinterpret Vec<f64> as Vec<u8> without copying.
266            // SAFETY: f64 has alignment 8 and size 8. We adjust len and capacity
267            // accordingly. The original Vec is consumed via ManuallyDrop so its
268            // destructor won't free the allocation.
269            let bytes = unsafe {
270                let mut v = std::mem::ManuallyDrop::new(f64_data);
271                let ptr = v.as_mut_ptr() as *mut u8;
272                let len = v.len() * 8;
273                let cap = v.capacity() * 8;
274                Vec::from_raw_parts(ptr, len, cap)
275            };
276            Ok(bytes)
277        } else {
278            Err(FerrotorchError::InvalidArgument {
279                message: "gpu_to_cpu: handle is neither CudaBuffer<f32> nor CudaBuffer<f64>".into(),
280            })
281        }
282    }
283
284    fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
285        // Clone via GPU -> CPU -> GPU round-trip.
286        // Correct but not optimal; a device-to-device memcpy would be better.
287        let bytes = self.gpu_to_cpu(handle)?;
288        // Determine elem_size from the concrete buffer type.
289        let elem_size = if handle.downcast_ref::<CudaBuffer<f64>>().is_some() {
290            8
291        } else {
292            4
293        };
294        self.cpu_to_gpu(&bytes, elem_size, handle.device_ordinal())
295    }
296
297    fn alloc_zeros(
298        &self,
299        len: usize,
300        elem_size: usize,
301        device: usize,
302    ) -> FerrotorchResult<GpuBufferHandle> {
303        let dev = self.device(device)?;
304        match elem_size {
305            4 => {
306                let buf = crate::transfer::alloc_zeros_f32(len, dev).map_err(Self::map_gpu_err)?;
307                Ok(Self::wrap_buffer(buf, device))
308            }
309            8 => {
310                let buf = crate::transfer::alloc_zeros_f64(len, dev).map_err(Self::map_gpu_err)?;
311                Ok(Self::wrap_buffer_f64(buf, device))
312            }
313            other => Err(FerrotorchError::InvalidArgument {
314                message: format!("alloc_zeros: unsupported elem_size {other} (expected 4 or 8)"),
315            }),
316        }
317    }
318
319    // -- Elementwise f32 ------------------------------------------------------
320
321    fn add_f32(
322        &self,
323        a: &GpuBufferHandle,
324        b: &GpuBufferHandle,
325    ) -> FerrotorchResult<GpuBufferHandle> {
326        let a_buf = Self::unwrap_buffer(a)?;
327        let b_buf = Self::unwrap_buffer(b)?;
328        let dev = self.device(a.device_ordinal())?;
329        let result = crate::kernels::gpu_add(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
330        Ok(Self::wrap_buffer(result, a.device_ordinal()))
331    }
332
333    fn sub_f32(
334        &self,
335        a: &GpuBufferHandle,
336        b: &GpuBufferHandle,
337    ) -> FerrotorchResult<GpuBufferHandle> {
338        let a_buf = Self::unwrap_buffer(a)?;
339        let b_buf = Self::unwrap_buffer(b)?;
340        let dev = self.device(a.device_ordinal())?;
341        let result = crate::kernels::gpu_sub(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
342        Ok(Self::wrap_buffer(result, a.device_ordinal()))
343    }
344
345    fn mul_f32(
346        &self,
347        a: &GpuBufferHandle,
348        b: &GpuBufferHandle,
349    ) -> FerrotorchResult<GpuBufferHandle> {
350        let a_buf = Self::unwrap_buffer(a)?;
351        let b_buf = Self::unwrap_buffer(b)?;
352        let dev = self.device(a.device_ordinal())?;
353        let result = crate::kernels::gpu_mul(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
354        Ok(Self::wrap_buffer(result, a.device_ordinal()))
355    }
356
357    fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
358        let a_buf = Self::unwrap_buffer(a)?;
359        let dev = self.device(a.device_ordinal())?;
360        let result = crate::kernels::gpu_neg(a_buf, dev).map_err(Self::map_gpu_err)?;
361        Ok(Self::wrap_buffer(result, a.device_ordinal()))
362    }
363
364    fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
365        let a_buf = Self::unwrap_buffer(a)?;
366        let dev = self.device(a.device_ordinal())?;
367        let result = crate::kernels::gpu_relu(a_buf, dev).map_err(Self::map_gpu_err)?;
368        Ok(Self::wrap_buffer(result, a.device_ordinal()))
369    }
370
371    fn div_f32(
372        &self,
373        a: &GpuBufferHandle,
374        b: &GpuBufferHandle,
375    ) -> FerrotorchResult<GpuBufferHandle> {
376        let a_buf = Self::unwrap_buffer(a)?;
377        let b_buf = Self::unwrap_buffer(b)?;
378        let dev = self.device(a.device_ordinal())?;
379        let result = crate::kernels::gpu_div(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
380        Ok(Self::wrap_buffer(result, a.device_ordinal()))
381    }
382
383    fn exp_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
384        let a_buf = Self::unwrap_buffer(a)?;
385        let dev = self.device(a.device_ordinal())?;
386        let result = crate::kernels::gpu_exp(a_buf, dev).map_err(Self::map_gpu_err)?;
387        Ok(Self::wrap_buffer(result, a.device_ordinal()))
388    }
389
390    fn log_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
391        let a_buf = Self::unwrap_buffer(a)?;
392        let dev = self.device(a.device_ordinal())?;
393        let result = crate::kernels::gpu_log(a_buf, dev).map_err(Self::map_gpu_err)?;
394        Ok(Self::wrap_buffer(result, a.device_ordinal()))
395    }
396
397    fn sqrt_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
398        let a_buf = Self::unwrap_buffer(a)?;
399        let dev = self.device(a.device_ordinal())?;
400        let result = crate::kernels::gpu_sqrt(a_buf, dev).map_err(Self::map_gpu_err)?;
401        Ok(Self::wrap_buffer(result, a.device_ordinal()))
402    }
403
404    fn pow_f32(&self, a: &GpuBufferHandle, exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
405        let a_buf = Self::unwrap_buffer(a)?;
406        let dev = self.device(a.device_ordinal())?;
407        let result =
408            crate::kernels::gpu_pow(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
409        Ok(Self::wrap_buffer(result, a.device_ordinal()))
410    }
411
412    fn abs_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
413        let a_buf = Self::unwrap_buffer(a)?;
414        let dev = self.device(a.device_ordinal())?;
415        let result = crate::kernels::gpu_abs(a_buf, dev).map_err(Self::map_gpu_err)?;
416        Ok(Self::wrap_buffer(result, a.device_ordinal()))
417    }
418
419    fn sigmoid_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
420        let a_buf = Self::unwrap_buffer(a)?;
421        let dev = self.device(a.device_ordinal())?;
422        let result = crate::kernels::gpu_sigmoid(a_buf, dev).map_err(Self::map_gpu_err)?;
423        Ok(Self::wrap_buffer(result, a.device_ordinal()))
424    }
425
426    fn tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
427        let a_buf = Self::unwrap_buffer(a)?;
428        let dev = self.device(a.device_ordinal())?;
429        let result = crate::kernels::gpu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
430        Ok(Self::wrap_buffer(result, a.device_ordinal()))
431    }
432
433    // -----------------------------------------------------------------------
434    // f64 elementwise ops
435    // -----------------------------------------------------------------------
436
437    fn add_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
438        let a_buf = Self::unwrap_buffer_f64(a)?;
439        let b_buf = Self::unwrap_buffer_f64(b)?;
440        let dev = self.device(a.device_ordinal())?;
441        let result = crate::kernels::gpu_add_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
442        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
443    }
444
445    fn sub_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
446        let a_buf = Self::unwrap_buffer_f64(a)?;
447        let b_buf = Self::unwrap_buffer_f64(b)?;
448        let dev = self.device(a.device_ordinal())?;
449        let result = crate::kernels::gpu_sub_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
450        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
451    }
452
453    fn mul_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
454        let a_buf = Self::unwrap_buffer_f64(a)?;
455        let b_buf = Self::unwrap_buffer_f64(b)?;
456        let dev = self.device(a.device_ordinal())?;
457        let result = crate::kernels::gpu_mul_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
458        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
459    }
460
461    fn div_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
462        let a_buf = Self::unwrap_buffer_f64(a)?;
463        let b_buf = Self::unwrap_buffer_f64(b)?;
464        let dev = self.device(a.device_ordinal())?;
465        let result = crate::kernels::gpu_div_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
466        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
467    }
468
469    fn neg_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
470        let a_buf = Self::unwrap_buffer_f64(a)?;
471        let dev = self.device(a.device_ordinal())?;
472        let result = crate::kernels::gpu_neg_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
473        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
474    }
475
476    fn relu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
477        let a_buf = Self::unwrap_buffer_f64(a)?;
478        let dev = self.device(a.device_ordinal())?;
479        let result = crate::kernels::gpu_relu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
480        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
481    }
482
483    fn scale_f64(&self, a: &GpuBufferHandle, scalar: f64) -> FerrotorchResult<GpuBufferHandle> {
484        let a_buf = Self::unwrap_buffer_f64(a)?;
485        let dev = self.device(a.device_ordinal())?;
486        let result = crate::kernels::gpu_scale_f64(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
487        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
488    }
489
490    fn exp_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
491        let a_buf = Self::unwrap_buffer_f64(a)?;
492        let dev = self.device(a.device_ordinal())?;
493        let result = crate::kernels::gpu_exp_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
494        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
495    }
496
497    fn log_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
498        let a_buf = Self::unwrap_buffer_f64(a)?;
499        let dev = self.device(a.device_ordinal())?;
500        let result = crate::kernels::gpu_log_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
501        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
502    }
503
504    fn sqrt_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
505        let a_buf = Self::unwrap_buffer_f64(a)?;
506        let dev = self.device(a.device_ordinal())?;
507        let result = crate::kernels::gpu_sqrt_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
508        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
509    }
510
511    fn pow_f64(&self, a: &GpuBufferHandle, exponent: f64) -> FerrotorchResult<GpuBufferHandle> {
512        let a_buf = Self::unwrap_buffer_f64(a)?;
513        let dev = self.device(a.device_ordinal())?;
514        let result = crate::kernels::gpu_pow_f64(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
515        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
516    }
517
518    fn abs_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
519        let a_buf = Self::unwrap_buffer_f64(a)?;
520        let dev = self.device(a.device_ordinal())?;
521        let result = crate::kernels::gpu_abs_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
522        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
523    }
524
525    fn sigmoid_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
526        let a_buf = Self::unwrap_buffer_f64(a)?;
527        let dev = self.device(a.device_ordinal())?;
528        let result = crate::kernels::gpu_sigmoid_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
529        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
530    }
531
532    fn tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
533        let a_buf = Self::unwrap_buffer_f64(a)?;
534        let dev = self.device(a.device_ordinal())?;
535        let result = crate::kernels::gpu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
536        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
537    }
538
539    // f64 backward ops
540    fn relu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
541        let g_buf = Self::unwrap_buffer_f64(grad)?;
542        let i_buf = Self::unwrap_buffer_f64(input)?;
543        let dev = self.device(grad.device_ordinal())?;
544        let result = crate::kernels::gpu_relu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
545        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
546    }
547
548    fn sigmoid_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
549        let g_buf = Self::unwrap_buffer_f64(grad)?;
550        let o_buf = Self::unwrap_buffer_f64(output)?;
551        let dev = self.device(grad.device_ordinal())?;
552        let result = crate::kernels::gpu_sigmoid_backward_f64(g_buf, o_buf, dev).map_err(Self::map_gpu_err)?;
553        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
554    }
555
556    fn tanh_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
557        let g_buf = Self::unwrap_buffer_f64(grad)?;
558        let o_buf = Self::unwrap_buffer_f64(output)?;
559        let dev = self.device(grad.device_ordinal())?;
560        let result = crate::kernels::gpu_tanh_backward_f64(g_buf, o_buf, dev).map_err(Self::map_gpu_err)?;
561        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
562    }
563
564    // f64 activation forward ops
565
566    fn gelu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
567        let a_buf = Self::unwrap_buffer_f64(a)?;
568        let dev = self.device(a.device_ordinal())?;
569        let result = crate::kernels::gpu_gelu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
570        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
571    }
572
573    fn gelu_tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
574        let a_buf = Self::unwrap_buffer_f64(a)?;
575        let dev = self.device(a.device_ordinal())?;
576        let result = crate::kernels::gpu_gelu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
577        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
578    }
579
580    fn gelu_erf_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
581        let a_buf = Self::unwrap_buffer_f64(a)?;
582        let dev = self.device(a.device_ordinal())?;
583        let result = crate::kernels::gpu_gelu_erf_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
584        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
585    }
586
587    fn silu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
588        let a_buf = Self::unwrap_buffer_f64(a)?;
589        let dev = self.device(a.device_ordinal())?;
590        let result = crate::kernels::gpu_silu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
591        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
592    }
593
594    fn elu_f64(&self, a: &GpuBufferHandle, alpha: f64) -> FerrotorchResult<GpuBufferHandle> {
595        let a_buf = Self::unwrap_buffer_f64(a)?;
596        let dev = self.device(a.device_ordinal())?;
597        let result = crate::kernels::gpu_elu_f64(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
598        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
599    }
600
601    fn mish_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
602        let a_buf = Self::unwrap_buffer_f64(a)?;
603        let dev = self.device(a.device_ordinal())?;
604        let result = crate::kernels::gpu_mish_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
605        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
606    }
607
608    fn clamp_f64(&self, a: &GpuBufferHandle, min_val: f64, max_val: f64) -> FerrotorchResult<GpuBufferHandle> {
609        let a_buf = Self::unwrap_buffer_f64(a)?;
610        let dev = self.device(a.device_ordinal())?;
611        let result = crate::kernels::gpu_clamp_f64(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
612        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
613    }
614
615    // f64 activation backward ops
616
617    fn gelu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
618        let g_buf = Self::unwrap_buffer_f64(grad)?;
619        let i_buf = Self::unwrap_buffer_f64(input)?;
620        let dev = self.device(grad.device_ordinal())?;
621        let result = crate::kernels::gpu_gelu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
622        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
623    }
624
625    fn gelu_backward_tanh_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
626        let g_buf = Self::unwrap_buffer_f64(grad)?;
627        let i_buf = Self::unwrap_buffer_f64(input)?;
628        let dev = self.device(grad.device_ordinal())?;
629        let result = crate::kernels::gpu_gelu_backward_tanh_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
630        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
631    }
632
633    fn gelu_backward_erf_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
634        let g_buf = Self::unwrap_buffer_f64(grad)?;
635        let i_buf = Self::unwrap_buffer_f64(input)?;
636        let dev = self.device(grad.device_ordinal())?;
637        let result = crate::kernels::gpu_gelu_backward_erf_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
638        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
639    }
640
641    fn silu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
642        let g_buf = Self::unwrap_buffer_f64(grad)?;
643        let i_buf = Self::unwrap_buffer_f64(input)?;
644        let dev = self.device(grad.device_ordinal())?;
645        let result = crate::kernels::gpu_silu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
646        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
647    }
648
649    fn elu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle, alpha: f64) -> FerrotorchResult<GpuBufferHandle> {
650        let g_buf = Self::unwrap_buffer_f64(grad)?;
651        let i_buf = Self::unwrap_buffer_f64(input)?;
652        let dev = self.device(grad.device_ordinal())?;
653        let result = crate::kernels::gpu_elu_backward_f64(g_buf, i_buf, alpha, dev).map_err(Self::map_gpu_err)?;
654        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
655    }
656
657    fn mish_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
658        let g_buf = Self::unwrap_buffer_f64(grad)?;
659        let i_buf = Self::unwrap_buffer_f64(input)?;
660        let dev = self.device(grad.device_ordinal())?;
661        let result = crate::kernels::gpu_mish_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
662        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
663    }
664
665    // f64 cumulative ops
666    fn cumsum_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
667        let a_buf = Self::unwrap_buffer_f64(a)?;
668        let dev = self.device(a.device_ordinal())?;
669        let result = crate::kernels::gpu_cumsum_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
670        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
671    }
672
673    fn cumprod_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
674        let a_buf = Self::unwrap_buffer_f64(a)?;
675        let dev = self.device(a.device_ordinal())?;
676        let result = crate::kernels::gpu_cumprod_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
677        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
678    }
679
680    fn cummax_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
681        let a_buf = Self::unwrap_buffer_f64(a)?;
682        let dev = self.device(a.device_ordinal())?;
683        let (vals, idxs) = crate::kernels::gpu_cummax_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
684        let ord = a.device_ordinal();
685        Ok((Self::wrap_buffer_f64(vals, ord), Self::wrap_buffer_f64(idxs, ord)))
686    }
687
688    fn cummin_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
689        let a_buf = Self::unwrap_buffer_f64(a)?;
690        let dev = self.device(a.device_ordinal())?;
691        let (vals, idxs) = crate::kernels::gpu_cummin_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
692        let ord = a.device_ordinal();
693        Ok((Self::wrap_buffer_f64(vals, ord), Self::wrap_buffer_f64(idxs, ord)))
694    }
695
696    fn logcumsumexp_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
697        let a_buf = Self::unwrap_buffer_f64(a)?;
698        let dev = self.device(a.device_ordinal())?;
699        let result = crate::kernels::gpu_logcumsumexp_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
700        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
701    }
702
703    // f64 shape ops
704    fn transpose_2d_f64(&self, a: &GpuBufferHandle, m: usize, n: usize) -> FerrotorchResult<GpuBufferHandle> {
705        let a_buf = Self::unwrap_buffer_f64(a)?;
706        let dev = self.device(a.device_ordinal())?;
707        let result = crate::kernels::gpu_transpose_2d_f64(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
708        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
709    }
710
711    fn permute_0213_f64(&self, a: &GpuBufferHandle, d0: usize, d1: usize, d2: usize, d3: usize) -> FerrotorchResult<GpuBufferHandle> {
712        let a_buf = Self::unwrap_buffer_f64(a)?;
713        let dev = self.device(a.device_ordinal())?;
714        let result = crate::kernels::gpu_permute_0213_f64(a_buf, d0, d1, d2, d3, dev).map_err(Self::map_gpu_err)?;
715        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
716    }
717
718    // f64 broadcast ops
719    fn broadcast_add_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
720        let a_buf = Self::unwrap_buffer_f64(a)?;
721        let b_buf = Self::unwrap_buffer_f64(b)?;
722        let dev = self.device(a.device_ordinal())?;
723        let result = crate::kernels::gpu_broadcast_add_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
724        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
725    }
726
727    fn broadcast_sub_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
728        let a_buf = Self::unwrap_buffer_f64(a)?;
729        let b_buf = Self::unwrap_buffer_f64(b)?;
730        let dev = self.device(a.device_ordinal())?;
731        let result = crate::kernels::gpu_broadcast_sub_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
732        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
733    }
734
735    fn broadcast_mul_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
736        let a_buf = Self::unwrap_buffer_f64(a)?;
737        let b_buf = Self::unwrap_buffer_f64(b)?;
738        let dev = self.device(a.device_ordinal())?;
739        let result = crate::kernels::gpu_broadcast_mul_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
740        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
741    }
742
743    fn broadcast_div_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
744        let a_buf = Self::unwrap_buffer_f64(a)?;
745        let b_buf = Self::unwrap_buffer_f64(b)?;
746        let dev = self.device(a.device_ordinal())?;
747        let result = crate::kernels::gpu_broadcast_div_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
748        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
749    }
750
751    // f64 reduction ops
752    fn sum_f64(&self, a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
753        let a_buf = Self::unwrap_buffer_f64(a)?;
754        let dev = self.device(a.device_ordinal())?;
755        let result = crate::kernels::gpu_reduce_sum_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
756        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
757    }
758
759    fn sum_axis_f64(&self, a: &GpuBufferHandle, shape: &[usize], axis: usize) -> FerrotorchResult<GpuBufferHandle> {
760        let a_buf = Self::unwrap_buffer_f64(a)?;
761        let dev = self.device(a.device_ordinal())?;
762        let outer: usize = shape[..axis].iter().product();
763        let axis_size = shape[axis];
764        let inner: usize = shape[axis + 1..].iter().product();
765        let result = crate::kernels::gpu_sum_axis_f64(a_buf, outer, axis_size, inner, dev).map_err(Self::map_gpu_err)?;
766        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
767    }
768
769    // f64 softmax / log-softmax / layernorm / rmsnorm
770
771    fn softmax_f64(&self, a: &GpuBufferHandle, rows: usize, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
772        let a_buf = Self::unwrap_buffer_f64(a)?;
773        let dev = self.device(a.device_ordinal())?;
774        let result = crate::kernels::gpu_softmax_f64(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
775        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
776    }
777
778    fn softmax_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
779        let grad_buf = Self::unwrap_buffer_f64(grad)?;
780        let output_buf = Self::unwrap_buffer_f64(output)?;
781        let dev = self.device(grad.device_ordinal())?;
782        let result = crate::kernels::gpu_softmax_backward_f64(grad_buf, output_buf, cols, dev).map_err(Self::map_gpu_err)?;
783        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
784    }
785
786    fn log_softmax_f64(&self, a: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
787        let a_buf = Self::unwrap_buffer_f64(a)?;
788        let dev = self.device(a.device_ordinal())?;
789        let result = crate::kernels::gpu_log_softmax_f64(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
790        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
791    }
792
793    fn log_softmax_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
794        let grad_buf = Self::unwrap_buffer_f64(grad)?;
795        let output_buf = Self::unwrap_buffer_f64(output)?;
796        let dev = self.device(grad.device_ordinal())?;
797        let result = crate::kernels::gpu_log_softmax_backward_f64(grad_buf, output_buf, cols, dev).map_err(Self::map_gpu_err)?;
798        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
799    }
800
801    fn layernorm_f64(
802        &self,
803        input: &GpuBufferHandle,
804        weight: &GpuBufferHandle,
805        bias: &GpuBufferHandle,
806        rows: usize,
807        cols: usize,
808        eps: f64,
809    ) -> FerrotorchResult<GpuBufferHandle> {
810        let in_buf = Self::unwrap_buffer_f64(input)?;
811        let w_buf = Self::unwrap_buffer_f64(weight)?;
812        let b_buf = Self::unwrap_buffer_f64(bias)?;
813        let dev = self.device(input.device_ordinal())?;
814        let result = crate::kernels::gpu_layernorm_f64(in_buf, w_buf, b_buf, rows, cols, eps, dev)
815            .map_err(Self::map_gpu_err)?;
816        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
817    }
818
819    fn layernorm_backward_f64(
820        &self,
821        input: &GpuBufferHandle,
822        grad_output: &GpuBufferHandle,
823        weight: &GpuBufferHandle,
824        rows: usize,
825        cols: usize,
826        eps: f64,
827    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
828        let in_buf = Self::unwrap_buffer_f64(input)?;
829        let go_buf = Self::unwrap_buffer_f64(grad_output)?;
830        let w_buf = Self::unwrap_buffer_f64(weight)?;
831        let dev = self.device(input.device_ordinal())?;
832        let (gi, gw, gb) =
833            crate::kernels::gpu_layernorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
834                .map_err(Self::map_gpu_err)?;
835        let ordinal = input.device_ordinal();
836        Ok((
837            Self::wrap_buffer_f64(gi, ordinal),
838            Self::wrap_buffer_f64(gw, ordinal),
839            Self::wrap_buffer_f64(gb, ordinal),
840        ))
841    }
842
843    fn rmsnorm_f64(
844        &self,
845        input: &GpuBufferHandle,
846        weight: &GpuBufferHandle,
847        rows: usize,
848        cols: usize,
849        eps: f64,
850    ) -> FerrotorchResult<GpuBufferHandle> {
851        let in_buf = Self::unwrap_buffer_f64(input)?;
852        let w_buf = Self::unwrap_buffer_f64(weight)?;
853        let dev = self.device(input.device_ordinal())?;
854        let result = crate::kernels::gpu_rmsnorm_f64(in_buf, w_buf, rows, cols, eps, dev)
855            .map_err(Self::map_gpu_err)?;
856        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
857    }
858
859    fn rmsnorm_backward_f64(
860        &self,
861        input: &GpuBufferHandle,
862        grad_output: &GpuBufferHandle,
863        weight: &GpuBufferHandle,
864        rows: usize,
865        cols: usize,
866        eps: f64,
867    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
868        let in_buf = Self::unwrap_buffer_f64(input)?;
869        let go_buf = Self::unwrap_buffer_f64(grad_output)?;
870        let w_buf = Self::unwrap_buffer_f64(weight)?;
871        let dev = self.device(input.device_ordinal())?;
872        let (gi, gw) =
873            crate::kernels::gpu_rmsnorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
874                .map_err(Self::map_gpu_err)?;
875        let ordinal = input.device_ordinal();
876        Ok((Self::wrap_buffer_f64(gi, ordinal), Self::wrap_buffer_f64(gw, ordinal)))
877    }
878
879    // f64 embedding / scatter / indexing
880
881    fn embed_lookup_f64(
882        &self,
883        idx: &GpuBufferHandle,
884        weight: &GpuBufferHandle,
885        d: usize,
886    ) -> FerrotorchResult<GpuBufferHandle> {
887        // indices are always f32-encoded
888        let idx_buf = Self::unwrap_buffer(idx)?;
889        let w_buf = Self::unwrap_buffer_f64(weight)?;
890        let dev = self.device(idx.device_ordinal())?;
891        let result =
892            crate::kernels::gpu_embed_lookup_f64(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
893        Ok(Self::wrap_buffer_f64(result, idx.device_ordinal()))
894    }
895
896    fn embed_lookup_batch_f64(
897        &self,
898        indices: &GpuBufferHandle,
899        weight: &GpuBufferHandle,
900        n: usize,
901        d: usize,
902    ) -> FerrotorchResult<GpuBufferHandle> {
903        // indices are always f32-encoded
904        let idx_buf = Self::unwrap_buffer(indices)?;
905        let w_buf = Self::unwrap_buffer_f64(weight)?;
906        let dev = self.device(indices.device_ordinal())?;
907        let result = crate::kernels::gpu_embed_lookup_batch_f64(idx_buf, w_buf, n, d, dev)
908            .map_err(Self::map_gpu_err)?;
909        Ok(Self::wrap_buffer_f64(result, indices.device_ordinal()))
910    }
911
912    fn scatter_add_rows_f64(
913        &self,
914        grad_output: &GpuBufferHandle,
915        indices: &GpuBufferHandle,
916        num_embeddings: usize,
917        d: usize,
918    ) -> FerrotorchResult<GpuBufferHandle> {
919        let go_buf = Self::unwrap_buffer_f64(grad_output)?;
920        // indices are always f32-encoded
921        let idx_buf = Self::unwrap_buffer(indices)?;
922        let dev = self.device(grad_output.device_ordinal())?;
923        let result = crate::kernels::gpu_scatter_add_rows_f64(go_buf, idx_buf, num_embeddings, d, dev)
924            .map_err(Self::map_gpu_err)?;
925        Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
926    }
927
928    // f64 masked fill / masked zero
929    //
930    // The f64 kernels expect CudaBuffer<u8> for the mask, but the trait
931    // provides a GpuBufferHandle containing CudaBuffer<f32> (1.0/0.0 encoding).
932    // We convert f32 mask -> u8 mask via a CPU roundtrip.
933
934    fn masked_fill_f64(
935        &self,
936        input: &GpuBufferHandle,
937        mask: &GpuBufferHandle,
938        value: f64,
939    ) -> FerrotorchResult<GpuBufferHandle> {
940        let input_buf = Self::unwrap_buffer_f64(input)?;
941        let mask_f32 = Self::unwrap_buffer(mask)?;
942        let dev = self.device(input.device_ordinal())?;
943        // Convert f32 mask to u8 mask on GPU via CPU roundtrip
944        let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
945        let mask_u8: Vec<u8> = mask_host.iter().map(|&v| if v != 0.0 { 1u8 } else { 0u8 }).collect();
946        let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
947        let result = crate::kernels::gpu_masked_fill_f64(input_buf, &mask_gpu, value, dev)
948            .map_err(Self::map_gpu_err)?;
949        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
950    }
951
952    fn masked_zero_f64(
953        &self,
954        grad: &GpuBufferHandle,
955        mask: &GpuBufferHandle,
956    ) -> FerrotorchResult<GpuBufferHandle> {
957        let grad_buf = Self::unwrap_buffer_f64(grad)?;
958        let mask_f32 = Self::unwrap_buffer(mask)?;
959        let dev = self.device(grad.device_ordinal())?;
960        // Convert f32 mask to u8 mask on GPU via CPU roundtrip
961        let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
962        let mask_u8: Vec<u8> = mask_host.iter().map(|&v| if v != 0.0 { 1u8 } else { 0u8 }).collect();
963        let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
964        let result = crate::kernels::gpu_masked_zero_f64(grad_buf, &mask_gpu, dev)
965            .map_err(Self::map_gpu_err)?;
966        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
967    }
968
969    // f64 slice ops
970
971    fn slice_write_f64(
972        &self,
973        src: &GpuBufferHandle,
974        dst: &mut GpuBufferHandle,
975        n_batch: usize,
976        d: usize,
977        max_len: usize,
978        pos: usize,
979    ) -> FerrotorchResult<()> {
980        let src_buf = Self::unwrap_buffer_f64(src)?;
981        let dst_buf = Self::unwrap_buffer_f64_mut(dst)?;
982        let dev = self.device(src.device_ordinal())?;
983        crate::kernels::gpu_slice_write_f64(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
984            .map_err(Self::map_gpu_err)?;
985        Ok(())
986    }
987
988    fn slice_read_f64(
989        &self,
990        src: &GpuBufferHandle,
991        n_batch: usize,
992        d: usize,
993        len: usize,
994        max_len: usize,
995    ) -> FerrotorchResult<GpuBufferHandle> {
996        let src_buf = Self::unwrap_buffer_f64(src)?;
997        let dev = self.device(src.device_ordinal())?;
998        let result = crate::kernels::gpu_slice_read_f64(src_buf, n_batch, d, len, max_len, dev)
999            .map_err(Self::map_gpu_err)?;
1000        Ok(Self::wrap_buffer_f64(result, src.device_ordinal()))
1001    }
1002
1003    // f64 strided split / cat
1004
1005    fn strided_split_f64(
1006        &self,
1007        input: &GpuBufferHandle,
1008        total_along_axis: usize,
1009        split_offset: usize,
1010        split_size: usize,
1011        inner_size: usize,
1012        n: usize,
1013    ) -> FerrotorchResult<GpuBufferHandle> {
1014        let in_buf = Self::unwrap_buffer_f64(input)?;
1015        let dev = self.device(input.device_ordinal())?;
1016        let result = crate::kernels::gpu_strided_split_f64(
1017            in_buf,
1018            total_along_axis,
1019            split_offset,
1020            split_size,
1021            inner_size,
1022            n,
1023            dev,
1024        )
1025        .map_err(Self::map_gpu_err)?;
1026        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
1027    }
1028
1029    fn strided_cat_f64(
1030        &self,
1031        input: &GpuBufferHandle,
1032        output: &mut GpuBufferHandle,
1033        total_along_axis: usize,
1034        cat_offset: usize,
1035        part_size: usize,
1036        inner_size: usize,
1037        n: usize,
1038    ) -> FerrotorchResult<()> {
1039        let in_buf = Self::unwrap_buffer_f64(input)?;
1040        let dev = self.device(input.device_ordinal())?;
1041        let out_buf = Self::unwrap_buffer_f64_mut(output)?;
1042        crate::kernels::gpu_strided_cat_f64(
1043            in_buf,
1044            out_buf,
1045            total_along_axis,
1046            cat_offset,
1047            part_size,
1048            inner_size,
1049            n,
1050            dev,
1051        )
1052        .map_err(Self::map_gpu_err)?;
1053        Ok(())
1054    }
1055
1056    // f64 indexing ops
1057
1058    fn index_select_1d_f64(
1059        &self,
1060        input: &GpuBufferHandle,
1061        indices: &GpuBufferHandle,
1062    ) -> FerrotorchResult<GpuBufferHandle> {
1063        let input_buf = Self::unwrap_buffer_f64(input)?;
1064        // indices are always f32-encoded
1065        let idx_buf = Self::unwrap_buffer(indices)?;
1066        let dev = self.device(input.device_ordinal())?;
1067        let result = crate::kernels::gpu_index_select_1d_f64(input_buf, idx_buf, dev)
1068            .map_err(Self::map_gpu_err)?;
1069        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
1070    }
1071
1072    fn scatter_add_1d_f64(
1073        &self,
1074        grad_output: &GpuBufferHandle,
1075        indices: &GpuBufferHandle,
1076        input_len: usize,
1077    ) -> FerrotorchResult<GpuBufferHandle> {
1078        let go_buf = Self::unwrap_buffer_f64(grad_output)?;
1079        // indices are always f32-encoded
1080        let idx_buf = Self::unwrap_buffer(indices)?;
1081        let dev = self.device(grad_output.device_ordinal())?;
1082        let result = crate::kernels::gpu_scatter_add_1d_f64(go_buf, idx_buf, input_len, dev)
1083            .map_err(Self::map_gpu_err)?;
1084        Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
1085    }
1086
1087    fn bmm_f64(
1088        &self,
1089        a: &GpuBufferHandle,
1090        b: &GpuBufferHandle,
1091        batch: usize,
1092        m: usize,
1093        k: usize,
1094        n: usize,
1095    ) -> FerrotorchResult<GpuBufferHandle> {
1096        let a_buf = Self::unwrap_buffer_f64(a)?;
1097        let b_buf = Self::unwrap_buffer_f64(b)?;
1098        let dev = self.device(a.device_ordinal())?;
1099        let result = crate::blas::gpu_bmm_f64(a_buf, b_buf, batch, m, k, n, dev)
1100            .map_err(Self::map_gpu_err)?;
1101        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1102    }
1103
1104    #[allow(clippy::too_many_arguments)]
1105    fn fused_adam_f32(
1106        &self,
1107        param: &mut GpuBufferHandle,
1108        grad: &GpuBufferHandle,
1109        exp_avg: &mut GpuBufferHandle,
1110        exp_avg_sq: &mut GpuBufferHandle,
1111        beta1: f32,
1112        beta2: f32,
1113        lr: f32,
1114        eps: f32,
1115        bc1: f32,
1116        bc2: f32,
1117        weight_decay: f32,
1118    ) -> FerrotorchResult<()> {
1119        let ordinal = param.device_ordinal();
1120        let dev = self.device(ordinal)?;
1121        let p_buf = Self::unwrap_buffer_mut(param)?;
1122        let g_buf = Self::unwrap_buffer(grad)?;
1123        let m_buf = Self::unwrap_buffer_mut(exp_avg)?;
1124        let v_buf = Self::unwrap_buffer_mut(exp_avg_sq)?;
1125        crate::kernels::gpu_fused_adam(
1126            p_buf,
1127            g_buf,
1128            m_buf,
1129            v_buf,
1130            beta1,
1131            beta2,
1132            lr,
1133            eps,
1134            bc1,
1135            bc2,
1136            weight_decay,
1137            dev,
1138        )
1139        .map_err(Self::map_gpu_err)?;
1140        Ok(())
1141    }
1142
1143    #[allow(clippy::too_many_arguments)]
1144    fn maxpool2d_f32(
1145        &self,
1146        input: &GpuBufferHandle,
1147        batch: usize,
1148        channels: usize,
1149        h_in: usize,
1150        w_in: usize,
1151        kh: usize,
1152        kw: usize,
1153        sh: usize,
1154        sw: usize,
1155        ph: usize,
1156        pw: usize,
1157    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1158        let buf = Self::unwrap_buffer(input)?;
1159        let dev = self.device(input.device_ordinal())?;
1160        let (out, shape) = crate::kernels::gpu_maxpool2d(
1161            buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
1162        ).map_err(Self::map_gpu_err)?;
1163        Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
1164    }
1165
1166    #[allow(clippy::too_many_arguments)]
1167    fn avgpool2d_f32(
1168        &self,
1169        input: &GpuBufferHandle,
1170        batch: usize,
1171        channels: usize,
1172        h_in: usize,
1173        w_in: usize,
1174        kh: usize,
1175        kw: usize,
1176        sh: usize,
1177        sw: usize,
1178        ph: usize,
1179        pw: usize,
1180    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1181        let buf = Self::unwrap_buffer(input)?;
1182        let dev = self.device(input.device_ordinal())?;
1183        let (out, shape) = crate::kernels::gpu_avgpool2d(
1184            buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
1185        ).map_err(Self::map_gpu_err)?;
1186        Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
1187    }
1188
1189    #[allow(clippy::too_many_arguments)]
1190    fn conv2d_f32(
1191        &self,
1192        input: &GpuBufferHandle,
1193        weight: &GpuBufferHandle,
1194        bias: Option<&GpuBufferHandle>,
1195        input_shape: [usize; 4],
1196        weight_shape: [usize; 4],
1197        stride: (usize, usize),
1198        padding: (usize, usize),
1199    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1200        let input_buf = Self::unwrap_buffer(input)?;
1201        let weight_buf = Self::unwrap_buffer(weight)?;
1202        let bias_buf = match bias {
1203            Some(b) => Some(Self::unwrap_buffer(b)?),
1204            None => None,
1205        };
1206        let dev = self.device(input.device_ordinal())?;
1207        let (out_buf, out_shape) = crate::conv::gpu_conv2d_f32(
1208            input_buf,
1209            weight_buf,
1210            bias_buf,
1211            input_shape,
1212            weight_shape,
1213            stride,
1214            padding,
1215            dev,
1216        )
1217        .map_err(Self::map_gpu_err)?;
1218        Ok((Self::wrap_buffer(out_buf, input.device_ordinal()), out_shape))
1219    }
1220
1221    fn fused_gru_cell_f32(
1222        &self,
1223        input_gates: &GpuBufferHandle,
1224        hidden_gates: &GpuBufferHandle,
1225        bias_ih: &GpuBufferHandle,
1226        bias_hh: &GpuBufferHandle,
1227        hx: &GpuBufferHandle,
1228        hidden_size: usize,
1229    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1230        let ig = Self::unwrap_buffer(input_gates)?;
1231        let hg = Self::unwrap_buffer(hidden_gates)?;
1232        let bih = Self::unwrap_buffer(bias_ih)?;
1233        let bhh = Self::unwrap_buffer(bias_hh)?;
1234        let hx_buf = Self::unwrap_buffer(hx)?;
1235        let dev = self.device(input_gates.device_ordinal())?;
1236        let (hy, ws) = crate::kernels::gpu_fused_gru_forward(
1237            ig, hg, bih, bhh, hx_buf, hidden_size, dev,
1238        )
1239        .map_err(Self::map_gpu_err)?;
1240        let ord = input_gates.device_ordinal();
1241        Ok((Self::wrap_buffer(hy, ord), Self::wrap_buffer(ws, ord)))
1242    }
1243
1244    fn synchronize(&self, device: usize) -> FerrotorchResult<()> {
1245        let dev = self.device(device)?;
1246        dev.stream()
1247            .synchronize()
1248            .map_err(|e| FerrotorchError::InvalidArgument {
1249                message: format!("CUDA synchronize failed: {e}"),
1250            })?;
1251        Ok(())
1252    }
1253
1254    fn stream_count(&self, device: usize) -> usize {
1255        crate::stream::StreamPool::pool_size(device)
1256    }
1257
1258    // -- Linalg f32 -----------------------------------------------------------
1259
1260    fn matmul_f32(
1261        &self,
1262        a: &GpuBufferHandle,
1263        b: &GpuBufferHandle,
1264        m: usize,
1265        k: usize,
1266        n: usize,
1267    ) -> FerrotorchResult<GpuBufferHandle> {
1268        let a_buf = Self::unwrap_buffer(a)?;
1269        let b_buf = Self::unwrap_buffer(b)?;
1270        let dev = self.device(a.device_ordinal())?;
1271        let result =
1272            crate::blas::gpu_matmul_f32(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1273        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1274    }
1275
1276    // -- Reduction f32 --------------------------------------------------------
1277
1278    fn sum_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
1279        let a_buf = Self::unwrap_buffer(a)?;
1280        let dev = self.device(a.device_ordinal())?;
1281        let result = crate::kernels::gpu_reduce_sum(a_buf, dev).map_err(Self::map_gpu_err)?;
1282        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1283    }
1284
1285    // -- Linalg f64 (cuBLAS DGEMM) --------------------------------------------
1286
1287    fn matmul_f64(
1288        &self,
1289        a: &GpuBufferHandle,
1290        b: &GpuBufferHandle,
1291        m: usize,
1292        k: usize,
1293        n: usize,
1294    ) -> FerrotorchResult<GpuBufferHandle> {
1295        let a_buf = Self::unwrap_buffer_f64(a)?;
1296        let b_buf = Self::unwrap_buffer_f64(b)?;
1297        let dev = self.device(a.device_ordinal())?;
1298        let result =
1299            crate::blas::gpu_matmul_f64(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1300        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1301    }
1302
1303    // -- Broadcast binary f32 -------------------------------------------------
1304
1305    fn broadcast_add_f32(
1306        &self,
1307        a: &GpuBufferHandle,
1308        b: &GpuBufferHandle,
1309        a_shape: &[usize],
1310        b_shape: &[usize],
1311        out_shape: &[usize],
1312    ) -> FerrotorchResult<GpuBufferHandle> {
1313        let a_buf = Self::unwrap_buffer(a)?;
1314        let b_buf = Self::unwrap_buffer(b)?;
1315        let dev = self.device(a.device_ordinal())?;
1316        let result =
1317            crate::kernels::gpu_broadcast_add(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1318                .map_err(Self::map_gpu_err)?;
1319        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1320    }
1321
1322    fn broadcast_sub_f32(
1323        &self,
1324        a: &GpuBufferHandle,
1325        b: &GpuBufferHandle,
1326        a_shape: &[usize],
1327        b_shape: &[usize],
1328        out_shape: &[usize],
1329    ) -> FerrotorchResult<GpuBufferHandle> {
1330        let a_buf = Self::unwrap_buffer(a)?;
1331        let b_buf = Self::unwrap_buffer(b)?;
1332        let dev = self.device(a.device_ordinal())?;
1333        let result =
1334            crate::kernels::gpu_broadcast_sub(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1335                .map_err(Self::map_gpu_err)?;
1336        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1337    }
1338
1339    fn broadcast_mul_f32(
1340        &self,
1341        a: &GpuBufferHandle,
1342        b: &GpuBufferHandle,
1343        a_shape: &[usize],
1344        b_shape: &[usize],
1345        out_shape: &[usize],
1346    ) -> FerrotorchResult<GpuBufferHandle> {
1347        let a_buf = Self::unwrap_buffer(a)?;
1348        let b_buf = Self::unwrap_buffer(b)?;
1349        let dev = self.device(a.device_ordinal())?;
1350        let result =
1351            crate::kernels::gpu_broadcast_mul(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1352                .map_err(Self::map_gpu_err)?;
1353        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1354    }
1355
1356    fn broadcast_div_f32(
1357        &self,
1358        a: &GpuBufferHandle,
1359        b: &GpuBufferHandle,
1360        a_shape: &[usize],
1361        b_shape: &[usize],
1362        out_shape: &[usize],
1363    ) -> FerrotorchResult<GpuBufferHandle> {
1364        let a_buf = Self::unwrap_buffer(a)?;
1365        let b_buf = Self::unwrap_buffer(b)?;
1366        let dev = self.device(a.device_ordinal())?;
1367        let result =
1368            crate::kernels::gpu_broadcast_div(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1369                .map_err(Self::map_gpu_err)?;
1370        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1371    }
1372
1373    fn softmax_f32(
1374        &self,
1375        a: &GpuBufferHandle,
1376        rows: usize,
1377        cols: usize,
1378    ) -> FerrotorchResult<GpuBufferHandle> {
1379        let a_buf = Self::unwrap_buffer(a)?;
1380        let dev = self.device(a.device_ordinal())?;
1381        let result =
1382            crate::kernels::gpu_softmax(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
1383        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1384    }
1385
1386    fn dropout_f32(
1387        &self,
1388        a: &GpuBufferHandle,
1389        threshold: u32,
1390        scale: f32,
1391        seed: u32,
1392    ) -> FerrotorchResult<GpuBufferHandle> {
1393        let a_buf = Self::unwrap_buffer(a)?;
1394        let dev = self.device(a.device_ordinal())?;
1395        let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, seed, dev)
1396            .map_err(Self::map_gpu_err)?;
1397        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1398    }
1399
1400    fn dropout_philox_f32(
1401        &self,
1402        a: &GpuBufferHandle,
1403        threshold: u32,
1404        scale: f32,
1405    ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
1406        let device_ordinal = a.device_ordinal();
1407        let n = a.len();
1408
1409        // Snapshot the current RNG state and advance it.
1410        let rng_state = {
1411            let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1412                FerrotorchError::InvalidArgument {
1413                    message: "failed to lock CUDA RNG manager".into(),
1414                }
1415            })?;
1416            let philox_gen = mgr.generator(device_ordinal);
1417            let state = philox_gen.get_state();
1418            // Advance by ceil(n/4) counters (each counter produces 4 u32 values)
1419            let counters_needed = n.div_ceil(4);
1420            philox_gen.advance(counters_needed as u64);
1421            state
1422        };
1423
1424        // Use the Philox state as the seed for the dropout kernel.
1425        // We encode the Philox counter+seed into a u32 seed that the existing
1426        // dropout kernel can use. For full correctness on GPU, we should use
1427        // the Philox uniform kernel to generate the mask, then apply it.
1428        // However, for consistency between GPU forward and CPU backward mask
1429        // regeneration, we use the Philox state to deterministically derive a
1430        // seed for the existing kernel.
1431        let a_buf = Self::unwrap_buffer(a)?;
1432        let dev = self.device(device_ordinal)?;
1433
1434        // Use the Philox counter XOR seed as the dropout kernel's seed.
1435        // This gives us deterministic behavior tied to the Philox state.
1436        let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
1437        let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, derived_seed, dev)
1438            .map_err(Self::map_gpu_err)?;
1439
1440        let gpu_rng_state = GpuRngState {
1441            counter: rng_state.counter,
1442            seed: rng_state.seed,
1443            offset: rng_state.offset,
1444            device: device_ordinal,
1445        };
1446
1447        Ok((Self::wrap_buffer(result, device_ordinal), gpu_rng_state))
1448    }
1449
1450    fn dropout_f64(
1451        &self,
1452        a: &GpuBufferHandle,
1453        threshold: u32,
1454        scale: f64,
1455        seed: u32,
1456    ) -> FerrotorchResult<GpuBufferHandle> {
1457        let a_buf = Self::unwrap_buffer_f64(a)?;
1458        let dev = self.device(a.device_ordinal())?;
1459        let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, seed, dev)
1460            .map_err(Self::map_gpu_err)?;
1461        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1462    }
1463
1464    fn dropout_philox_f64(
1465        &self,
1466        a: &GpuBufferHandle,
1467        threshold: u32,
1468        scale: f64,
1469    ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
1470        let device_ordinal = a.device_ordinal();
1471        let n = a.len();
1472
1473        let rng_state = {
1474            let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1475                FerrotorchError::InvalidArgument {
1476                    message: "failed to lock CUDA RNG manager".into(),
1477                }
1478            })?;
1479            let philox_gen = mgr.generator(device_ordinal);
1480            let state = philox_gen.get_state();
1481            let counters_needed = n.div_ceil(4);
1482            philox_gen.advance(counters_needed as u64);
1483            state
1484        };
1485
1486        let a_buf = Self::unwrap_buffer_f64(a)?;
1487        let dev = self.device(device_ordinal)?;
1488        let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
1489        let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, derived_seed, dev)
1490            .map_err(Self::map_gpu_err)?;
1491
1492        let gpu_rng_state = GpuRngState {
1493            counter: rng_state.counter,
1494            seed: rng_state.seed,
1495            offset: rng_state.offset,
1496            device: device_ordinal,
1497        };
1498
1499        Ok((Self::wrap_buffer_f64(result, device_ordinal), gpu_rng_state))
1500    }
1501
1502    fn transpose_2d_f32(
1503        &self,
1504        a: &GpuBufferHandle,
1505        m: usize,
1506        n: usize,
1507    ) -> FerrotorchResult<GpuBufferHandle> {
1508        let a_buf = Self::unwrap_buffer(a)?;
1509        let dev = self.device(a.device_ordinal())?;
1510        let result =
1511            crate::kernels::gpu_transpose_2d(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
1512        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1513    }
1514
1515    fn permute_0213_f32(
1516        &self,
1517        a: &GpuBufferHandle,
1518        d0: usize,
1519        d1: usize,
1520        d2: usize,
1521        d3: usize,
1522    ) -> FerrotorchResult<GpuBufferHandle> {
1523        let a_buf = Self::unwrap_buffer(a)?;
1524        let dev = self.device(a.device_ordinal())?;
1525        let result = crate::kernels::gpu_permute_0213(a_buf, d0, d1, d2, d3, dev)
1526            .map_err(Self::map_gpu_err)?;
1527        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1528    }
1529
1530    fn bmm_f32(
1531        &self,
1532        a: &GpuBufferHandle,
1533        b: &GpuBufferHandle,
1534        batch: usize,
1535        m: usize,
1536        k: usize,
1537        n: usize,
1538    ) -> FerrotorchResult<GpuBufferHandle> {
1539        let a_buf = Self::unwrap_buffer(a)?;
1540        let b_buf = Self::unwrap_buffer(b)?;
1541        let dev = self.device(a.device_ordinal())?;
1542        let result = crate::blas::gpu_bmm_f32(a_buf, b_buf, batch, m, k, n, dev)
1543            .map_err(Self::map_gpu_err)?;
1544        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1545    }
1546
1547    fn bmm_f16_f32(
1548        &self,
1549        a: &GpuBufferHandle,
1550        b: &GpuBufferHandle,
1551        batch: usize,
1552        m: usize,
1553        k: usize,
1554        n: usize,
1555    ) -> FerrotorchResult<GpuBufferHandle> {
1556        let a_buf = Self::unwrap_buffer(a)?;
1557        let b_buf = Self::unwrap_buffer(b)?;
1558        let dev = self.device(a.device_ordinal())?;
1559        let result = crate::blas::gpu_bmm_f16(a_buf, b_buf, batch, m, k, n, dev)
1560            .map_err(Self::map_gpu_err)?;
1561        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1562    }
1563
1564    fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1565        let a_buf = Self::unwrap_buffer(a)?;
1566        let dev = self.device(a.device_ordinal())?;
1567        let result = crate::kernels::gpu_gelu(a_buf, dev).map_err(Self::map_gpu_err)?;
1568        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1569    }
1570
1571    fn gelu_tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1572        let a_buf = Self::unwrap_buffer(a)?;
1573        let dev = self.device(a.device_ordinal())?;
1574        let result = crate::kernels::gpu_gelu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
1575        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1576    }
1577
1578    fn gelu_erf_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1579        let a_buf = Self::unwrap_buffer(a)?;
1580        let dev = self.device(a.device_ordinal())?;
1581        let result = crate::kernels::gpu_gelu_erf(a_buf, dev).map_err(Self::map_gpu_err)?;
1582        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1583    }
1584
1585    fn layernorm_f32(
1586        &self,
1587        input: &GpuBufferHandle,
1588        weight: &GpuBufferHandle,
1589        bias: &GpuBufferHandle,
1590        rows: usize,
1591        cols: usize,
1592        eps: f32,
1593    ) -> FerrotorchResult<GpuBufferHandle> {
1594        let in_buf = Self::unwrap_buffer(input)?;
1595        let w_buf = Self::unwrap_buffer(weight)?;
1596        let b_buf = Self::unwrap_buffer(bias)?;
1597        let dev = self.device(input.device_ordinal())?;
1598        let result = crate::kernels::gpu_layernorm(in_buf, w_buf, b_buf, rows, cols, eps, dev)
1599            .map_err(Self::map_gpu_err)?;
1600        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1601    }
1602
1603    fn rmsnorm_f32(
1604        &self,
1605        input: &GpuBufferHandle,
1606        weight: &GpuBufferHandle,
1607        rows: usize,
1608        cols: usize,
1609        eps: f32,
1610    ) -> FerrotorchResult<GpuBufferHandle> {
1611        let in_buf = Self::unwrap_buffer(input)?;
1612        let w_buf = Self::unwrap_buffer(weight)?;
1613        let dev = self.device(input.device_ordinal())?;
1614        let result = crate::kernels::gpu_rmsnorm(in_buf, w_buf, rows, cols, eps, dev)
1615            .map_err(Self::map_gpu_err)?;
1616        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1617    }
1618
1619    fn rmsnorm_backward_f32(
1620        &self,
1621        input: &GpuBufferHandle,
1622        grad_output: &GpuBufferHandle,
1623        weight: &GpuBufferHandle,
1624        rows: usize,
1625        cols: usize,
1626        eps: f32,
1627    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1628        let in_buf = Self::unwrap_buffer(input)?;
1629        let go_buf = Self::unwrap_buffer(grad_output)?;
1630        let w_buf = Self::unwrap_buffer(weight)?;
1631        let dev = self.device(input.device_ordinal())?;
1632        let (gi, gw) =
1633            crate::kernels::gpu_rmsnorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
1634                .map_err(Self::map_gpu_err)?;
1635        let ordinal = input.device_ordinal();
1636        Ok((Self::wrap_buffer(gi, ordinal), Self::wrap_buffer(gw, ordinal)))
1637    }
1638
1639    fn slice_write_f32(
1640        &self,
1641        src: &GpuBufferHandle,
1642        dst: &mut GpuBufferHandle,
1643        n_batch: usize,
1644        d: usize,
1645        max_len: usize,
1646        pos: usize,
1647    ) -> FerrotorchResult<()> {
1648        let src_buf = Self::unwrap_buffer(src)?;
1649        let dst_buf =
1650            dst.downcast_mut::<CudaBuffer<f32>>()
1651                .ok_or(FerrotorchError::InvalidArgument {
1652                    message: "slice_write_f32: dst is not CudaBuffer<f32>".into(),
1653                })?;
1654        let dev = self.device(src.device_ordinal())?;
1655        crate::kernels::gpu_slice_write(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
1656            .map_err(Self::map_gpu_err)?;
1657        Ok(())
1658    }
1659
1660    fn slice_read_f32(
1661        &self,
1662        src: &GpuBufferHandle,
1663        n_batch: usize,
1664        d: usize,
1665        len: usize,
1666        max_len: usize,
1667    ) -> FerrotorchResult<GpuBufferHandle> {
1668        let src_buf = Self::unwrap_buffer(src)?;
1669        let dev = self.device(src.device_ordinal())?;
1670        let result = crate::kernels::gpu_slice_read(src_buf, n_batch, d, len, max_len, dev)
1671            .map_err(Self::map_gpu_err)?;
1672        Ok(Self::wrap_buffer(result, src.device_ordinal()))
1673    }
1674
1675    fn embed_lookup_f32(
1676        &self,
1677        idx: &GpuBufferHandle,
1678        weight: &GpuBufferHandle,
1679        d: usize,
1680    ) -> FerrotorchResult<GpuBufferHandle> {
1681        let idx_buf = Self::unwrap_buffer(idx)?;
1682        let w_buf = Self::unwrap_buffer(weight)?;
1683        let dev = self.device(idx.device_ordinal())?;
1684        let result =
1685            crate::kernels::gpu_embed_lookup(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
1686        Ok(Self::wrap_buffer(result, idx.device_ordinal()))
1687    }
1688
1689    fn embed_lookup_batch_f32(
1690        &self,
1691        indices: &GpuBufferHandle,
1692        weight: &GpuBufferHandle,
1693        n: usize,
1694        d: usize,
1695    ) -> FerrotorchResult<GpuBufferHandle> {
1696        let idx_buf = Self::unwrap_buffer(indices)?;
1697        let w_buf = Self::unwrap_buffer(weight)?;
1698        let dev = self.device(indices.device_ordinal())?;
1699        let result = crate::kernels::gpu_embed_lookup_batch(idx_buf, w_buf, n, d, dev)
1700            .map_err(Self::map_gpu_err)?;
1701        Ok(Self::wrap_buffer(result, indices.device_ordinal()))
1702    }
1703
1704    fn scatter_add_rows_f32(
1705        &self,
1706        grad_output: &GpuBufferHandle,
1707        indices: &GpuBufferHandle,
1708        num_embeddings: usize,
1709        d: usize,
1710    ) -> FerrotorchResult<GpuBufferHandle> {
1711        let go_buf = Self::unwrap_buffer(grad_output)?;
1712        let idx_buf = Self::unwrap_buffer(indices)?;
1713        let dev = self.device(grad_output.device_ordinal())?;
1714        let result = crate::kernels::gpu_scatter_add_rows(go_buf, idx_buf, num_embeddings, d, dev)
1715            .map_err(Self::map_gpu_err)?;
1716        Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
1717    }
1718
1719    fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle> {
1720        let a_buf = Self::unwrap_buffer(a)?;
1721        let dev = self.device(a.device_ordinal())?;
1722        let result = crate::kernels::gpu_scale(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
1723        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1724    }
1725
1726    fn relu_backward_f32(
1727        &self,
1728        grad: &GpuBufferHandle,
1729        input: &GpuBufferHandle,
1730    ) -> FerrotorchResult<GpuBufferHandle> {
1731        let grad_buf = Self::unwrap_buffer(grad)?;
1732        let input_buf = Self::unwrap_buffer(input)?;
1733        let dev = self.device(grad.device_ordinal())?;
1734        let result = crate::kernels::gpu_relu_backward(grad_buf, input_buf, dev)
1735            .map_err(Self::map_gpu_err)?;
1736        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1737    }
1738
1739    fn gelu_backward_f32(
1740        &self,
1741        grad: &GpuBufferHandle,
1742        input: &GpuBufferHandle,
1743    ) -> FerrotorchResult<GpuBufferHandle> {
1744        let grad_buf = Self::unwrap_buffer(grad)?;
1745        let input_buf = Self::unwrap_buffer(input)?;
1746        let dev = self.device(grad.device_ordinal())?;
1747        let result = crate::kernels::gpu_gelu_backward(grad_buf, input_buf, dev)
1748            .map_err(Self::map_gpu_err)?;
1749        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1750    }
1751
1752    fn gelu_backward_tanh_f32(
1753        &self,
1754        grad: &GpuBufferHandle,
1755        input: &GpuBufferHandle,
1756    ) -> FerrotorchResult<GpuBufferHandle> {
1757        let grad_buf = Self::unwrap_buffer(grad)?;
1758        let input_buf = Self::unwrap_buffer(input)?;
1759        let dev = self.device(grad.device_ordinal())?;
1760        let result = crate::kernels::gpu_gelu_backward_tanh(grad_buf, input_buf, dev)
1761            .map_err(Self::map_gpu_err)?;
1762        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1763    }
1764
1765    fn gelu_backward_erf_f32(
1766        &self,
1767        grad: &GpuBufferHandle,
1768        input: &GpuBufferHandle,
1769    ) -> FerrotorchResult<GpuBufferHandle> {
1770        let grad_buf = Self::unwrap_buffer(grad)?;
1771        let input_buf = Self::unwrap_buffer(input)?;
1772        let dev = self.device(grad.device_ordinal())?;
1773        let result = crate::kernels::gpu_gelu_backward_erf(grad_buf, input_buf, dev)
1774            .map_err(Self::map_gpu_err)?;
1775        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1776    }
1777
1778    fn cumsum_f32(
1779        &self,
1780        a: &GpuBufferHandle,
1781        outer: usize,
1782        dim_size: usize,
1783        inner: usize,
1784    ) -> FerrotorchResult<GpuBufferHandle> {
1785        let a_buf = Self::unwrap_buffer(a)?;
1786        let dev = self.device(a.device_ordinal())?;
1787        let result = crate::kernels::gpu_cumsum(a_buf, outer, dim_size, inner, dev)
1788            .map_err(Self::map_gpu_err)?;
1789        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1790    }
1791
1792    fn cumprod_f32(
1793        &self,
1794        a: &GpuBufferHandle,
1795        outer: usize,
1796        dim_size: usize,
1797        inner: usize,
1798    ) -> FerrotorchResult<GpuBufferHandle> {
1799        let a_buf = Self::unwrap_buffer(a)?;
1800        let dev = self.device(a.device_ordinal())?;
1801        let result = crate::kernels::gpu_cumprod(a_buf, outer, dim_size, inner, dev)
1802            .map_err(Self::map_gpu_err)?;
1803        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1804    }
1805
1806    fn cummax_f32(
1807        &self,
1808        a: &GpuBufferHandle,
1809        outer: usize,
1810        dim_size: usize,
1811        inner: usize,
1812    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1813        let a_buf = Self::unwrap_buffer(a)?;
1814        let dev = self.device(a.device_ordinal())?;
1815        let (vals, idxs) = crate::kernels::gpu_cummax(a_buf, outer, dim_size, inner, dev)
1816            .map_err(Self::map_gpu_err)?;
1817        let ord = a.device_ordinal();
1818        Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1819    }
1820
1821    fn cummin_f32(
1822        &self,
1823        a: &GpuBufferHandle,
1824        outer: usize,
1825        dim_size: usize,
1826        inner: usize,
1827    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1828        let a_buf = Self::unwrap_buffer(a)?;
1829        let dev = self.device(a.device_ordinal())?;
1830        let (vals, idxs) = crate::kernels::gpu_cummin(a_buf, outer, dim_size, inner, dev)
1831            .map_err(Self::map_gpu_err)?;
1832        let ord = a.device_ordinal();
1833        Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1834    }
1835
1836    fn logcumsumexp_f32(
1837        &self,
1838        a: &GpuBufferHandle,
1839        outer: usize,
1840        dim_size: usize,
1841        inner: usize,
1842    ) -> FerrotorchResult<GpuBufferHandle> {
1843        let a_buf = Self::unwrap_buffer(a)?;
1844        let dev = self.device(a.device_ordinal())?;
1845        let result = crate::kernels::gpu_logcumsumexp(a_buf, outer, dim_size, inner, dev)
1846            .map_err(Self::map_gpu_err)?;
1847        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1848    }
1849
1850    fn clamp_f32(
1851        &self,
1852        a: &GpuBufferHandle,
1853        min_val: f32,
1854        max_val: f32,
1855    ) -> FerrotorchResult<GpuBufferHandle> {
1856        let a_buf = Self::unwrap_buffer(a)?;
1857        let dev = self.device(a.device_ordinal())?;
1858        let result =
1859            crate::kernels::gpu_clamp(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
1860        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1861    }
1862
1863    fn silu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1864        let a_buf = Self::unwrap_buffer(a)?;
1865        let dev = self.device(a.device_ordinal())?;
1866        let result = crate::kernels::gpu_silu(a_buf, dev).map_err(Self::map_gpu_err)?;
1867        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1868    }
1869
1870    fn silu_backward_f32(
1871        &self,
1872        grad: &GpuBufferHandle,
1873        input: &GpuBufferHandle,
1874    ) -> FerrotorchResult<GpuBufferHandle> {
1875        let grad_buf = Self::unwrap_buffer(grad)?;
1876        let input_buf = Self::unwrap_buffer(input)?;
1877        let dev = self.device(grad.device_ordinal())?;
1878        let result = crate::kernels::gpu_silu_backward(grad_buf, input_buf, dev)
1879            .map_err(Self::map_gpu_err)?;
1880        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1881    }
1882
1883    fn elu_f32(&self, a: &GpuBufferHandle, alpha: f32) -> FerrotorchResult<GpuBufferHandle> {
1884        let a_buf = Self::unwrap_buffer(a)?;
1885        let dev = self.device(a.device_ordinal())?;
1886        let result = crate::kernels::gpu_elu(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
1887        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1888    }
1889
1890    fn elu_backward_f32(
1891        &self,
1892        grad: &GpuBufferHandle,
1893        input: &GpuBufferHandle,
1894        alpha: f32,
1895    ) -> FerrotorchResult<GpuBufferHandle> {
1896        let grad_buf = Self::unwrap_buffer(grad)?;
1897        let input_buf = Self::unwrap_buffer(input)?;
1898        let dev = self.device(grad.device_ordinal())?;
1899        let result = crate::kernels::gpu_elu_backward(grad_buf, input_buf, alpha, dev)
1900            .map_err(Self::map_gpu_err)?;
1901        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1902    }
1903
1904    fn mish_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1905        let a_buf = Self::unwrap_buffer(a)?;
1906        let dev = self.device(a.device_ordinal())?;
1907        let result = crate::kernels::gpu_mish(a_buf, dev).map_err(Self::map_gpu_err)?;
1908        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1909    }
1910
1911    fn mish_backward_f32(
1912        &self,
1913        grad: &GpuBufferHandle,
1914        input: &GpuBufferHandle,
1915    ) -> FerrotorchResult<GpuBufferHandle> {
1916        let grad_buf = Self::unwrap_buffer(grad)?;
1917        let input_buf = Self::unwrap_buffer(input)?;
1918        let dev = self.device(grad.device_ordinal())?;
1919        let result = crate::kernels::gpu_mish_backward(grad_buf, input_buf, dev)
1920            .map_err(Self::map_gpu_err)?;
1921        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1922    }
1923
1924    fn log_softmax_f32(
1925        &self,
1926        a: &GpuBufferHandle,
1927        cols: usize,
1928    ) -> FerrotorchResult<GpuBufferHandle> {
1929        let a_buf = Self::unwrap_buffer(a)?;
1930        let dev = self.device(a.device_ordinal())?;
1931        let result =
1932            crate::kernels::gpu_log_softmax(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
1933        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1934    }
1935
1936    fn log_softmax_backward_f32(
1937        &self,
1938        grad: &GpuBufferHandle,
1939        output: &GpuBufferHandle,
1940        cols: usize,
1941    ) -> FerrotorchResult<GpuBufferHandle> {
1942        let grad_buf = Self::unwrap_buffer(grad)?;
1943        let output_buf = Self::unwrap_buffer(output)?;
1944        let dev = self.device(grad.device_ordinal())?;
1945        let result =
1946            crate::kernels::gpu_log_softmax_backward(grad_buf, output_buf, cols, dev)
1947                .map_err(Self::map_gpu_err)?;
1948        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1949    }
1950
1951    fn index_select_1d_f32(
1952        &self,
1953        input: &GpuBufferHandle,
1954        indices: &GpuBufferHandle,
1955    ) -> FerrotorchResult<GpuBufferHandle> {
1956        let input_buf = Self::unwrap_buffer(input)?;
1957        let idx_buf = Self::unwrap_buffer(indices)?;
1958        let dev = self.device(input.device_ordinal())?;
1959        let result = crate::kernels::gpu_index_select_1d(input_buf, idx_buf, dev)
1960            .map_err(Self::map_gpu_err)?;
1961        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1962    }
1963
1964    fn scatter_add_1d_f32(
1965        &self,
1966        grad_output: &GpuBufferHandle,
1967        indices: &GpuBufferHandle,
1968        input_len: usize,
1969    ) -> FerrotorchResult<GpuBufferHandle> {
1970        let go_buf = Self::unwrap_buffer(grad_output)?;
1971        let idx_buf = Self::unwrap_buffer(indices)?;
1972        let dev = self.device(grad_output.device_ordinal())?;
1973        let result = crate::kernels::gpu_scatter_add_1d(go_buf, idx_buf, input_len, dev)
1974            .map_err(Self::map_gpu_err)?;
1975        Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
1976    }
1977
1978    fn masked_fill_f32(
1979        &self,
1980        input: &GpuBufferHandle,
1981        mask: &GpuBufferHandle,
1982        value: f32,
1983    ) -> FerrotorchResult<GpuBufferHandle> {
1984        let input_buf = Self::unwrap_buffer(input)?;
1985        let mask_buf = Self::unwrap_buffer(mask)?;
1986        let dev = self.device(input.device_ordinal())?;
1987        let result = crate::kernels::gpu_masked_fill(input_buf, mask_buf, value, dev)
1988            .map_err(Self::map_gpu_err)?;
1989        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1990    }
1991
1992    fn masked_zero_f32(
1993        &self,
1994        grad: &GpuBufferHandle,
1995        mask: &GpuBufferHandle,
1996    ) -> FerrotorchResult<GpuBufferHandle> {
1997        let grad_buf = Self::unwrap_buffer(grad)?;
1998        let mask_buf = Self::unwrap_buffer(mask)?;
1999        let dev = self.device(grad.device_ordinal())?;
2000        let result =
2001            crate::kernels::gpu_masked_zero(grad_buf, mask_buf, dev).map_err(Self::map_gpu_err)?;
2002        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2003    }
2004
2005    fn sigmoid_backward_f32(
2006        &self,
2007        grad: &GpuBufferHandle,
2008        output: &GpuBufferHandle,
2009    ) -> FerrotorchResult<GpuBufferHandle> {
2010        let grad_buf = Self::unwrap_buffer(grad)?;
2011        let output_buf = Self::unwrap_buffer(output)?;
2012        let dev = self.device(grad.device_ordinal())?;
2013        let result = crate::kernels::gpu_sigmoid_backward(grad_buf, output_buf, dev)
2014            .map_err(Self::map_gpu_err)?;
2015        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2016    }
2017
2018    fn tanh_backward_f32(
2019        &self,
2020        grad: &GpuBufferHandle,
2021        output: &GpuBufferHandle,
2022    ) -> FerrotorchResult<GpuBufferHandle> {
2023        let grad_buf = Self::unwrap_buffer(grad)?;
2024        let output_buf = Self::unwrap_buffer(output)?;
2025        let dev = self.device(grad.device_ordinal())?;
2026        let result = crate::kernels::gpu_tanh_backward(grad_buf, output_buf, dev)
2027            .map_err(Self::map_gpu_err)?;
2028        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2029    }
2030
2031    fn softmax_backward_f32(
2032        &self,
2033        grad: &GpuBufferHandle,
2034        output: &GpuBufferHandle,
2035        cols: usize,
2036    ) -> FerrotorchResult<GpuBufferHandle> {
2037        let grad_buf = Self::unwrap_buffer(grad)?;
2038        let output_buf = Self::unwrap_buffer(output)?;
2039        let dev = self.device(grad.device_ordinal())?;
2040        let result = crate::kernels::gpu_softmax_backward(grad_buf, output_buf, cols, dev)
2041            .map_err(Self::map_gpu_err)?;
2042        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2043    }
2044
2045    fn layernorm_backward_f32(
2046        &self,
2047        input: &GpuBufferHandle,
2048        grad_output: &GpuBufferHandle,
2049        weight: &GpuBufferHandle,
2050        rows: usize,
2051        cols: usize,
2052        eps: f32,
2053    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2054        let in_buf = Self::unwrap_buffer(input)?;
2055        let go_buf = Self::unwrap_buffer(grad_output)?;
2056        let w_buf = Self::unwrap_buffer(weight)?;
2057        let dev = self.device(input.device_ordinal())?;
2058        let (gi, gw, gb) =
2059            crate::kernels::gpu_layernorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
2060                .map_err(Self::map_gpu_err)?;
2061        let ordinal = input.device_ordinal();
2062        Ok((
2063            Self::wrap_buffer(gi, ordinal),
2064            Self::wrap_buffer(gw, ordinal),
2065            Self::wrap_buffer(gb, ordinal),
2066        ))
2067    }
2068
2069    fn sum_axis_f32(
2070        &self,
2071        a: &GpuBufferHandle,
2072        shape: &[usize],
2073        axis: usize,
2074    ) -> FerrotorchResult<GpuBufferHandle> {
2075        let a_buf = Self::unwrap_buffer(a)?;
2076        let dev = self.device(a.device_ordinal())?;
2077        let outer: usize = shape[..axis].iter().product();
2078        let axis_size = shape[axis];
2079        let inner: usize = shape[axis + 1..].iter().product::<usize>().max(1);
2080        let result = crate::kernels::gpu_sum_axis(a_buf, outer, axis_size, inner, dev)
2081            .map_err(Self::map_gpu_err)?;
2082        Ok(Self::wrap_buffer(result, a.device_ordinal()))
2083    }
2084
2085    fn matmul_f16_f32(
2086        &self,
2087        a: &GpuBufferHandle,
2088        b: &GpuBufferHandle,
2089        m: usize,
2090        k: usize,
2091        n: usize,
2092    ) -> FerrotorchResult<GpuBufferHandle> {
2093        let a_buf = Self::unwrap_buffer(a)?;
2094        let b_buf = Self::unwrap_buffer(b)?;
2095        let dev = self.device(a.device_ordinal())?;
2096        let result =
2097            crate::blas::gpu_matmul_f16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
2098        Ok(Self::wrap_buffer(result, a.device_ordinal()))
2099    }
2100
2101    fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
2102        let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
2103            FerrotorchError::InvalidArgument {
2104                message: "failed to lock CUDA RNG manager".into(),
2105            }
2106        })?;
2107        let state = mgr.get_rng_state(device);
2108        Ok(GpuRngState {
2109            counter: state.counter,
2110            seed: state.seed,
2111            offset: state.offset,
2112            device,
2113        })
2114    }
2115
2116    fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
2117        let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
2118            FerrotorchError::InvalidArgument {
2119                message: "failed to lock CUDA RNG manager".into(),
2120            }
2121        })?;
2122        mgr.set_rng_state(
2123            state.device,
2124            crate::rng::PhiloxState {
2125                counter: state.counter,
2126                seed: state.seed,
2127                offset: state.offset,
2128            },
2129        );
2130        Ok(())
2131    }
2132
2133    fn strided_split_f32(
2134        &self,
2135        input: &GpuBufferHandle,
2136        total_along_axis: usize,
2137        split_offset: usize,
2138        split_size: usize,
2139        inner_size: usize,
2140        n: usize,
2141    ) -> FerrotorchResult<GpuBufferHandle> {
2142        let in_buf = Self::unwrap_buffer(input)?;
2143        let dev = self.device(input.device_ordinal())?;
2144        let result = crate::kernels::gpu_strided_split(
2145            in_buf,
2146            total_along_axis,
2147            split_offset,
2148            split_size,
2149            inner_size,
2150            n,
2151            dev,
2152        )
2153        .map_err(Self::map_gpu_err)?;
2154        Ok(Self::wrap_buffer(result, input.device_ordinal()))
2155    }
2156
2157    fn strided_copy_f32(
2158        &self,
2159        input: &GpuBufferHandle,
2160        out_shape: &[usize],
2161        src_strides: &[isize],
2162        src_offset: usize,
2163    ) -> FerrotorchResult<GpuBufferHandle> {
2164        let in_buf = Self::unwrap_buffer(input)?;
2165        let dev = self.device(input.device_ordinal())?;
2166        let result =
2167            crate::kernels::gpu_strided_copy(in_buf, out_shape, src_strides, src_offset, dev)
2168                .map_err(Self::map_gpu_err)?;
2169        Ok(Self::wrap_buffer(result, input.device_ordinal()))
2170    }
2171
2172    fn strided_copy_f64(
2173        &self,
2174        input: &GpuBufferHandle,
2175        out_shape: &[usize],
2176        src_strides: &[isize],
2177        src_offset: usize,
2178    ) -> FerrotorchResult<GpuBufferHandle> {
2179        let in_buf = Self::unwrap_buffer_f64(input)?;
2180        let dev = self.device(input.device_ordinal())?;
2181        let result = crate::kernels::gpu_strided_copy_f64(
2182            in_buf,
2183            out_shape,
2184            src_strides,
2185            src_offset,
2186            dev,
2187        )
2188        .map_err(Self::map_gpu_err)?;
2189        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
2190    }
2191
2192    fn strided_cat_f32(
2193        &self,
2194        input: &GpuBufferHandle,
2195        output: &mut GpuBufferHandle,
2196        total_along_axis: usize,
2197        cat_offset: usize,
2198        part_size: usize,
2199        inner_size: usize,
2200        n: usize,
2201    ) -> FerrotorchResult<()> {
2202        let in_buf = Self::unwrap_buffer(input)?;
2203        let dev = self.device(input.device_ordinal())?;
2204        let out_buf =
2205            output
2206                .downcast_mut::<CudaBuffer<f32>>()
2207                .ok_or(FerrotorchError::InvalidArgument {
2208                    message: "strided_cat_f32: output is not CudaBuffer<f32>".into(),
2209                })?;
2210        crate::kernels::gpu_strided_cat(
2211            in_buf,
2212            out_buf,
2213            total_along_axis,
2214            cat_offset,
2215            part_size,
2216            inner_size,
2217            n,
2218            dev,
2219        )
2220        .map_err(Self::map_gpu_err)?;
2221        Ok(())
2222    }
2223
2224    // -- cuSOLVER linear algebra -------------------------------------------------
2225
2226    fn svd_f32(
2227        &self,
2228        a: &GpuBufferHandle,
2229        m: usize,
2230        n: usize,
2231    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2232        let a_buf = Self::unwrap_buffer(a)?;
2233        let dev = self.device(a.device_ordinal())?;
2234        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2235        let (u, s, vt) =
2236            crate::cusolver::gpu_svd_f32(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2237        let u_buf = crate::transfer::cpu_to_gpu(&u, dev).map_err(Self::map_gpu_err)?;
2238        let s_buf = crate::transfer::cpu_to_gpu(&s, dev).map_err(Self::map_gpu_err)?;
2239        let vt_buf = crate::transfer::cpu_to_gpu(&vt, dev).map_err(Self::map_gpu_err)?;
2240        let ord = a.device_ordinal();
2241        Ok((
2242            Self::wrap_buffer(u_buf, ord),
2243            Self::wrap_buffer(s_buf, ord),
2244            Self::wrap_buffer(vt_buf, ord),
2245        ))
2246    }
2247
2248    fn svd_f64(
2249        &self,
2250        a: &GpuBufferHandle,
2251        m: usize,
2252        n: usize,
2253    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2254        let a_buf = Self::unwrap_buffer_f64(a)?;
2255        let dev = self.device(a.device_ordinal())?;
2256        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2257        let (u, s, vt) =
2258            crate::cusolver::gpu_svd_f64(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2259        let u_buf = crate::transfer::cpu_to_gpu(&u, dev).map_err(Self::map_gpu_err)?;
2260        let s_buf = crate::transfer::cpu_to_gpu(&s, dev).map_err(Self::map_gpu_err)?;
2261        let vt_buf = crate::transfer::cpu_to_gpu(&vt, dev).map_err(Self::map_gpu_err)?;
2262        let ord = a.device_ordinal();
2263        Ok((
2264            Self::wrap_buffer_f64(u_buf, ord),
2265            Self::wrap_buffer_f64(s_buf, ord),
2266            Self::wrap_buffer_f64(vt_buf, ord),
2267        ))
2268    }
2269
2270    fn cholesky_f32(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
2271        let a_buf = Self::unwrap_buffer(a)?;
2272        let dev = self.device(a.device_ordinal())?;
2273        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2274        let l = crate::cusolver::gpu_cholesky_f32(&a_host, n, dev).map_err(Self::map_gpu_err)?;
2275        let l_buf = crate::transfer::cpu_to_gpu(&l, dev).map_err(Self::map_gpu_err)?;
2276        Ok(Self::wrap_buffer(l_buf, a.device_ordinal()))
2277    }
2278
2279    fn cholesky_f64(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
2280        let a_buf = Self::unwrap_buffer_f64(a)?;
2281        let dev = self.device(a.device_ordinal())?;
2282        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2283        let l = crate::cusolver::gpu_cholesky_f64(&a_host, n, dev).map_err(Self::map_gpu_err)?;
2284        let l_buf = crate::transfer::cpu_to_gpu(&l, dev).map_err(Self::map_gpu_err)?;
2285        Ok(Self::wrap_buffer_f64(l_buf, a.device_ordinal()))
2286    }
2287
2288    fn solve_f32(
2289        &self,
2290        a: &GpuBufferHandle,
2291        b: &GpuBufferHandle,
2292        n: usize,
2293        nrhs: usize,
2294    ) -> FerrotorchResult<GpuBufferHandle> {
2295        let a_buf = Self::unwrap_buffer(a)?;
2296        let b_buf = Self::unwrap_buffer(b)?;
2297        let dev = self.device(a.device_ordinal())?;
2298        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2299        let b_host = crate::transfer::gpu_to_cpu(b_buf, dev).map_err(Self::map_gpu_err)?;
2300        let x =
2301            crate::cusolver::gpu_solve_f32(&a_host, &b_host, n, nrhs, dev)
2302                .map_err(Self::map_gpu_err)?;
2303        let x_buf = crate::transfer::cpu_to_gpu(&x, dev).map_err(Self::map_gpu_err)?;
2304        Ok(Self::wrap_buffer(x_buf, a.device_ordinal()))
2305    }
2306
2307    fn solve_f64(
2308        &self,
2309        a: &GpuBufferHandle,
2310        b: &GpuBufferHandle,
2311        n: usize,
2312        nrhs: usize,
2313    ) -> FerrotorchResult<GpuBufferHandle> {
2314        let a_buf = Self::unwrap_buffer_f64(a)?;
2315        let b_buf = Self::unwrap_buffer_f64(b)?;
2316        let dev = self.device(a.device_ordinal())?;
2317        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2318        let b_host = crate::transfer::gpu_to_cpu(b_buf, dev).map_err(Self::map_gpu_err)?;
2319        let x =
2320            crate::cusolver::gpu_solve_f64(&a_host, &b_host, n, nrhs, dev)
2321                .map_err(Self::map_gpu_err)?;
2322        let x_buf = crate::transfer::cpu_to_gpu(&x, dev).map_err(Self::map_gpu_err)?;
2323        Ok(Self::wrap_buffer_f64(x_buf, a.device_ordinal()))
2324    }
2325
2326    fn qr_f32(
2327        &self,
2328        a: &GpuBufferHandle,
2329        m: usize,
2330        n: usize,
2331    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
2332        let a_buf = Self::unwrap_buffer(a)?;
2333        let dev = self.device(a.device_ordinal())?;
2334        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2335        let (q, r) =
2336            crate::cusolver::gpu_qr_f32(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2337        let q_buf = crate::transfer::cpu_to_gpu(&q, dev).map_err(Self::map_gpu_err)?;
2338        let r_buf = crate::transfer::cpu_to_gpu(&r, dev).map_err(Self::map_gpu_err)?;
2339        let ord = a.device_ordinal();
2340        Ok((Self::wrap_buffer(q_buf, ord), Self::wrap_buffer(r_buf, ord)))
2341    }
2342
2343    fn qr_f64(
2344        &self,
2345        a: &GpuBufferHandle,
2346        m: usize,
2347        n: usize,
2348    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
2349        let a_buf = Self::unwrap_buffer_f64(a)?;
2350        let dev = self.device(a.device_ordinal())?;
2351        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2352        let (q, r) =
2353            crate::cusolver::gpu_qr_f64(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2354        let q_buf = crate::transfer::cpu_to_gpu(&q, dev).map_err(Self::map_gpu_err)?;
2355        let r_buf = crate::transfer::cpu_to_gpu(&r, dev).map_err(Self::map_gpu_err)?;
2356        let ord = a.device_ordinal();
2357        Ok((
2358            Self::wrap_buffer_f64(q_buf, ord),
2359            Self::wrap_buffer_f64(r_buf, ord),
2360        ))
2361    }
2362}
2363
2364// ---------------------------------------------------------------------------
2365// Registration
2366// ---------------------------------------------------------------------------
2367
2368/// Get the `GpuDevice` from the registered CUDA backend.
2369///
2370/// This retrieves the device that was created during [`init_cuda_backend`],
2371/// ensuring all kernel modules and cuBLAS handles are shared. Creating a
2372/// second `GpuDevice` via `GpuDevice::new(0)` would create a separate
2373/// CUDA context with its own module cache, which is not interoperable.
2374pub fn get_cuda_device() -> FerrotorchResult<Arc<GpuDevice>> {
2375    let backend =
2376        ferrotorch_core::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
2377    // The global backend is a &dyn GpuBackend. We know it's CudaBackendImpl
2378    // because init_cuda_backend registered it. Downcast via Any.
2379    let cuda_backend = backend.as_any().downcast_ref::<CudaBackendImpl>().ok_or(
2380        FerrotorchError::InvalidArgument {
2381            message: "registered GPU backend is not CudaBackendImpl".into(),
2382        },
2383    )?;
2384    Ok(Arc::clone(cuda_backend.default_device()?))
2385}
2386
2387/// Initialize the CUDA backend and register it with ferrotorch-core.
2388///
2389/// This must be called before any GPU tensor operations. It creates a
2390/// [`CudaBackendImpl`] (initializing CUDA device 0) and registers it via
2391/// [`ferrotorch_core::gpu_dispatch::register_gpu_backend`].
2392///
2393/// Calling this a second time returns an error (the backend is already
2394/// registered).
2395///
2396/// # Errors
2397///
2398/// - [`FerrotorchError::InvalidArgument`] if CUDA initialization fails.
2399/// - [`FerrotorchError::InvalidArgument`] if a GPU backend is already registered.
2400pub fn init_cuda_backend() -> FerrotorchResult<()> {
2401    // Idempotent: if already registered, return Ok silently.
2402    if ferrotorch_core::gpu_dispatch::has_gpu_backend() {
2403        return Ok(());
2404    }
2405    let backend = CudaBackendImpl::new()?;
2406    // OnceLock::set can still race if two threads call init concurrently —
2407    // if that happens, the second set() fails but the backend is registered
2408    // by the first. We treat that as success.
2409    let _ = ferrotorch_core::gpu_dispatch::register_gpu_backend(Box::new(backend));
2410    Ok(())
2411}
2412
2413// ---------------------------------------------------------------------------
2414// Tests
2415// ---------------------------------------------------------------------------
2416
2417#[cfg(test)]
2418#[cfg(feature = "cuda")]
2419mod tests {
2420    use super::*;
2421    use ferrotorch_core::gpu_dispatch;
2422
2423    // Note: Because `register_gpu_backend` uses a `OnceLock`, only the first
2424    // test to call `init_cuda_backend()` will succeed at registration. The
2425    // others will see the backend as already registered. We handle this by
2426    // checking `has_gpu_backend()` before calling init.
2427
2428    /// Ensure the backend can be initialized (or was already initialized).
2429    fn ensure_init() {
2430        if !gpu_dispatch::has_gpu_backend() {
2431            init_cuda_backend().expect("init_cuda_backend");
2432        }
2433    }
2434
2435    #[test]
2436    fn test_init_cuda_backend() {
2437        // First call succeeds (or backend was already registered by another test).
2438        ensure_init();
2439        assert!(gpu_dispatch::has_gpu_backend());
2440    }
2441
2442    #[test]
2443    fn test_gpu_backend_returns_some() {
2444        ensure_init();
2445        assert!(gpu_dispatch::gpu_backend().is_some());
2446    }
2447
2448    #[test]
2449    fn test_roundtrip_cpu_gpu_cpu() {
2450        ensure_init();
2451        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2452
2453        let host: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2454        let bytes: &[u8] = unsafe {
2455            std::slice::from_raw_parts(
2456                host.as_ptr() as *const u8,
2457                host.len() * std::mem::size_of::<f32>(),
2458            )
2459        };
2460
2461        let handle = backend.cpu_to_gpu(bytes, 4, 0).expect("cpu_to_gpu");
2462        assert_eq!(handle.len(), 5);
2463        assert_eq!(handle.device_ordinal(), 0);
2464
2465        let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu");
2466        let back: &[f32] = unsafe {
2467            std::slice::from_raw_parts(back_bytes.as_ptr() as *const f32, back_bytes.len() / 4)
2468        };
2469        assert_eq!(back, &host[..]);
2470    }
2471
2472    #[test]
2473    fn test_add_f32() {
2474        ensure_init();
2475        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2476
2477        let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2478        let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
2479        let expected: Vec<f32> = vec![11.0, 22.0, 33.0, 44.0];
2480
2481        let a_bytes: &[u8] =
2482            unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
2483        let b_bytes: &[u8] =
2484            unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
2485
2486        let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
2487        let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
2488
2489        let result = backend.add_f32(&a_handle, &b_handle).expect("add_f32");
2490        assert_eq!(result.len(), 4);
2491
2492        let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
2493        let result_f32: &[f32] = unsafe {
2494            std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
2495        };
2496
2497        for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
2498            assert!(
2499                (got - exp).abs() < 1e-6,
2500                "element {i}: got {got}, expected {exp}",
2501            );
2502        }
2503    }
2504
2505    #[test]
2506    fn test_matmul_f32() {
2507        ensure_init();
2508        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2509
2510        // A = [[1, 2, 3],
2511        //      [4, 5, 6]]  (2x3)
2512        // B = [[7, 8],
2513        //      [9, 10],
2514        //      [11, 12]]   (3x2)
2515        // C = [[58, 64],
2516        //      [139, 154]] (2x2)
2517        let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
2518        let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
2519        let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
2520
2521        let a_bytes: &[u8] =
2522            unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
2523        let b_bytes: &[u8] =
2524            unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
2525
2526        let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
2527        let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
2528
2529        let result = backend
2530            .matmul_f32(&a_handle, &b_handle, 2, 3, 2)
2531            .expect("matmul_f32");
2532        assert_eq!(result.len(), 4);
2533
2534        let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
2535        let result_f32: &[f32] = unsafe {
2536            std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
2537        };
2538
2539        for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
2540            assert!(
2541                (got - exp).abs() < 1e-3,
2542                "element {i}: got {got}, expected {exp}",
2543            );
2544        }
2545    }
2546}