Skip to main content

ferrotorch_gpu/
backend_impl.rs

1//! CUDA implementation of the [`GpuBackend`] trait from ferrotorch-core.
2//!
3//! This module bridges the existing GPU operations (`gpu_add`, `gpu_matmul_f32`,
4//! etc.) to the type-erased [`GpuBackend`] dispatch interface, enabling
5//! ferrotorch-core to call GPU operations without depending on this crate
6//! directly.
7//!
8//! # Initialization
9//!
10//! Call [`init_cuda_backend`] once at startup (typically via `ferrotorch::init()`).
11//! This creates a [`CudaBackendImpl`], initializes CUDA device 0, and registers
12//! it with [`ferrotorch_core::gpu_dispatch::register_gpu_backend`].
13
14use std::sync::Arc;
15
16use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
17use ferrotorch_core::gpu_dispatch::{GpuBackend, GpuBufferHandle, GpuRngState};
18
19use crate::buffer::CudaBuffer;
20use crate::device::GpuDevice;
21
22// ---------------------------------------------------------------------------
23// CudaBackendImpl
24// ---------------------------------------------------------------------------
25
26/// CUDA implementation of the [`GpuBackend`] trait.
27///
28/// Holds one or more [`GpuDevice`] handles (currently device 0 only) and
29/// delegates every trait method to the corresponding function in
30/// [`crate::kernels`], [`crate::blas`], or [`crate::transfer`].
31pub struct CudaBackendImpl {
32    devices: Vec<Arc<GpuDevice>>,
33}
34
35impl CudaBackendImpl {
36    /// Create a new CUDA backend, initializing device 0.
37    ///
38    /// # Errors
39    ///
40    /// Returns [`FerrotorchError::InvalidArgument`] if CUDA initialization fails
41    /// (e.g. no GPU available, driver not loaded).
42    pub fn new() -> FerrotorchResult<Self> {
43        let device = Arc::new(
44            GpuDevice::new(0).map_err(|e| FerrotorchError::InvalidArgument {
45                message: format!("CUDA init failed: {e}"),
46            })?,
47        );
48        Ok(Self {
49            devices: vec![device],
50        })
51    }
52
53    /// Get the device for ordinal 0 (the default device).
54    pub fn default_device(&self) -> FerrotorchResult<&Arc<GpuDevice>> {
55        self.device(0)
56    }
57
58    /// Look up a device by ordinal.
59    fn device(&self, ordinal: usize) -> FerrotorchResult<&Arc<GpuDevice>> {
60        self.devices
61            .get(ordinal)
62            .ok_or(FerrotorchError::InvalidArgument {
63                message: format!("CUDA device {ordinal} not available"),
64            })
65    }
66
67    /// Wrap a `CudaBuffer<f32>` into a type-erased [`GpuBufferHandle`].
68    fn wrap_buffer(buf: CudaBuffer<f32>, ordinal: usize) -> GpuBufferHandle {
69        let len = buf.len();
70        GpuBufferHandle::new(Box::new(buf), ordinal, len)
71    }
72
73    /// Wrap a `CudaBuffer<f64>` into a type-erased [`GpuBufferHandle`].
74    fn wrap_buffer_f64(buf: CudaBuffer<f64>, ordinal: usize) -> GpuBufferHandle {
75        let len = buf.len();
76        GpuBufferHandle::new(Box::new(buf), ordinal, len)
77    }
78
79    /// Extract a `&CudaBuffer<f32>` from a [`GpuBufferHandle`].
80    fn unwrap_buffer(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f32>> {
81        handle
82            .downcast_ref::<CudaBuffer<f32>>()
83            .ok_or(FerrotorchError::InvalidArgument {
84                message: "GPU handle does not contain a CudaBuffer<f32>".into(),
85            })
86    }
87
88    /// Extract a `&mut CudaBuffer<f32>` from a [`GpuBufferHandle`].
89    fn unwrap_buffer_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f32>> {
90        handle
91            .downcast_mut::<CudaBuffer<f32>>()
92            .ok_or(FerrotorchError::InvalidArgument {
93                message: "GPU handle does not contain a CudaBuffer<f32>".into(),
94            })
95    }
96
97    /// Extract a `&mut CudaBuffer<f64>` from a [`GpuBufferHandle`].
98    fn unwrap_buffer_f64_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f64>> {
99        handle
100            .downcast_mut::<CudaBuffer<f64>>()
101            .ok_or(FerrotorchError::InvalidArgument {
102                message: "GPU handle does not contain a CudaBuffer<f64>".into(),
103            })
104    }
105
106    /// Extract a `&CudaBuffer<f64>` from a [`GpuBufferHandle`].
107    fn unwrap_buffer_f64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f64>> {
108        handle
109            .downcast_ref::<CudaBuffer<f64>>()
110            .ok_or(FerrotorchError::InvalidArgument {
111                message: "GPU handle does not contain a CudaBuffer<f64>".into(),
112            })
113    }
114
115    /// Convert a [`crate::error::GpuError`] into a [`FerrotorchError`].
116    fn map_gpu_err(e: crate::error::GpuError) -> FerrotorchError {
117        FerrotorchError::InvalidArgument {
118            message: format!("{e}"),
119        }
120    }
121}
122
123// ---------------------------------------------------------------------------
124// GpuBackend implementation
125// ---------------------------------------------------------------------------
126
127impl GpuBackend for CudaBackendImpl {
128    fn as_any(&self) -> &dyn std::any::Any {
129        self
130    }
131
132    fn raw_device_ptr(&self, handle: &GpuBufferHandle) -> *const std::ffi::c_void {
133        use cudarc::driver::DevicePtr;
134        let dev = match self.device(handle.device_ordinal()) {
135            Ok(d) => d,
136            Err(_) => return std::ptr::null(),
137        };
138        let stream = dev.stream();
139        if let Ok(buf) = Self::unwrap_buffer(handle) {
140            let (ptr, _sync) = buf.inner().device_ptr(&stream);
141            ptr as *const std::ffi::c_void
142        } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
143            let (ptr, _sync) = buf.inner().device_ptr(&stream);
144            ptr as *const std::ffi::c_void
145        } else {
146            std::ptr::null()
147        }
148    }
149
150    fn raw_device_ptr_mut(&self, handle: &mut GpuBufferHandle) -> *mut std::ffi::c_void {
151        use cudarc::driver::DevicePtrMut;
152        let ordinal = handle.device_ordinal();
153        let dev = match self.device(ordinal) {
154            Ok(d) => d,
155            Err(_) => return std::ptr::null_mut(),
156        };
157        let stream = dev.stream();
158        if let Some(buf) = handle.downcast_mut::<CudaBuffer<f32>>() {
159            let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
160            ptr as *mut std::ffi::c_void
161        } else if let Some(buf) = handle.downcast_mut::<CudaBuffer<f64>>() {
162            let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
163            ptr as *mut std::ffi::c_void
164        } else {
165            std::ptr::null_mut()
166        }
167    }
168
169    fn buffer_elem_size(&self, handle: &GpuBufferHandle) -> usize {
170        if Self::unwrap_buffer(handle).is_ok() {
171            4 // f32
172        } else if Self::unwrap_buffer_f64(handle).is_ok() {
173            8 // f64
174        } else {
175            0
176        }
177    }
178
179    fn cpu_to_gpu(
180        &self,
181        data: &[u8],
182        elem_size: usize,
183        device: usize,
184    ) -> FerrotorchResult<GpuBufferHandle> {
185        let dev = self.device(device)?;
186        match elem_size {
187            4 => {
188                // SAFETY: The caller (ferrotorch-core) guarantees that `data`
189                // was originally an f32 slice serialised to bytes.
190                let count = data.len() / 4;
191                let f32_data: &[f32] =
192                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
193                let buf = crate::transfer::cpu_to_gpu(f32_data, dev).map_err(Self::map_gpu_err)?;
194                Ok(Self::wrap_buffer(buf, device))
195            }
196            8 => {
197                // SAFETY: The caller (ferrotorch-core) guarantees that `data`
198                // was originally an f64 slice serialised to bytes.
199                let count = data.len() / 8;
200                let f64_data: &[f64] =
201                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
202                let buf = crate::transfer::cpu_to_gpu(f64_data, dev).map_err(Self::map_gpu_err)?;
203                Ok(Self::wrap_buffer_f64(buf, device))
204            }
205            other => Err(FerrotorchError::InvalidArgument {
206                message: format!("cpu_to_gpu: unsupported elem_size {other} (expected 4 or 8)"),
207            }),
208        }
209    }
210
211    fn cpu_to_gpu_pinned(
212        &self,
213        data: &[u8],
214        elem_size: usize,
215        device: usize,
216    ) -> FerrotorchResult<GpuBufferHandle> {
217        let dev = self.device(device)?;
218        match elem_size {
219            4 => {
220                let count = data.len() / 4;
221                let f32_data: &[f32] =
222                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
223                let buf = crate::transfer::cpu_to_gpu_pinned(f32_data, dev)
224                    .map_err(Self::map_gpu_err)?;
225                Ok(Self::wrap_buffer(buf, device))
226            }
227            8 => {
228                let count = data.len() / 8;
229                let f64_data: &[f64] =
230                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
231                let buf = crate::transfer::cpu_to_gpu_pinned(f64_data, dev)
232                    .map_err(Self::map_gpu_err)?;
233                Ok(Self::wrap_buffer_f64(buf, device))
234            }
235            other => Err(FerrotorchError::InvalidArgument {
236                message: format!(
237                    "cpu_to_gpu_pinned: unsupported elem_size {other} (expected 4 or 8)"
238                ),
239            }),
240        }
241    }
242
243    fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>> {
244        let dev = self.device(handle.device_ordinal())?;
245
246        // Try f32 first, then f64.
247        if let Ok(buf) = Self::unwrap_buffer(handle) {
248            let f32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
249
250            // Reinterpret Vec<f32> as Vec<u8> without copying.
251            // SAFETY: f32 has alignment 4 and size 4. We adjust len and capacity
252            // accordingly. The original Vec is consumed via ManuallyDrop so its
253            // destructor won't free the allocation.
254            let bytes = unsafe {
255                let mut v = std::mem::ManuallyDrop::new(f32_data);
256                let ptr = v.as_mut_ptr() as *mut u8;
257                let len = v.len() * 4;
258                let cap = v.capacity() * 4;
259                Vec::from_raw_parts(ptr, len, cap)
260            };
261            Ok(bytes)
262        } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
263            let f64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
264
265            // Reinterpret Vec<f64> as Vec<u8> without copying.
266            // SAFETY: f64 has alignment 8 and size 8. We adjust len and capacity
267            // accordingly. The original Vec is consumed via ManuallyDrop so its
268            // destructor won't free the allocation.
269            let bytes = unsafe {
270                let mut v = std::mem::ManuallyDrop::new(f64_data);
271                let ptr = v.as_mut_ptr() as *mut u8;
272                let len = v.len() * 8;
273                let cap = v.capacity() * 8;
274                Vec::from_raw_parts(ptr, len, cap)
275            };
276            Ok(bytes)
277        } else {
278            Err(FerrotorchError::InvalidArgument {
279                message: "gpu_to_cpu: handle is neither CudaBuffer<f32> nor CudaBuffer<f64>".into(),
280            })
281        }
282    }
283
284    fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
285        // Clone via GPU -> CPU -> GPU round-trip.
286        // Correct but not optimal; a device-to-device memcpy would be better.
287        let bytes = self.gpu_to_cpu(handle)?;
288        // Determine elem_size from the concrete buffer type.
289        let elem_size = if handle.downcast_ref::<CudaBuffer<f64>>().is_some() {
290            8
291        } else {
292            4
293        };
294        self.cpu_to_gpu(&bytes, elem_size, handle.device_ordinal())
295    }
296
297    fn alloc_zeros(
298        &self,
299        len: usize,
300        elem_size: usize,
301        device: usize,
302    ) -> FerrotorchResult<GpuBufferHandle> {
303        let dev = self.device(device)?;
304        match elem_size {
305            4 => {
306                let buf = crate::transfer::alloc_zeros_f32(len, dev).map_err(Self::map_gpu_err)?;
307                Ok(Self::wrap_buffer(buf, device))
308            }
309            8 => {
310                let buf = crate::transfer::alloc_zeros_f64(len, dev).map_err(Self::map_gpu_err)?;
311                Ok(Self::wrap_buffer_f64(buf, device))
312            }
313            other => Err(FerrotorchError::InvalidArgument {
314                message: format!("alloc_zeros: unsupported elem_size {other} (expected 4 or 8)"),
315            }),
316        }
317    }
318
319    // -- Elementwise f32 ------------------------------------------------------
320
321    fn add_f32(
322        &self,
323        a: &GpuBufferHandle,
324        b: &GpuBufferHandle,
325    ) -> FerrotorchResult<GpuBufferHandle> {
326        let a_buf = Self::unwrap_buffer(a)?;
327        let b_buf = Self::unwrap_buffer(b)?;
328        let dev = self.device(a.device_ordinal())?;
329        let result = crate::kernels::gpu_add(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
330        Ok(Self::wrap_buffer(result, a.device_ordinal()))
331    }
332
333    fn sub_f32(
334        &self,
335        a: &GpuBufferHandle,
336        b: &GpuBufferHandle,
337    ) -> FerrotorchResult<GpuBufferHandle> {
338        let a_buf = Self::unwrap_buffer(a)?;
339        let b_buf = Self::unwrap_buffer(b)?;
340        let dev = self.device(a.device_ordinal())?;
341        let result = crate::kernels::gpu_sub(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
342        Ok(Self::wrap_buffer(result, a.device_ordinal()))
343    }
344
345    fn mul_f32(
346        &self,
347        a: &GpuBufferHandle,
348        b: &GpuBufferHandle,
349    ) -> FerrotorchResult<GpuBufferHandle> {
350        let a_buf = Self::unwrap_buffer(a)?;
351        let b_buf = Self::unwrap_buffer(b)?;
352        let dev = self.device(a.device_ordinal())?;
353        let result = crate::kernels::gpu_mul(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
354        Ok(Self::wrap_buffer(result, a.device_ordinal()))
355    }
356
357    fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
358        let a_buf = Self::unwrap_buffer(a)?;
359        let dev = self.device(a.device_ordinal())?;
360        let result = crate::kernels::gpu_neg(a_buf, dev).map_err(Self::map_gpu_err)?;
361        Ok(Self::wrap_buffer(result, a.device_ordinal()))
362    }
363
364    fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
365        let a_buf = Self::unwrap_buffer(a)?;
366        let dev = self.device(a.device_ordinal())?;
367        let result = crate::kernels::gpu_relu(a_buf, dev).map_err(Self::map_gpu_err)?;
368        Ok(Self::wrap_buffer(result, a.device_ordinal()))
369    }
370
371    fn div_f32(
372        &self,
373        a: &GpuBufferHandle,
374        b: &GpuBufferHandle,
375    ) -> FerrotorchResult<GpuBufferHandle> {
376        let a_buf = Self::unwrap_buffer(a)?;
377        let b_buf = Self::unwrap_buffer(b)?;
378        let dev = self.device(a.device_ordinal())?;
379        let result = crate::kernels::gpu_div(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
380        Ok(Self::wrap_buffer(result, a.device_ordinal()))
381    }
382
383    fn exp_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
384        let a_buf = Self::unwrap_buffer(a)?;
385        let dev = self.device(a.device_ordinal())?;
386        let result = crate::kernels::gpu_exp(a_buf, dev).map_err(Self::map_gpu_err)?;
387        Ok(Self::wrap_buffer(result, a.device_ordinal()))
388    }
389
390    fn log_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
391        let a_buf = Self::unwrap_buffer(a)?;
392        let dev = self.device(a.device_ordinal())?;
393        let result = crate::kernels::gpu_log(a_buf, dev).map_err(Self::map_gpu_err)?;
394        Ok(Self::wrap_buffer(result, a.device_ordinal()))
395    }
396
397    fn sqrt_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
398        let a_buf = Self::unwrap_buffer(a)?;
399        let dev = self.device(a.device_ordinal())?;
400        let result = crate::kernels::gpu_sqrt(a_buf, dev).map_err(Self::map_gpu_err)?;
401        Ok(Self::wrap_buffer(result, a.device_ordinal()))
402    }
403
404    fn pow_f32(&self, a: &GpuBufferHandle, exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
405        let a_buf = Self::unwrap_buffer(a)?;
406        let dev = self.device(a.device_ordinal())?;
407        let result =
408            crate::kernels::gpu_pow(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
409        Ok(Self::wrap_buffer(result, a.device_ordinal()))
410    }
411
412    fn abs_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
413        let a_buf = Self::unwrap_buffer(a)?;
414        let dev = self.device(a.device_ordinal())?;
415        let result = crate::kernels::gpu_abs(a_buf, dev).map_err(Self::map_gpu_err)?;
416        Ok(Self::wrap_buffer(result, a.device_ordinal()))
417    }
418
419    fn sigmoid_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
420        let a_buf = Self::unwrap_buffer(a)?;
421        let dev = self.device(a.device_ordinal())?;
422        let result = crate::kernels::gpu_sigmoid(a_buf, dev).map_err(Self::map_gpu_err)?;
423        Ok(Self::wrap_buffer(result, a.device_ordinal()))
424    }
425
426    fn tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
427        let a_buf = Self::unwrap_buffer(a)?;
428        let dev = self.device(a.device_ordinal())?;
429        let result = crate::kernels::gpu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
430        Ok(Self::wrap_buffer(result, a.device_ordinal()))
431    }
432
433    // -----------------------------------------------------------------------
434    // f64 elementwise ops
435    // -----------------------------------------------------------------------
436
437    fn add_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
438        let a_buf = Self::unwrap_buffer_f64(a)?;
439        let b_buf = Self::unwrap_buffer_f64(b)?;
440        let dev = self.device(a.device_ordinal())?;
441        let result = crate::kernels::gpu_add_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
442        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
443    }
444
445    fn sub_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
446        let a_buf = Self::unwrap_buffer_f64(a)?;
447        let b_buf = Self::unwrap_buffer_f64(b)?;
448        let dev = self.device(a.device_ordinal())?;
449        let result = crate::kernels::gpu_sub_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
450        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
451    }
452
453    fn mul_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
454        let a_buf = Self::unwrap_buffer_f64(a)?;
455        let b_buf = Self::unwrap_buffer_f64(b)?;
456        let dev = self.device(a.device_ordinal())?;
457        let result = crate::kernels::gpu_mul_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
458        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
459    }
460
461    fn div_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
462        let a_buf = Self::unwrap_buffer_f64(a)?;
463        let b_buf = Self::unwrap_buffer_f64(b)?;
464        let dev = self.device(a.device_ordinal())?;
465        let result = crate::kernels::gpu_div_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
466        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
467    }
468
469    fn neg_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
470        let a_buf = Self::unwrap_buffer_f64(a)?;
471        let dev = self.device(a.device_ordinal())?;
472        let result = crate::kernels::gpu_neg_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
473        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
474    }
475
476    fn relu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
477        let a_buf = Self::unwrap_buffer_f64(a)?;
478        let dev = self.device(a.device_ordinal())?;
479        let result = crate::kernels::gpu_relu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
480        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
481    }
482
483    fn scale_f64(&self, a: &GpuBufferHandle, scalar: f64) -> FerrotorchResult<GpuBufferHandle> {
484        let a_buf = Self::unwrap_buffer_f64(a)?;
485        let dev = self.device(a.device_ordinal())?;
486        let result = crate::kernels::gpu_scale_f64(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
487        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
488    }
489
490    fn exp_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
491        let a_buf = Self::unwrap_buffer_f64(a)?;
492        let dev = self.device(a.device_ordinal())?;
493        let result = crate::kernels::gpu_exp_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
494        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
495    }
496
497    fn log_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
498        let a_buf = Self::unwrap_buffer_f64(a)?;
499        let dev = self.device(a.device_ordinal())?;
500        let result = crate::kernels::gpu_log_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
501        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
502    }
503
504    fn sqrt_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
505        let a_buf = Self::unwrap_buffer_f64(a)?;
506        let dev = self.device(a.device_ordinal())?;
507        let result = crate::kernels::gpu_sqrt_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
508        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
509    }
510
511    fn pow_f64(&self, a: &GpuBufferHandle, exponent: f64) -> FerrotorchResult<GpuBufferHandle> {
512        let a_buf = Self::unwrap_buffer_f64(a)?;
513        let dev = self.device(a.device_ordinal())?;
514        let result = crate::kernels::gpu_pow_f64(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
515        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
516    }
517
518    fn abs_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
519        let a_buf = Self::unwrap_buffer_f64(a)?;
520        let dev = self.device(a.device_ordinal())?;
521        let result = crate::kernels::gpu_abs_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
522        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
523    }
524
525    fn sigmoid_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
526        let a_buf = Self::unwrap_buffer_f64(a)?;
527        let dev = self.device(a.device_ordinal())?;
528        let result = crate::kernels::gpu_sigmoid_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
529        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
530    }
531
532    fn tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
533        let a_buf = Self::unwrap_buffer_f64(a)?;
534        let dev = self.device(a.device_ordinal())?;
535        let result = crate::kernels::gpu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
536        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
537    }
538
539    // f64 backward ops
540    fn relu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
541        let g_buf = Self::unwrap_buffer_f64(grad)?;
542        let i_buf = Self::unwrap_buffer_f64(input)?;
543        let dev = self.device(grad.device_ordinal())?;
544        let result = crate::kernels::gpu_relu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
545        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
546    }
547
548    fn sigmoid_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
549        let g_buf = Self::unwrap_buffer_f64(grad)?;
550        let o_buf = Self::unwrap_buffer_f64(output)?;
551        let dev = self.device(grad.device_ordinal())?;
552        let result = crate::kernels::gpu_sigmoid_backward_f64(g_buf, o_buf, dev).map_err(Self::map_gpu_err)?;
553        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
554    }
555
556    fn tanh_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
557        let g_buf = Self::unwrap_buffer_f64(grad)?;
558        let o_buf = Self::unwrap_buffer_f64(output)?;
559        let dev = self.device(grad.device_ordinal())?;
560        let result = crate::kernels::gpu_tanh_backward_f64(g_buf, o_buf, dev).map_err(Self::map_gpu_err)?;
561        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
562    }
563
564    // f64 activation forward ops
565
566    fn gelu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
567        let a_buf = Self::unwrap_buffer_f64(a)?;
568        let dev = self.device(a.device_ordinal())?;
569        let result = crate::kernels::gpu_gelu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
570        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
571    }
572
573    fn gelu_tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
574        let a_buf = Self::unwrap_buffer_f64(a)?;
575        let dev = self.device(a.device_ordinal())?;
576        let result = crate::kernels::gpu_gelu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
577        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
578    }
579
580    fn gelu_erf_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
581        let a_buf = Self::unwrap_buffer_f64(a)?;
582        let dev = self.device(a.device_ordinal())?;
583        let result = crate::kernels::gpu_gelu_erf_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
584        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
585    }
586
587    fn silu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
588        let a_buf = Self::unwrap_buffer_f64(a)?;
589        let dev = self.device(a.device_ordinal())?;
590        let result = crate::kernels::gpu_silu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
591        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
592    }
593
594    fn elu_f64(&self, a: &GpuBufferHandle, alpha: f64) -> FerrotorchResult<GpuBufferHandle> {
595        let a_buf = Self::unwrap_buffer_f64(a)?;
596        let dev = self.device(a.device_ordinal())?;
597        let result = crate::kernels::gpu_elu_f64(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
598        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
599    }
600
601    fn mish_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
602        let a_buf = Self::unwrap_buffer_f64(a)?;
603        let dev = self.device(a.device_ordinal())?;
604        let result = crate::kernels::gpu_mish_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
605        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
606    }
607
608    fn clamp_f64(&self, a: &GpuBufferHandle, min_val: f64, max_val: f64) -> FerrotorchResult<GpuBufferHandle> {
609        let a_buf = Self::unwrap_buffer_f64(a)?;
610        let dev = self.device(a.device_ordinal())?;
611        let result = crate::kernels::gpu_clamp_f64(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
612        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
613    }
614
615    // f64 activation backward ops
616
617    fn gelu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
618        let g_buf = Self::unwrap_buffer_f64(grad)?;
619        let i_buf = Self::unwrap_buffer_f64(input)?;
620        let dev = self.device(grad.device_ordinal())?;
621        let result = crate::kernels::gpu_gelu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
622        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
623    }
624
625    fn gelu_backward_tanh_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
626        let g_buf = Self::unwrap_buffer_f64(grad)?;
627        let i_buf = Self::unwrap_buffer_f64(input)?;
628        let dev = self.device(grad.device_ordinal())?;
629        let result = crate::kernels::gpu_gelu_backward_tanh_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
630        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
631    }
632
633    fn gelu_backward_erf_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
634        let g_buf = Self::unwrap_buffer_f64(grad)?;
635        let i_buf = Self::unwrap_buffer_f64(input)?;
636        let dev = self.device(grad.device_ordinal())?;
637        let result = crate::kernels::gpu_gelu_backward_erf_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
638        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
639    }
640
641    fn silu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
642        let g_buf = Self::unwrap_buffer_f64(grad)?;
643        let i_buf = Self::unwrap_buffer_f64(input)?;
644        let dev = self.device(grad.device_ordinal())?;
645        let result = crate::kernels::gpu_silu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
646        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
647    }
648
649    fn elu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle, alpha: f64) -> FerrotorchResult<GpuBufferHandle> {
650        let g_buf = Self::unwrap_buffer_f64(grad)?;
651        let i_buf = Self::unwrap_buffer_f64(input)?;
652        let dev = self.device(grad.device_ordinal())?;
653        let result = crate::kernels::gpu_elu_backward_f64(g_buf, i_buf, alpha, dev).map_err(Self::map_gpu_err)?;
654        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
655    }
656
657    fn mish_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
658        let g_buf = Self::unwrap_buffer_f64(grad)?;
659        let i_buf = Self::unwrap_buffer_f64(input)?;
660        let dev = self.device(grad.device_ordinal())?;
661        let result = crate::kernels::gpu_mish_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
662        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
663    }
664
665    // f64 cumulative ops
666    fn cumsum_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
667        let a_buf = Self::unwrap_buffer_f64(a)?;
668        let dev = self.device(a.device_ordinal())?;
669        let result = crate::kernels::gpu_cumsum_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
670        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
671    }
672
673    fn cumprod_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
674        let a_buf = Self::unwrap_buffer_f64(a)?;
675        let dev = self.device(a.device_ordinal())?;
676        let result = crate::kernels::gpu_cumprod_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
677        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
678    }
679
680    fn cummax_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
681        let a_buf = Self::unwrap_buffer_f64(a)?;
682        let dev = self.device(a.device_ordinal())?;
683        let (vals, idxs) = crate::kernels::gpu_cummax_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
684        let ord = a.device_ordinal();
685        Ok((Self::wrap_buffer_f64(vals, ord), Self::wrap_buffer_f64(idxs, ord)))
686    }
687
688    fn cummin_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
689        let a_buf = Self::unwrap_buffer_f64(a)?;
690        let dev = self.device(a.device_ordinal())?;
691        let (vals, idxs) = crate::kernels::gpu_cummin_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
692        let ord = a.device_ordinal();
693        Ok((Self::wrap_buffer_f64(vals, ord), Self::wrap_buffer_f64(idxs, ord)))
694    }
695
696    fn logcumsumexp_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
697        let a_buf = Self::unwrap_buffer_f64(a)?;
698        let dev = self.device(a.device_ordinal())?;
699        let result = crate::kernels::gpu_logcumsumexp_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
700        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
701    }
702
703    // f64 shape ops
704    fn transpose_2d_f64(&self, a: &GpuBufferHandle, m: usize, n: usize) -> FerrotorchResult<GpuBufferHandle> {
705        let a_buf = Self::unwrap_buffer_f64(a)?;
706        let dev = self.device(a.device_ordinal())?;
707        let result = crate::kernels::gpu_transpose_2d_f64(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
708        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
709    }
710
711    fn permute_0213_f64(&self, a: &GpuBufferHandle, d0: usize, d1: usize, d2: usize, d3: usize) -> FerrotorchResult<GpuBufferHandle> {
712        let a_buf = Self::unwrap_buffer_f64(a)?;
713        let dev = self.device(a.device_ordinal())?;
714        let result = crate::kernels::gpu_permute_0213_f64(a_buf, d0, d1, d2, d3, dev).map_err(Self::map_gpu_err)?;
715        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
716    }
717
718    // f64 broadcast ops
719    fn broadcast_add_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
720        let a_buf = Self::unwrap_buffer_f64(a)?;
721        let b_buf = Self::unwrap_buffer_f64(b)?;
722        let dev = self.device(a.device_ordinal())?;
723        let result = crate::kernels::gpu_broadcast_add_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
724        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
725    }
726
727    fn broadcast_sub_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
728        let a_buf = Self::unwrap_buffer_f64(a)?;
729        let b_buf = Self::unwrap_buffer_f64(b)?;
730        let dev = self.device(a.device_ordinal())?;
731        let result = crate::kernels::gpu_broadcast_sub_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
732        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
733    }
734
735    fn broadcast_mul_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
736        let a_buf = Self::unwrap_buffer_f64(a)?;
737        let b_buf = Self::unwrap_buffer_f64(b)?;
738        let dev = self.device(a.device_ordinal())?;
739        let result = crate::kernels::gpu_broadcast_mul_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
740        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
741    }
742
743    fn broadcast_div_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
744        let a_buf = Self::unwrap_buffer_f64(a)?;
745        let b_buf = Self::unwrap_buffer_f64(b)?;
746        let dev = self.device(a.device_ordinal())?;
747        let result = crate::kernels::gpu_broadcast_div_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
748        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
749    }
750
751    // f64 reduction ops
752    fn sum_f64(&self, a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
753        let a_buf = Self::unwrap_buffer_f64(a)?;
754        let dev = self.device(a.device_ordinal())?;
755        let result = crate::kernels::gpu_reduce_sum_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
756        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
757    }
758
759    fn sum_axis_f64(&self, a: &GpuBufferHandle, shape: &[usize], axis: usize) -> FerrotorchResult<GpuBufferHandle> {
760        let a_buf = Self::unwrap_buffer_f64(a)?;
761        let dev = self.device(a.device_ordinal())?;
762        let outer: usize = shape[..axis].iter().product();
763        let axis_size = shape[axis];
764        let inner: usize = shape[axis + 1..].iter().product();
765        let result = crate::kernels::gpu_sum_axis_f64(a_buf, outer, axis_size, inner, dev).map_err(Self::map_gpu_err)?;
766        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
767    }
768
769    // f64 softmax / log-softmax / layernorm / rmsnorm
770
771    fn softmax_f64(&self, a: &GpuBufferHandle, rows: usize, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
772        let a_buf = Self::unwrap_buffer_f64(a)?;
773        let dev = self.device(a.device_ordinal())?;
774        let result = crate::kernels::gpu_softmax_f64(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
775        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
776    }
777
778    fn softmax_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
779        let grad_buf = Self::unwrap_buffer_f64(grad)?;
780        let output_buf = Self::unwrap_buffer_f64(output)?;
781        let dev = self.device(grad.device_ordinal())?;
782        let result = crate::kernels::gpu_softmax_backward_f64(grad_buf, output_buf, cols, dev).map_err(Self::map_gpu_err)?;
783        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
784    }
785
786    fn log_softmax_f64(&self, a: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
787        let a_buf = Self::unwrap_buffer_f64(a)?;
788        let dev = self.device(a.device_ordinal())?;
789        let result = crate::kernels::gpu_log_softmax_f64(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
790        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
791    }
792
793    fn log_softmax_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
794        let grad_buf = Self::unwrap_buffer_f64(grad)?;
795        let output_buf = Self::unwrap_buffer_f64(output)?;
796        let dev = self.device(grad.device_ordinal())?;
797        let result = crate::kernels::gpu_log_softmax_backward_f64(grad_buf, output_buf, cols, dev).map_err(Self::map_gpu_err)?;
798        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
799    }
800
801    fn layernorm_f64(
802        &self,
803        input: &GpuBufferHandle,
804        weight: &GpuBufferHandle,
805        bias: &GpuBufferHandle,
806        rows: usize,
807        cols: usize,
808        eps: f64,
809    ) -> FerrotorchResult<GpuBufferHandle> {
810        let in_buf = Self::unwrap_buffer_f64(input)?;
811        let w_buf = Self::unwrap_buffer_f64(weight)?;
812        let b_buf = Self::unwrap_buffer_f64(bias)?;
813        let dev = self.device(input.device_ordinal())?;
814        let result = crate::kernels::gpu_layernorm_f64(in_buf, w_buf, b_buf, rows, cols, eps, dev)
815            .map_err(Self::map_gpu_err)?;
816        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
817    }
818
819    fn layernorm_backward_f64(
820        &self,
821        input: &GpuBufferHandle,
822        grad_output: &GpuBufferHandle,
823        weight: &GpuBufferHandle,
824        rows: usize,
825        cols: usize,
826        eps: f64,
827    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
828        let in_buf = Self::unwrap_buffer_f64(input)?;
829        let go_buf = Self::unwrap_buffer_f64(grad_output)?;
830        let w_buf = Self::unwrap_buffer_f64(weight)?;
831        let dev = self.device(input.device_ordinal())?;
832        let (gi, gw, gb) =
833            crate::kernels::gpu_layernorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
834                .map_err(Self::map_gpu_err)?;
835        let ordinal = input.device_ordinal();
836        Ok((
837            Self::wrap_buffer_f64(gi, ordinal),
838            Self::wrap_buffer_f64(gw, ordinal),
839            Self::wrap_buffer_f64(gb, ordinal),
840        ))
841    }
842
843    fn rmsnorm_f64(
844        &self,
845        input: &GpuBufferHandle,
846        weight: &GpuBufferHandle,
847        rows: usize,
848        cols: usize,
849        eps: f64,
850    ) -> FerrotorchResult<GpuBufferHandle> {
851        let in_buf = Self::unwrap_buffer_f64(input)?;
852        let w_buf = Self::unwrap_buffer_f64(weight)?;
853        let dev = self.device(input.device_ordinal())?;
854        let result = crate::kernels::gpu_rmsnorm_f64(in_buf, w_buf, rows, cols, eps, dev)
855            .map_err(Self::map_gpu_err)?;
856        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
857    }
858
859    fn rmsnorm_backward_f64(
860        &self,
861        input: &GpuBufferHandle,
862        grad_output: &GpuBufferHandle,
863        weight: &GpuBufferHandle,
864        rows: usize,
865        cols: usize,
866        eps: f64,
867    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
868        let in_buf = Self::unwrap_buffer_f64(input)?;
869        let go_buf = Self::unwrap_buffer_f64(grad_output)?;
870        let w_buf = Self::unwrap_buffer_f64(weight)?;
871        let dev = self.device(input.device_ordinal())?;
872        let (gi, gw) =
873            crate::kernels::gpu_rmsnorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
874                .map_err(Self::map_gpu_err)?;
875        let ordinal = input.device_ordinal();
876        Ok((Self::wrap_buffer_f64(gi, ordinal), Self::wrap_buffer_f64(gw, ordinal)))
877    }
878
879    // f64 embedding / scatter / indexing
880
881    fn embed_lookup_f64(
882        &self,
883        idx: &GpuBufferHandle,
884        weight: &GpuBufferHandle,
885        d: usize,
886    ) -> FerrotorchResult<GpuBufferHandle> {
887        // indices are always f32-encoded
888        let idx_buf = Self::unwrap_buffer(idx)?;
889        let w_buf = Self::unwrap_buffer_f64(weight)?;
890        let dev = self.device(idx.device_ordinal())?;
891        let result =
892            crate::kernels::gpu_embed_lookup_f64(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
893        Ok(Self::wrap_buffer_f64(result, idx.device_ordinal()))
894    }
895
896    fn embed_lookup_batch_f64(
897        &self,
898        indices: &GpuBufferHandle,
899        weight: &GpuBufferHandle,
900        n: usize,
901        d: usize,
902    ) -> FerrotorchResult<GpuBufferHandle> {
903        // indices are always f32-encoded
904        let idx_buf = Self::unwrap_buffer(indices)?;
905        let w_buf = Self::unwrap_buffer_f64(weight)?;
906        let dev = self.device(indices.device_ordinal())?;
907        let result = crate::kernels::gpu_embed_lookup_batch_f64(idx_buf, w_buf, n, d, dev)
908            .map_err(Self::map_gpu_err)?;
909        Ok(Self::wrap_buffer_f64(result, indices.device_ordinal()))
910    }
911
912    fn scatter_add_rows_f64(
913        &self,
914        grad_output: &GpuBufferHandle,
915        indices: &GpuBufferHandle,
916        num_embeddings: usize,
917        d: usize,
918    ) -> FerrotorchResult<GpuBufferHandle> {
919        let go_buf = Self::unwrap_buffer_f64(grad_output)?;
920        // indices are always f32-encoded
921        let idx_buf = Self::unwrap_buffer(indices)?;
922        let dev = self.device(grad_output.device_ordinal())?;
923        let result = crate::kernels::gpu_scatter_add_rows_f64(go_buf, idx_buf, num_embeddings, d, dev)
924            .map_err(Self::map_gpu_err)?;
925        Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
926    }
927
928    // f64 masked fill / masked zero
929    //
930    // The f64 kernels expect CudaBuffer<u8> for the mask, but the trait
931    // provides a GpuBufferHandle containing CudaBuffer<f32> (1.0/0.0 encoding).
932    // We convert f32 mask -> u8 mask via a CPU roundtrip.
933
934    fn masked_fill_f64(
935        &self,
936        input: &GpuBufferHandle,
937        mask: &GpuBufferHandle,
938        value: f64,
939    ) -> FerrotorchResult<GpuBufferHandle> {
940        let input_buf = Self::unwrap_buffer_f64(input)?;
941        let mask_f32 = Self::unwrap_buffer(mask)?;
942        let dev = self.device(input.device_ordinal())?;
943        // Convert f32 mask to u8 mask on GPU via CPU roundtrip
944        let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
945        let mask_u8: Vec<u8> = mask_host.iter().map(|&v| if v != 0.0 { 1u8 } else { 0u8 }).collect();
946        let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
947        let result = crate::kernels::gpu_masked_fill_f64(input_buf, &mask_gpu, value, dev)
948            .map_err(Self::map_gpu_err)?;
949        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
950    }
951
952    fn masked_zero_f64(
953        &self,
954        grad: &GpuBufferHandle,
955        mask: &GpuBufferHandle,
956    ) -> FerrotorchResult<GpuBufferHandle> {
957        let grad_buf = Self::unwrap_buffer_f64(grad)?;
958        let mask_f32 = Self::unwrap_buffer(mask)?;
959        let dev = self.device(grad.device_ordinal())?;
960        // Convert f32 mask to u8 mask on GPU via CPU roundtrip
961        let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
962        let mask_u8: Vec<u8> = mask_host.iter().map(|&v| if v != 0.0 { 1u8 } else { 0u8 }).collect();
963        let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
964        let result = crate::kernels::gpu_masked_zero_f64(grad_buf, &mask_gpu, dev)
965            .map_err(Self::map_gpu_err)?;
966        Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
967    }
968
969    // f64 slice ops
970
971    fn slice_write_f64(
972        &self,
973        src: &GpuBufferHandle,
974        dst: &mut GpuBufferHandle,
975        n_batch: usize,
976        d: usize,
977        max_len: usize,
978        pos: usize,
979    ) -> FerrotorchResult<()> {
980        let src_buf = Self::unwrap_buffer_f64(src)?;
981        let dst_buf = Self::unwrap_buffer_f64_mut(dst)?;
982        let dev = self.device(src.device_ordinal())?;
983        crate::kernels::gpu_slice_write_f64(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
984            .map_err(Self::map_gpu_err)?;
985        Ok(())
986    }
987
988    fn slice_read_f64(
989        &self,
990        src: &GpuBufferHandle,
991        n_batch: usize,
992        d: usize,
993        len: usize,
994        max_len: usize,
995    ) -> FerrotorchResult<GpuBufferHandle> {
996        let src_buf = Self::unwrap_buffer_f64(src)?;
997        let dev = self.device(src.device_ordinal())?;
998        let result = crate::kernels::gpu_slice_read_f64(src_buf, n_batch, d, len, max_len, dev)
999            .map_err(Self::map_gpu_err)?;
1000        Ok(Self::wrap_buffer_f64(result, src.device_ordinal()))
1001    }
1002
1003    // f64 strided split / cat
1004
1005    fn strided_split_f64(
1006        &self,
1007        input: &GpuBufferHandle,
1008        total_along_axis: usize,
1009        split_offset: usize,
1010        split_size: usize,
1011        inner_size: usize,
1012        n: usize,
1013    ) -> FerrotorchResult<GpuBufferHandle> {
1014        let in_buf = Self::unwrap_buffer_f64(input)?;
1015        let dev = self.device(input.device_ordinal())?;
1016        let result = crate::kernels::gpu_strided_split_f64(
1017            in_buf,
1018            total_along_axis,
1019            split_offset,
1020            split_size,
1021            inner_size,
1022            n,
1023            dev,
1024        )
1025        .map_err(Self::map_gpu_err)?;
1026        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
1027    }
1028
1029    fn strided_cat_f64(
1030        &self,
1031        input: &GpuBufferHandle,
1032        output: &mut GpuBufferHandle,
1033        total_along_axis: usize,
1034        cat_offset: usize,
1035        part_size: usize,
1036        inner_size: usize,
1037        n: usize,
1038    ) -> FerrotorchResult<()> {
1039        let in_buf = Self::unwrap_buffer_f64(input)?;
1040        let dev = self.device(input.device_ordinal())?;
1041        let out_buf = Self::unwrap_buffer_f64_mut(output)?;
1042        crate::kernels::gpu_strided_cat_f64(
1043            in_buf,
1044            out_buf,
1045            total_along_axis,
1046            cat_offset,
1047            part_size,
1048            inner_size,
1049            n,
1050            dev,
1051        )
1052        .map_err(Self::map_gpu_err)?;
1053        Ok(())
1054    }
1055
1056    // f64 indexing ops
1057
1058    fn index_select_1d_f64(
1059        &self,
1060        input: &GpuBufferHandle,
1061        indices: &GpuBufferHandle,
1062    ) -> FerrotorchResult<GpuBufferHandle> {
1063        let input_buf = Self::unwrap_buffer_f64(input)?;
1064        // indices are always f32-encoded
1065        let idx_buf = Self::unwrap_buffer(indices)?;
1066        let dev = self.device(input.device_ordinal())?;
1067        let result = crate::kernels::gpu_index_select_1d_f64(input_buf, idx_buf, dev)
1068            .map_err(Self::map_gpu_err)?;
1069        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
1070    }
1071
1072    fn scatter_add_1d_f64(
1073        &self,
1074        grad_output: &GpuBufferHandle,
1075        indices: &GpuBufferHandle,
1076        input_len: usize,
1077    ) -> FerrotorchResult<GpuBufferHandle> {
1078        let go_buf = Self::unwrap_buffer_f64(grad_output)?;
1079        // indices are always f32-encoded
1080        let idx_buf = Self::unwrap_buffer(indices)?;
1081        let dev = self.device(grad_output.device_ordinal())?;
1082        let result = crate::kernels::gpu_scatter_add_1d_f64(go_buf, idx_buf, input_len, dev)
1083            .map_err(Self::map_gpu_err)?;
1084        Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
1085    }
1086
1087    fn bmm_f64(
1088        &self,
1089        a: &GpuBufferHandle,
1090        b: &GpuBufferHandle,
1091        batch: usize,
1092        m: usize,
1093        k: usize,
1094        n: usize,
1095    ) -> FerrotorchResult<GpuBufferHandle> {
1096        let a_buf = Self::unwrap_buffer_f64(a)?;
1097        let b_buf = Self::unwrap_buffer_f64(b)?;
1098        let dev = self.device(a.device_ordinal())?;
1099        let result = crate::blas::gpu_bmm_f64(a_buf, b_buf, batch, m, k, n, dev)
1100            .map_err(Self::map_gpu_err)?;
1101        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1102    }
1103
1104    #[allow(clippy::too_many_arguments)]
1105    fn fused_adam_f32(
1106        &self,
1107        param: &mut GpuBufferHandle,
1108        grad: &GpuBufferHandle,
1109        exp_avg: &mut GpuBufferHandle,
1110        exp_avg_sq: &mut GpuBufferHandle,
1111        beta1: f32,
1112        beta2: f32,
1113        lr: f32,
1114        eps: f32,
1115        bc1: f32,
1116        bc2: f32,
1117        weight_decay: f32,
1118    ) -> FerrotorchResult<()> {
1119        let ordinal = param.device_ordinal();
1120        let dev = self.device(ordinal)?;
1121        let p_buf = Self::unwrap_buffer_mut(param)?;
1122        let g_buf = Self::unwrap_buffer(grad)?;
1123        let m_buf = Self::unwrap_buffer_mut(exp_avg)?;
1124        let v_buf = Self::unwrap_buffer_mut(exp_avg_sq)?;
1125        crate::kernels::gpu_fused_adam(
1126            p_buf,
1127            g_buf,
1128            m_buf,
1129            v_buf,
1130            beta1,
1131            beta2,
1132            lr,
1133            eps,
1134            bc1,
1135            bc2,
1136            weight_decay,
1137            dev,
1138        )
1139        .map_err(Self::map_gpu_err)?;
1140        Ok(())
1141    }
1142
1143    #[allow(clippy::too_many_arguments)]
1144    fn maxpool2d_f32(
1145        &self,
1146        input: &GpuBufferHandle,
1147        batch: usize,
1148        channels: usize,
1149        h_in: usize,
1150        w_in: usize,
1151        kh: usize,
1152        kw: usize,
1153        sh: usize,
1154        sw: usize,
1155        ph: usize,
1156        pw: usize,
1157    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1158        let buf = Self::unwrap_buffer(input)?;
1159        let dev = self.device(input.device_ordinal())?;
1160        let (out, shape) = crate::kernels::gpu_maxpool2d(
1161            buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
1162        ).map_err(Self::map_gpu_err)?;
1163        Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
1164    }
1165
1166    #[allow(clippy::too_many_arguments)]
1167    fn avgpool2d_f32(
1168        &self,
1169        input: &GpuBufferHandle,
1170        batch: usize,
1171        channels: usize,
1172        h_in: usize,
1173        w_in: usize,
1174        kh: usize,
1175        kw: usize,
1176        sh: usize,
1177        sw: usize,
1178        ph: usize,
1179        pw: usize,
1180    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1181        let buf = Self::unwrap_buffer(input)?;
1182        let dev = self.device(input.device_ordinal())?;
1183        let (out, shape) = crate::kernels::gpu_avgpool2d(
1184            buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
1185        ).map_err(Self::map_gpu_err)?;
1186        Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
1187    }
1188
1189    #[allow(clippy::too_many_arguments)]
1190    fn conv2d_f32(
1191        &self,
1192        input: &GpuBufferHandle,
1193        weight: &GpuBufferHandle,
1194        bias: Option<&GpuBufferHandle>,
1195        input_shape: [usize; 4],
1196        weight_shape: [usize; 4],
1197        stride: (usize, usize),
1198        padding: (usize, usize),
1199    ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1200        let input_buf = Self::unwrap_buffer(input)?;
1201        let weight_buf = Self::unwrap_buffer(weight)?;
1202        let bias_buf = match bias {
1203            Some(b) => Some(Self::unwrap_buffer(b)?),
1204            None => None,
1205        };
1206        let dev = self.device(input.device_ordinal())?;
1207        let (out_buf, out_shape) = crate::conv::gpu_conv2d_f32(
1208            input_buf,
1209            weight_buf,
1210            bias_buf,
1211            input_shape,
1212            weight_shape,
1213            stride,
1214            padding,
1215            dev,
1216        )
1217        .map_err(Self::map_gpu_err)?;
1218        Ok((Self::wrap_buffer(out_buf, input.device_ordinal()), out_shape))
1219    }
1220
1221    fn fused_gru_cell_f32(
1222        &self,
1223        input_gates: &GpuBufferHandle,
1224        hidden_gates: &GpuBufferHandle,
1225        bias_ih: &GpuBufferHandle,
1226        bias_hh: &GpuBufferHandle,
1227        hx: &GpuBufferHandle,
1228        hidden_size: usize,
1229    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1230        let ig = Self::unwrap_buffer(input_gates)?;
1231        let hg = Self::unwrap_buffer(hidden_gates)?;
1232        let bih = Self::unwrap_buffer(bias_ih)?;
1233        let bhh = Self::unwrap_buffer(bias_hh)?;
1234        let hx_buf = Self::unwrap_buffer(hx)?;
1235        let dev = self.device(input_gates.device_ordinal())?;
1236        let (hy, ws) = crate::kernels::gpu_fused_gru_forward(
1237            ig, hg, bih, bhh, hx_buf, hidden_size, dev,
1238        )
1239        .map_err(Self::map_gpu_err)?;
1240        let ord = input_gates.device_ordinal();
1241        Ok((Self::wrap_buffer(hy, ord), Self::wrap_buffer(ws, ord)))
1242    }
1243
1244    fn synchronize(&self, device: usize) -> FerrotorchResult<()> {
1245        let dev = self.device(device)?;
1246        dev.stream()
1247            .synchronize()
1248            .map_err(|e| FerrotorchError::InvalidArgument {
1249                message: format!("CUDA synchronize failed: {e}"),
1250            })?;
1251        Ok(())
1252    }
1253
1254    fn stream_count(&self, device: usize) -> usize {
1255        crate::stream::StreamPool::pool_size(device)
1256    }
1257
1258    // -- Linalg f32 -----------------------------------------------------------
1259
1260    fn matmul_f32(
1261        &self,
1262        a: &GpuBufferHandle,
1263        b: &GpuBufferHandle,
1264        m: usize,
1265        k: usize,
1266        n: usize,
1267    ) -> FerrotorchResult<GpuBufferHandle> {
1268        let a_buf = Self::unwrap_buffer(a)?;
1269        let b_buf = Self::unwrap_buffer(b)?;
1270        let dev = self.device(a.device_ordinal())?;
1271        let result =
1272            crate::blas::gpu_matmul_f32(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1273        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1274    }
1275
1276    // -- Reduction f32 --------------------------------------------------------
1277
1278    fn sum_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
1279        let a_buf = Self::unwrap_buffer(a)?;
1280        let dev = self.device(a.device_ordinal())?;
1281        let result = crate::kernels::gpu_reduce_sum(a_buf, dev).map_err(Self::map_gpu_err)?;
1282        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1283    }
1284
1285    // -- Linalg f64 (cuBLAS DGEMM) --------------------------------------------
1286
1287    fn matmul_f64(
1288        &self,
1289        a: &GpuBufferHandle,
1290        b: &GpuBufferHandle,
1291        m: usize,
1292        k: usize,
1293        n: usize,
1294    ) -> FerrotorchResult<GpuBufferHandle> {
1295        let a_buf = Self::unwrap_buffer_f64(a)?;
1296        let b_buf = Self::unwrap_buffer_f64(b)?;
1297        let dev = self.device(a.device_ordinal())?;
1298        let result =
1299            crate::blas::gpu_matmul_f64(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1300        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1301    }
1302
1303    // -- Broadcast binary f32 -------------------------------------------------
1304
1305    fn broadcast_add_f32(
1306        &self,
1307        a: &GpuBufferHandle,
1308        b: &GpuBufferHandle,
1309        a_shape: &[usize],
1310        b_shape: &[usize],
1311        out_shape: &[usize],
1312    ) -> FerrotorchResult<GpuBufferHandle> {
1313        let a_buf = Self::unwrap_buffer(a)?;
1314        let b_buf = Self::unwrap_buffer(b)?;
1315        let dev = self.device(a.device_ordinal())?;
1316        let result =
1317            crate::kernels::gpu_broadcast_add(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1318                .map_err(Self::map_gpu_err)?;
1319        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1320    }
1321
1322    fn broadcast_sub_f32(
1323        &self,
1324        a: &GpuBufferHandle,
1325        b: &GpuBufferHandle,
1326        a_shape: &[usize],
1327        b_shape: &[usize],
1328        out_shape: &[usize],
1329    ) -> FerrotorchResult<GpuBufferHandle> {
1330        let a_buf = Self::unwrap_buffer(a)?;
1331        let b_buf = Self::unwrap_buffer(b)?;
1332        let dev = self.device(a.device_ordinal())?;
1333        let result =
1334            crate::kernels::gpu_broadcast_sub(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1335                .map_err(Self::map_gpu_err)?;
1336        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1337    }
1338
1339    fn broadcast_mul_f32(
1340        &self,
1341        a: &GpuBufferHandle,
1342        b: &GpuBufferHandle,
1343        a_shape: &[usize],
1344        b_shape: &[usize],
1345        out_shape: &[usize],
1346    ) -> FerrotorchResult<GpuBufferHandle> {
1347        let a_buf = Self::unwrap_buffer(a)?;
1348        let b_buf = Self::unwrap_buffer(b)?;
1349        let dev = self.device(a.device_ordinal())?;
1350        let result =
1351            crate::kernels::gpu_broadcast_mul(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1352                .map_err(Self::map_gpu_err)?;
1353        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1354    }
1355
1356    fn broadcast_div_f32(
1357        &self,
1358        a: &GpuBufferHandle,
1359        b: &GpuBufferHandle,
1360        a_shape: &[usize],
1361        b_shape: &[usize],
1362        out_shape: &[usize],
1363    ) -> FerrotorchResult<GpuBufferHandle> {
1364        let a_buf = Self::unwrap_buffer(a)?;
1365        let b_buf = Self::unwrap_buffer(b)?;
1366        let dev = self.device(a.device_ordinal())?;
1367        let result =
1368            crate::kernels::gpu_broadcast_div(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1369                .map_err(Self::map_gpu_err)?;
1370        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1371    }
1372
1373    fn softmax_f32(
1374        &self,
1375        a: &GpuBufferHandle,
1376        rows: usize,
1377        cols: usize,
1378    ) -> FerrotorchResult<GpuBufferHandle> {
1379        let a_buf = Self::unwrap_buffer(a)?;
1380        let dev = self.device(a.device_ordinal())?;
1381        let result =
1382            crate::kernels::gpu_softmax(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
1383        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1384    }
1385
1386    fn dropout_f32(
1387        &self,
1388        a: &GpuBufferHandle,
1389        threshold: u32,
1390        scale: f32,
1391        seed: u32,
1392    ) -> FerrotorchResult<GpuBufferHandle> {
1393        let a_buf = Self::unwrap_buffer(a)?;
1394        let dev = self.device(a.device_ordinal())?;
1395        let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, seed, dev)
1396            .map_err(Self::map_gpu_err)?;
1397        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1398    }
1399
1400    fn dropout_philox_f32(
1401        &self,
1402        a: &GpuBufferHandle,
1403        threshold: u32,
1404        scale: f32,
1405    ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
1406        let device_ordinal = a.device_ordinal();
1407        let n = a.len();
1408
1409        // Snapshot the current RNG state and advance it.
1410        let rng_state = {
1411            let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1412                FerrotorchError::InvalidArgument {
1413                    message: "failed to lock CUDA RNG manager".into(),
1414                }
1415            })?;
1416            let philox_gen = mgr.generator(device_ordinal);
1417            let state = philox_gen.get_state();
1418            // Advance by ceil(n/4) counters (each counter produces 4 u32 values)
1419            let counters_needed = n.div_ceil(4);
1420            philox_gen.advance(counters_needed as u64);
1421            state
1422        };
1423
1424        // Use the Philox state as the seed for the dropout kernel.
1425        // We encode the Philox counter+seed into a u32 seed that the existing
1426        // dropout kernel can use. For full correctness on GPU, we should use
1427        // the Philox uniform kernel to generate the mask, then apply it.
1428        // However, for consistency between GPU forward and CPU backward mask
1429        // regeneration, we use the Philox state to deterministically derive a
1430        // seed for the existing kernel.
1431        let a_buf = Self::unwrap_buffer(a)?;
1432        let dev = self.device(device_ordinal)?;
1433
1434        // Use the Philox counter XOR seed as the dropout kernel's seed.
1435        // This gives us deterministic behavior tied to the Philox state.
1436        let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
1437        let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, derived_seed, dev)
1438            .map_err(Self::map_gpu_err)?;
1439
1440        let gpu_rng_state = GpuRngState {
1441            counter: rng_state.counter,
1442            seed: rng_state.seed,
1443            offset: rng_state.offset,
1444            device: device_ordinal,
1445        };
1446
1447        Ok((Self::wrap_buffer(result, device_ordinal), gpu_rng_state))
1448    }
1449
1450    fn dropout_f64(
1451        &self,
1452        a: &GpuBufferHandle,
1453        threshold: u32,
1454        scale: f64,
1455        seed: u32,
1456    ) -> FerrotorchResult<GpuBufferHandle> {
1457        let a_buf = Self::unwrap_buffer_f64(a)?;
1458        let dev = self.device(a.device_ordinal())?;
1459        let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, seed, dev)
1460            .map_err(Self::map_gpu_err)?;
1461        Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1462    }
1463
1464    fn dropout_philox_f64(
1465        &self,
1466        a: &GpuBufferHandle,
1467        threshold: u32,
1468        scale: f64,
1469    ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
1470        let device_ordinal = a.device_ordinal();
1471        let n = a.len();
1472
1473        let rng_state = {
1474            let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1475                FerrotorchError::InvalidArgument {
1476                    message: "failed to lock CUDA RNG manager".into(),
1477                }
1478            })?;
1479            let philox_gen = mgr.generator(device_ordinal);
1480            let state = philox_gen.get_state();
1481            let counters_needed = n.div_ceil(4);
1482            philox_gen.advance(counters_needed as u64);
1483            state
1484        };
1485
1486        let a_buf = Self::unwrap_buffer_f64(a)?;
1487        let dev = self.device(device_ordinal)?;
1488        let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
1489        let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, derived_seed, dev)
1490            .map_err(Self::map_gpu_err)?;
1491
1492        let gpu_rng_state = GpuRngState {
1493            counter: rng_state.counter,
1494            seed: rng_state.seed,
1495            offset: rng_state.offset,
1496            device: device_ordinal,
1497        };
1498
1499        Ok((Self::wrap_buffer_f64(result, device_ordinal), gpu_rng_state))
1500    }
1501
1502    fn transpose_2d_f32(
1503        &self,
1504        a: &GpuBufferHandle,
1505        m: usize,
1506        n: usize,
1507    ) -> FerrotorchResult<GpuBufferHandle> {
1508        let a_buf = Self::unwrap_buffer(a)?;
1509        let dev = self.device(a.device_ordinal())?;
1510        let result =
1511            crate::kernels::gpu_transpose_2d(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
1512        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1513    }
1514
1515    fn permute_0213_f32(
1516        &self,
1517        a: &GpuBufferHandle,
1518        d0: usize,
1519        d1: usize,
1520        d2: usize,
1521        d3: usize,
1522    ) -> FerrotorchResult<GpuBufferHandle> {
1523        let a_buf = Self::unwrap_buffer(a)?;
1524        let dev = self.device(a.device_ordinal())?;
1525        let result = crate::kernels::gpu_permute_0213(a_buf, d0, d1, d2, d3, dev)
1526            .map_err(Self::map_gpu_err)?;
1527        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1528    }
1529
1530    fn bmm_f32(
1531        &self,
1532        a: &GpuBufferHandle,
1533        b: &GpuBufferHandle,
1534        batch: usize,
1535        m: usize,
1536        k: usize,
1537        n: usize,
1538    ) -> FerrotorchResult<GpuBufferHandle> {
1539        let a_buf = Self::unwrap_buffer(a)?;
1540        let b_buf = Self::unwrap_buffer(b)?;
1541        let dev = self.device(a.device_ordinal())?;
1542        let result = crate::blas::gpu_bmm_f32(a_buf, b_buf, batch, m, k, n, dev)
1543            .map_err(Self::map_gpu_err)?;
1544        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1545    }
1546
1547    fn bmm_f16_f32(
1548        &self,
1549        a: &GpuBufferHandle,
1550        b: &GpuBufferHandle,
1551        batch: usize,
1552        m: usize,
1553        k: usize,
1554        n: usize,
1555    ) -> FerrotorchResult<GpuBufferHandle> {
1556        let a_buf = Self::unwrap_buffer(a)?;
1557        let b_buf = Self::unwrap_buffer(b)?;
1558        let dev = self.device(a.device_ordinal())?;
1559        let result = crate::blas::gpu_bmm_f16(a_buf, b_buf, batch, m, k, n, dev)
1560            .map_err(Self::map_gpu_err)?;
1561        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1562    }
1563
1564    fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1565        let a_buf = Self::unwrap_buffer(a)?;
1566        let dev = self.device(a.device_ordinal())?;
1567        let result = crate::kernels::gpu_gelu(a_buf, dev).map_err(Self::map_gpu_err)?;
1568        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1569    }
1570
1571    fn gelu_tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1572        let a_buf = Self::unwrap_buffer(a)?;
1573        let dev = self.device(a.device_ordinal())?;
1574        let result = crate::kernels::gpu_gelu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
1575        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1576    }
1577
1578    fn gelu_erf_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1579        let a_buf = Self::unwrap_buffer(a)?;
1580        let dev = self.device(a.device_ordinal())?;
1581        let result = crate::kernels::gpu_gelu_erf(a_buf, dev).map_err(Self::map_gpu_err)?;
1582        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1583    }
1584
1585    fn layernorm_f32(
1586        &self,
1587        input: &GpuBufferHandle,
1588        weight: &GpuBufferHandle,
1589        bias: &GpuBufferHandle,
1590        rows: usize,
1591        cols: usize,
1592        eps: f32,
1593    ) -> FerrotorchResult<GpuBufferHandle> {
1594        let in_buf = Self::unwrap_buffer(input)?;
1595        let w_buf = Self::unwrap_buffer(weight)?;
1596        let b_buf = Self::unwrap_buffer(bias)?;
1597        let dev = self.device(input.device_ordinal())?;
1598        let result = crate::kernels::gpu_layernorm(in_buf, w_buf, b_buf, rows, cols, eps, dev)
1599            .map_err(Self::map_gpu_err)?;
1600        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1601    }
1602
1603    fn rmsnorm_f32(
1604        &self,
1605        input: &GpuBufferHandle,
1606        weight: &GpuBufferHandle,
1607        rows: usize,
1608        cols: usize,
1609        eps: f32,
1610    ) -> FerrotorchResult<GpuBufferHandle> {
1611        let in_buf = Self::unwrap_buffer(input)?;
1612        let w_buf = Self::unwrap_buffer(weight)?;
1613        let dev = self.device(input.device_ordinal())?;
1614        let result = crate::kernels::gpu_rmsnorm(in_buf, w_buf, rows, cols, eps, dev)
1615            .map_err(Self::map_gpu_err)?;
1616        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1617    }
1618
1619    fn rmsnorm_backward_f32(
1620        &self,
1621        input: &GpuBufferHandle,
1622        grad_output: &GpuBufferHandle,
1623        weight: &GpuBufferHandle,
1624        rows: usize,
1625        cols: usize,
1626        eps: f32,
1627    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1628        let in_buf = Self::unwrap_buffer(input)?;
1629        let go_buf = Self::unwrap_buffer(grad_output)?;
1630        let w_buf = Self::unwrap_buffer(weight)?;
1631        let dev = self.device(input.device_ordinal())?;
1632        let (gi, gw) =
1633            crate::kernels::gpu_rmsnorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
1634                .map_err(Self::map_gpu_err)?;
1635        let ordinal = input.device_ordinal();
1636        Ok((Self::wrap_buffer(gi, ordinal), Self::wrap_buffer(gw, ordinal)))
1637    }
1638
1639    fn slice_write_f32(
1640        &self,
1641        src: &GpuBufferHandle,
1642        dst: &mut GpuBufferHandle,
1643        n_batch: usize,
1644        d: usize,
1645        max_len: usize,
1646        pos: usize,
1647    ) -> FerrotorchResult<()> {
1648        let src_buf = Self::unwrap_buffer(src)?;
1649        let dst_buf =
1650            dst.downcast_mut::<CudaBuffer<f32>>()
1651                .ok_or(FerrotorchError::InvalidArgument {
1652                    message: "slice_write_f32: dst is not CudaBuffer<f32>".into(),
1653                })?;
1654        let dev = self.device(src.device_ordinal())?;
1655        crate::kernels::gpu_slice_write(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
1656            .map_err(Self::map_gpu_err)?;
1657        Ok(())
1658    }
1659
1660    fn slice_read_f32(
1661        &self,
1662        src: &GpuBufferHandle,
1663        n_batch: usize,
1664        d: usize,
1665        len: usize,
1666        max_len: usize,
1667    ) -> FerrotorchResult<GpuBufferHandle> {
1668        let src_buf = Self::unwrap_buffer(src)?;
1669        let dev = self.device(src.device_ordinal())?;
1670        let result = crate::kernels::gpu_slice_read(src_buf, n_batch, d, len, max_len, dev)
1671            .map_err(Self::map_gpu_err)?;
1672        Ok(Self::wrap_buffer(result, src.device_ordinal()))
1673    }
1674
1675    fn embed_lookup_f32(
1676        &self,
1677        idx: &GpuBufferHandle,
1678        weight: &GpuBufferHandle,
1679        d: usize,
1680    ) -> FerrotorchResult<GpuBufferHandle> {
1681        let idx_buf = Self::unwrap_buffer(idx)?;
1682        let w_buf = Self::unwrap_buffer(weight)?;
1683        let dev = self.device(idx.device_ordinal())?;
1684        let result =
1685            crate::kernels::gpu_embed_lookup(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
1686        Ok(Self::wrap_buffer(result, idx.device_ordinal()))
1687    }
1688
1689    fn embed_lookup_batch_f32(
1690        &self,
1691        indices: &GpuBufferHandle,
1692        weight: &GpuBufferHandle,
1693        n: usize,
1694        d: usize,
1695    ) -> FerrotorchResult<GpuBufferHandle> {
1696        let idx_buf = Self::unwrap_buffer(indices)?;
1697        let w_buf = Self::unwrap_buffer(weight)?;
1698        let dev = self.device(indices.device_ordinal())?;
1699        let result = crate::kernels::gpu_embed_lookup_batch(idx_buf, w_buf, n, d, dev)
1700            .map_err(Self::map_gpu_err)?;
1701        Ok(Self::wrap_buffer(result, indices.device_ordinal()))
1702    }
1703
1704    fn scatter_add_rows_f32(
1705        &self,
1706        grad_output: &GpuBufferHandle,
1707        indices: &GpuBufferHandle,
1708        num_embeddings: usize,
1709        d: usize,
1710    ) -> FerrotorchResult<GpuBufferHandle> {
1711        let go_buf = Self::unwrap_buffer(grad_output)?;
1712        let idx_buf = Self::unwrap_buffer(indices)?;
1713        let dev = self.device(grad_output.device_ordinal())?;
1714        let result = crate::kernels::gpu_scatter_add_rows(go_buf, idx_buf, num_embeddings, d, dev)
1715            .map_err(Self::map_gpu_err)?;
1716        Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
1717    }
1718
1719    fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle> {
1720        let a_buf = Self::unwrap_buffer(a)?;
1721        let dev = self.device(a.device_ordinal())?;
1722        let result = crate::kernels::gpu_scale(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
1723        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1724    }
1725
1726    fn relu_backward_f32(
1727        &self,
1728        grad: &GpuBufferHandle,
1729        input: &GpuBufferHandle,
1730    ) -> FerrotorchResult<GpuBufferHandle> {
1731        let grad_buf = Self::unwrap_buffer(grad)?;
1732        let input_buf = Self::unwrap_buffer(input)?;
1733        let dev = self.device(grad.device_ordinal())?;
1734        let result = crate::kernels::gpu_relu_backward(grad_buf, input_buf, dev)
1735            .map_err(Self::map_gpu_err)?;
1736        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1737    }
1738
1739    fn abs_backward_f32(
1740        &self,
1741        grad: &GpuBufferHandle,
1742        input: &GpuBufferHandle,
1743    ) -> FerrotorchResult<GpuBufferHandle> {
1744        let grad_buf = Self::unwrap_buffer(grad)?;
1745        let input_buf = Self::unwrap_buffer(input)?;
1746        let dev = self.device(grad.device_ordinal())?;
1747        let result = crate::kernels::gpu_abs_backward(grad_buf, input_buf, dev)
1748            .map_err(Self::map_gpu_err)?;
1749        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1750    }
1751
1752    fn fill_f32(
1753        &self,
1754        n: usize,
1755        scalar: f32,
1756        ordinal: usize,
1757    ) -> FerrotorchResult<GpuBufferHandle> {
1758        let dev = self.device(ordinal)?;
1759        let result = crate::kernels::gpu_fill_f32(n, scalar, dev).map_err(Self::map_gpu_err)?;
1760        Ok(Self::wrap_buffer(result, ordinal))
1761    }
1762
1763    fn gelu_backward_f32(
1764        &self,
1765        grad: &GpuBufferHandle,
1766        input: &GpuBufferHandle,
1767    ) -> FerrotorchResult<GpuBufferHandle> {
1768        let grad_buf = Self::unwrap_buffer(grad)?;
1769        let input_buf = Self::unwrap_buffer(input)?;
1770        let dev = self.device(grad.device_ordinal())?;
1771        let result = crate::kernels::gpu_gelu_backward(grad_buf, input_buf, dev)
1772            .map_err(Self::map_gpu_err)?;
1773        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1774    }
1775
1776    fn gelu_backward_tanh_f32(
1777        &self,
1778        grad: &GpuBufferHandle,
1779        input: &GpuBufferHandle,
1780    ) -> FerrotorchResult<GpuBufferHandle> {
1781        let grad_buf = Self::unwrap_buffer(grad)?;
1782        let input_buf = Self::unwrap_buffer(input)?;
1783        let dev = self.device(grad.device_ordinal())?;
1784        let result = crate::kernels::gpu_gelu_backward_tanh(grad_buf, input_buf, dev)
1785            .map_err(Self::map_gpu_err)?;
1786        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1787    }
1788
1789    fn gelu_backward_erf_f32(
1790        &self,
1791        grad: &GpuBufferHandle,
1792        input: &GpuBufferHandle,
1793    ) -> FerrotorchResult<GpuBufferHandle> {
1794        let grad_buf = Self::unwrap_buffer(grad)?;
1795        let input_buf = Self::unwrap_buffer(input)?;
1796        let dev = self.device(grad.device_ordinal())?;
1797        let result = crate::kernels::gpu_gelu_backward_erf(grad_buf, input_buf, dev)
1798            .map_err(Self::map_gpu_err)?;
1799        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1800    }
1801
1802    fn cumsum_f32(
1803        &self,
1804        a: &GpuBufferHandle,
1805        outer: usize,
1806        dim_size: usize,
1807        inner: usize,
1808    ) -> FerrotorchResult<GpuBufferHandle> {
1809        let a_buf = Self::unwrap_buffer(a)?;
1810        let dev = self.device(a.device_ordinal())?;
1811        let result = crate::kernels::gpu_cumsum(a_buf, outer, dim_size, inner, dev)
1812            .map_err(Self::map_gpu_err)?;
1813        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1814    }
1815
1816    fn cumprod_f32(
1817        &self,
1818        a: &GpuBufferHandle,
1819        outer: usize,
1820        dim_size: usize,
1821        inner: usize,
1822    ) -> FerrotorchResult<GpuBufferHandle> {
1823        let a_buf = Self::unwrap_buffer(a)?;
1824        let dev = self.device(a.device_ordinal())?;
1825        let result = crate::kernels::gpu_cumprod(a_buf, outer, dim_size, inner, dev)
1826            .map_err(Self::map_gpu_err)?;
1827        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1828    }
1829
1830    fn cummax_f32(
1831        &self,
1832        a: &GpuBufferHandle,
1833        outer: usize,
1834        dim_size: usize,
1835        inner: usize,
1836    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1837        let a_buf = Self::unwrap_buffer(a)?;
1838        let dev = self.device(a.device_ordinal())?;
1839        let (vals, idxs) = crate::kernels::gpu_cummax(a_buf, outer, dim_size, inner, dev)
1840            .map_err(Self::map_gpu_err)?;
1841        let ord = a.device_ordinal();
1842        Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1843    }
1844
1845    fn cummin_f32(
1846        &self,
1847        a: &GpuBufferHandle,
1848        outer: usize,
1849        dim_size: usize,
1850        inner: usize,
1851    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1852        let a_buf = Self::unwrap_buffer(a)?;
1853        let dev = self.device(a.device_ordinal())?;
1854        let (vals, idxs) = crate::kernels::gpu_cummin(a_buf, outer, dim_size, inner, dev)
1855            .map_err(Self::map_gpu_err)?;
1856        let ord = a.device_ordinal();
1857        Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1858    }
1859
1860    fn logcumsumexp_f32(
1861        &self,
1862        a: &GpuBufferHandle,
1863        outer: usize,
1864        dim_size: usize,
1865        inner: usize,
1866    ) -> FerrotorchResult<GpuBufferHandle> {
1867        let a_buf = Self::unwrap_buffer(a)?;
1868        let dev = self.device(a.device_ordinal())?;
1869        let result = crate::kernels::gpu_logcumsumexp(a_buf, outer, dim_size, inner, dev)
1870            .map_err(Self::map_gpu_err)?;
1871        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1872    }
1873
1874    fn clamp_f32(
1875        &self,
1876        a: &GpuBufferHandle,
1877        min_val: f32,
1878        max_val: f32,
1879    ) -> FerrotorchResult<GpuBufferHandle> {
1880        let a_buf = Self::unwrap_buffer(a)?;
1881        let dev = self.device(a.device_ordinal())?;
1882        let result =
1883            crate::kernels::gpu_clamp(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
1884        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1885    }
1886
1887    fn silu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1888        let a_buf = Self::unwrap_buffer(a)?;
1889        let dev = self.device(a.device_ordinal())?;
1890        let result = crate::kernels::gpu_silu(a_buf, dev).map_err(Self::map_gpu_err)?;
1891        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1892    }
1893
1894    fn silu_backward_f32(
1895        &self,
1896        grad: &GpuBufferHandle,
1897        input: &GpuBufferHandle,
1898    ) -> FerrotorchResult<GpuBufferHandle> {
1899        let grad_buf = Self::unwrap_buffer(grad)?;
1900        let input_buf = Self::unwrap_buffer(input)?;
1901        let dev = self.device(grad.device_ordinal())?;
1902        let result = crate::kernels::gpu_silu_backward(grad_buf, input_buf, dev)
1903            .map_err(Self::map_gpu_err)?;
1904        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1905    }
1906
1907    fn elu_f32(&self, a: &GpuBufferHandle, alpha: f32) -> FerrotorchResult<GpuBufferHandle> {
1908        let a_buf = Self::unwrap_buffer(a)?;
1909        let dev = self.device(a.device_ordinal())?;
1910        let result = crate::kernels::gpu_elu(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
1911        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1912    }
1913
1914    fn elu_backward_f32(
1915        &self,
1916        grad: &GpuBufferHandle,
1917        input: &GpuBufferHandle,
1918        alpha: f32,
1919    ) -> FerrotorchResult<GpuBufferHandle> {
1920        let grad_buf = Self::unwrap_buffer(grad)?;
1921        let input_buf = Self::unwrap_buffer(input)?;
1922        let dev = self.device(grad.device_ordinal())?;
1923        let result = crate::kernels::gpu_elu_backward(grad_buf, input_buf, alpha, dev)
1924            .map_err(Self::map_gpu_err)?;
1925        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1926    }
1927
1928    fn mish_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1929        let a_buf = Self::unwrap_buffer(a)?;
1930        let dev = self.device(a.device_ordinal())?;
1931        let result = crate::kernels::gpu_mish(a_buf, dev).map_err(Self::map_gpu_err)?;
1932        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1933    }
1934
1935    fn mish_backward_f32(
1936        &self,
1937        grad: &GpuBufferHandle,
1938        input: &GpuBufferHandle,
1939    ) -> FerrotorchResult<GpuBufferHandle> {
1940        let grad_buf = Self::unwrap_buffer(grad)?;
1941        let input_buf = Self::unwrap_buffer(input)?;
1942        let dev = self.device(grad.device_ordinal())?;
1943        let result = crate::kernels::gpu_mish_backward(grad_buf, input_buf, dev)
1944            .map_err(Self::map_gpu_err)?;
1945        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1946    }
1947
1948    fn log_softmax_f32(
1949        &self,
1950        a: &GpuBufferHandle,
1951        cols: usize,
1952    ) -> FerrotorchResult<GpuBufferHandle> {
1953        let a_buf = Self::unwrap_buffer(a)?;
1954        let dev = self.device(a.device_ordinal())?;
1955        let result =
1956            crate::kernels::gpu_log_softmax(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
1957        Ok(Self::wrap_buffer(result, a.device_ordinal()))
1958    }
1959
1960    fn log_softmax_backward_f32(
1961        &self,
1962        grad: &GpuBufferHandle,
1963        output: &GpuBufferHandle,
1964        cols: usize,
1965    ) -> FerrotorchResult<GpuBufferHandle> {
1966        let grad_buf = Self::unwrap_buffer(grad)?;
1967        let output_buf = Self::unwrap_buffer(output)?;
1968        let dev = self.device(grad.device_ordinal())?;
1969        let result =
1970            crate::kernels::gpu_log_softmax_backward(grad_buf, output_buf, cols, dev)
1971                .map_err(Self::map_gpu_err)?;
1972        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1973    }
1974
1975    fn index_select_1d_f32(
1976        &self,
1977        input: &GpuBufferHandle,
1978        indices: &GpuBufferHandle,
1979    ) -> FerrotorchResult<GpuBufferHandle> {
1980        let input_buf = Self::unwrap_buffer(input)?;
1981        let idx_buf = Self::unwrap_buffer(indices)?;
1982        let dev = self.device(input.device_ordinal())?;
1983        let result = crate::kernels::gpu_index_select_1d(input_buf, idx_buf, dev)
1984            .map_err(Self::map_gpu_err)?;
1985        Ok(Self::wrap_buffer(result, input.device_ordinal()))
1986    }
1987
1988    fn scatter_add_1d_f32(
1989        &self,
1990        grad_output: &GpuBufferHandle,
1991        indices: &GpuBufferHandle,
1992        input_len: usize,
1993    ) -> FerrotorchResult<GpuBufferHandle> {
1994        let go_buf = Self::unwrap_buffer(grad_output)?;
1995        let idx_buf = Self::unwrap_buffer(indices)?;
1996        let dev = self.device(grad_output.device_ordinal())?;
1997        let result = crate::kernels::gpu_scatter_add_1d(go_buf, idx_buf, input_len, dev)
1998            .map_err(Self::map_gpu_err)?;
1999        Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
2000    }
2001
2002    fn masked_fill_f32(
2003        &self,
2004        input: &GpuBufferHandle,
2005        mask: &GpuBufferHandle,
2006        value: f32,
2007    ) -> FerrotorchResult<GpuBufferHandle> {
2008        let input_buf = Self::unwrap_buffer(input)?;
2009        let mask_buf = Self::unwrap_buffer(mask)?;
2010        let dev = self.device(input.device_ordinal())?;
2011        let result = crate::kernels::gpu_masked_fill(input_buf, mask_buf, value, dev)
2012            .map_err(Self::map_gpu_err)?;
2013        Ok(Self::wrap_buffer(result, input.device_ordinal()))
2014    }
2015
2016    fn masked_zero_f32(
2017        &self,
2018        grad: &GpuBufferHandle,
2019        mask: &GpuBufferHandle,
2020    ) -> FerrotorchResult<GpuBufferHandle> {
2021        let grad_buf = Self::unwrap_buffer(grad)?;
2022        let mask_buf = Self::unwrap_buffer(mask)?;
2023        let dev = self.device(grad.device_ordinal())?;
2024        let result =
2025            crate::kernels::gpu_masked_zero(grad_buf, mask_buf, dev).map_err(Self::map_gpu_err)?;
2026        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2027    }
2028
2029    fn sigmoid_backward_f32(
2030        &self,
2031        grad: &GpuBufferHandle,
2032        output: &GpuBufferHandle,
2033    ) -> FerrotorchResult<GpuBufferHandle> {
2034        let grad_buf = Self::unwrap_buffer(grad)?;
2035        let output_buf = Self::unwrap_buffer(output)?;
2036        let dev = self.device(grad.device_ordinal())?;
2037        let result = crate::kernels::gpu_sigmoid_backward(grad_buf, output_buf, dev)
2038            .map_err(Self::map_gpu_err)?;
2039        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2040    }
2041
2042    fn tanh_backward_f32(
2043        &self,
2044        grad: &GpuBufferHandle,
2045        output: &GpuBufferHandle,
2046    ) -> FerrotorchResult<GpuBufferHandle> {
2047        let grad_buf = Self::unwrap_buffer(grad)?;
2048        let output_buf = Self::unwrap_buffer(output)?;
2049        let dev = self.device(grad.device_ordinal())?;
2050        let result = crate::kernels::gpu_tanh_backward(grad_buf, output_buf, dev)
2051            .map_err(Self::map_gpu_err)?;
2052        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2053    }
2054
2055    fn softmax_backward_f32(
2056        &self,
2057        grad: &GpuBufferHandle,
2058        output: &GpuBufferHandle,
2059        cols: usize,
2060    ) -> FerrotorchResult<GpuBufferHandle> {
2061        let grad_buf = Self::unwrap_buffer(grad)?;
2062        let output_buf = Self::unwrap_buffer(output)?;
2063        let dev = self.device(grad.device_ordinal())?;
2064        let result = crate::kernels::gpu_softmax_backward(grad_buf, output_buf, cols, dev)
2065            .map_err(Self::map_gpu_err)?;
2066        Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2067    }
2068
2069    fn layernorm_backward_f32(
2070        &self,
2071        input: &GpuBufferHandle,
2072        grad_output: &GpuBufferHandle,
2073        weight: &GpuBufferHandle,
2074        rows: usize,
2075        cols: usize,
2076        eps: f32,
2077    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2078        let in_buf = Self::unwrap_buffer(input)?;
2079        let go_buf = Self::unwrap_buffer(grad_output)?;
2080        let w_buf = Self::unwrap_buffer(weight)?;
2081        let dev = self.device(input.device_ordinal())?;
2082        let (gi, gw, gb) =
2083            crate::kernels::gpu_layernorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
2084                .map_err(Self::map_gpu_err)?;
2085        let ordinal = input.device_ordinal();
2086        Ok((
2087            Self::wrap_buffer(gi, ordinal),
2088            Self::wrap_buffer(gw, ordinal),
2089            Self::wrap_buffer(gb, ordinal),
2090        ))
2091    }
2092
2093    fn sum_axis_f32(
2094        &self,
2095        a: &GpuBufferHandle,
2096        shape: &[usize],
2097        axis: usize,
2098    ) -> FerrotorchResult<GpuBufferHandle> {
2099        let a_buf = Self::unwrap_buffer(a)?;
2100        let dev = self.device(a.device_ordinal())?;
2101        let outer: usize = shape[..axis].iter().product();
2102        let axis_size = shape[axis];
2103        let inner: usize = shape[axis + 1..].iter().product::<usize>().max(1);
2104        let result = crate::kernels::gpu_sum_axis(a_buf, outer, axis_size, inner, dev)
2105            .map_err(Self::map_gpu_err)?;
2106        Ok(Self::wrap_buffer(result, a.device_ordinal()))
2107    }
2108
2109    fn matmul_f16_f32(
2110        &self,
2111        a: &GpuBufferHandle,
2112        b: &GpuBufferHandle,
2113        m: usize,
2114        k: usize,
2115        n: usize,
2116    ) -> FerrotorchResult<GpuBufferHandle> {
2117        let a_buf = Self::unwrap_buffer(a)?;
2118        let b_buf = Self::unwrap_buffer(b)?;
2119        let dev = self.device(a.device_ordinal())?;
2120        let result =
2121            crate::blas::gpu_matmul_f16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
2122        Ok(Self::wrap_buffer(result, a.device_ordinal()))
2123    }
2124
2125    fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
2126        let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
2127            FerrotorchError::InvalidArgument {
2128                message: "failed to lock CUDA RNG manager".into(),
2129            }
2130        })?;
2131        let state = mgr.get_rng_state(device);
2132        Ok(GpuRngState {
2133            counter: state.counter,
2134            seed: state.seed,
2135            offset: state.offset,
2136            device,
2137        })
2138    }
2139
2140    fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
2141        let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
2142            FerrotorchError::InvalidArgument {
2143                message: "failed to lock CUDA RNG manager".into(),
2144            }
2145        })?;
2146        mgr.set_rng_state(
2147            state.device,
2148            crate::rng::PhiloxState {
2149                counter: state.counter,
2150                seed: state.seed,
2151                offset: state.offset,
2152            },
2153        );
2154        Ok(())
2155    }
2156
2157    fn strided_split_f32(
2158        &self,
2159        input: &GpuBufferHandle,
2160        total_along_axis: usize,
2161        split_offset: usize,
2162        split_size: usize,
2163        inner_size: usize,
2164        n: usize,
2165    ) -> FerrotorchResult<GpuBufferHandle> {
2166        let in_buf = Self::unwrap_buffer(input)?;
2167        let dev = self.device(input.device_ordinal())?;
2168        let result = crate::kernels::gpu_strided_split(
2169            in_buf,
2170            total_along_axis,
2171            split_offset,
2172            split_size,
2173            inner_size,
2174            n,
2175            dev,
2176        )
2177        .map_err(Self::map_gpu_err)?;
2178        Ok(Self::wrap_buffer(result, input.device_ordinal()))
2179    }
2180
2181    fn strided_copy_f32(
2182        &self,
2183        input: &GpuBufferHandle,
2184        out_shape: &[usize],
2185        src_strides: &[isize],
2186        src_offset: usize,
2187    ) -> FerrotorchResult<GpuBufferHandle> {
2188        let in_buf = Self::unwrap_buffer(input)?;
2189        let dev = self.device(input.device_ordinal())?;
2190        let result =
2191            crate::kernels::gpu_strided_copy(in_buf, out_shape, src_strides, src_offset, dev)
2192                .map_err(Self::map_gpu_err)?;
2193        Ok(Self::wrap_buffer(result, input.device_ordinal()))
2194    }
2195
2196    fn strided_copy_f64(
2197        &self,
2198        input: &GpuBufferHandle,
2199        out_shape: &[usize],
2200        src_strides: &[isize],
2201        src_offset: usize,
2202    ) -> FerrotorchResult<GpuBufferHandle> {
2203        let in_buf = Self::unwrap_buffer_f64(input)?;
2204        let dev = self.device(input.device_ordinal())?;
2205        let result = crate::kernels::gpu_strided_copy_f64(
2206            in_buf,
2207            out_shape,
2208            src_strides,
2209            src_offset,
2210            dev,
2211        )
2212        .map_err(Self::map_gpu_err)?;
2213        Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
2214    }
2215
2216    fn strided_cat_f32(
2217        &self,
2218        input: &GpuBufferHandle,
2219        output: &mut GpuBufferHandle,
2220        total_along_axis: usize,
2221        cat_offset: usize,
2222        part_size: usize,
2223        inner_size: usize,
2224        n: usize,
2225    ) -> FerrotorchResult<()> {
2226        let in_buf = Self::unwrap_buffer(input)?;
2227        let dev = self.device(input.device_ordinal())?;
2228        let out_buf =
2229            output
2230                .downcast_mut::<CudaBuffer<f32>>()
2231                .ok_or(FerrotorchError::InvalidArgument {
2232                    message: "strided_cat_f32: output is not CudaBuffer<f32>".into(),
2233                })?;
2234        crate::kernels::gpu_strided_cat(
2235            in_buf,
2236            out_buf,
2237            total_along_axis,
2238            cat_offset,
2239            part_size,
2240            inner_size,
2241            n,
2242            dev,
2243        )
2244        .map_err(Self::map_gpu_err)?;
2245        Ok(())
2246    }
2247
2248    // -- cuSOLVER linear algebra -------------------------------------------------
2249
2250    fn svd_f32(
2251        &self,
2252        a: &GpuBufferHandle,
2253        m: usize,
2254        n: usize,
2255    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2256        let a_buf = Self::unwrap_buffer(a)?;
2257        let dev = self.device(a.device_ordinal())?;
2258        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2259        let (u, s, vt) =
2260            crate::cusolver::gpu_svd_f32(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2261        let u_buf = crate::transfer::cpu_to_gpu(&u, dev).map_err(Self::map_gpu_err)?;
2262        let s_buf = crate::transfer::cpu_to_gpu(&s, dev).map_err(Self::map_gpu_err)?;
2263        let vt_buf = crate::transfer::cpu_to_gpu(&vt, dev).map_err(Self::map_gpu_err)?;
2264        let ord = a.device_ordinal();
2265        Ok((
2266            Self::wrap_buffer(u_buf, ord),
2267            Self::wrap_buffer(s_buf, ord),
2268            Self::wrap_buffer(vt_buf, ord),
2269        ))
2270    }
2271
2272    fn svd_f64(
2273        &self,
2274        a: &GpuBufferHandle,
2275        m: usize,
2276        n: usize,
2277    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2278        let a_buf = Self::unwrap_buffer_f64(a)?;
2279        let dev = self.device(a.device_ordinal())?;
2280        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2281        let (u, s, vt) =
2282            crate::cusolver::gpu_svd_f64(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2283        let u_buf = crate::transfer::cpu_to_gpu(&u, dev).map_err(Self::map_gpu_err)?;
2284        let s_buf = crate::transfer::cpu_to_gpu(&s, dev).map_err(Self::map_gpu_err)?;
2285        let vt_buf = crate::transfer::cpu_to_gpu(&vt, dev).map_err(Self::map_gpu_err)?;
2286        let ord = a.device_ordinal();
2287        Ok((
2288            Self::wrap_buffer_f64(u_buf, ord),
2289            Self::wrap_buffer_f64(s_buf, ord),
2290            Self::wrap_buffer_f64(vt_buf, ord),
2291        ))
2292    }
2293
2294    fn cholesky_f32(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
2295        let a_buf = Self::unwrap_buffer(a)?;
2296        let dev = self.device(a.device_ordinal())?;
2297        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2298        let l = crate::cusolver::gpu_cholesky_f32(&a_host, n, dev).map_err(Self::map_gpu_err)?;
2299        let l_buf = crate::transfer::cpu_to_gpu(&l, dev).map_err(Self::map_gpu_err)?;
2300        Ok(Self::wrap_buffer(l_buf, a.device_ordinal()))
2301    }
2302
2303    fn cholesky_f64(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
2304        let a_buf = Self::unwrap_buffer_f64(a)?;
2305        let dev = self.device(a.device_ordinal())?;
2306        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2307        let l = crate::cusolver::gpu_cholesky_f64(&a_host, n, dev).map_err(Self::map_gpu_err)?;
2308        let l_buf = crate::transfer::cpu_to_gpu(&l, dev).map_err(Self::map_gpu_err)?;
2309        Ok(Self::wrap_buffer_f64(l_buf, a.device_ordinal()))
2310    }
2311
2312    fn solve_f32(
2313        &self,
2314        a: &GpuBufferHandle,
2315        b: &GpuBufferHandle,
2316        n: usize,
2317        nrhs: usize,
2318    ) -> FerrotorchResult<GpuBufferHandle> {
2319        let a_buf = Self::unwrap_buffer(a)?;
2320        let b_buf = Self::unwrap_buffer(b)?;
2321        let dev = self.device(a.device_ordinal())?;
2322        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2323        let b_host = crate::transfer::gpu_to_cpu(b_buf, dev).map_err(Self::map_gpu_err)?;
2324        let x =
2325            crate::cusolver::gpu_solve_f32(&a_host, &b_host, n, nrhs, dev)
2326                .map_err(Self::map_gpu_err)?;
2327        let x_buf = crate::transfer::cpu_to_gpu(&x, dev).map_err(Self::map_gpu_err)?;
2328        Ok(Self::wrap_buffer(x_buf, a.device_ordinal()))
2329    }
2330
2331    fn solve_f64(
2332        &self,
2333        a: &GpuBufferHandle,
2334        b: &GpuBufferHandle,
2335        n: usize,
2336        nrhs: usize,
2337    ) -> FerrotorchResult<GpuBufferHandle> {
2338        let a_buf = Self::unwrap_buffer_f64(a)?;
2339        let b_buf = Self::unwrap_buffer_f64(b)?;
2340        let dev = self.device(a.device_ordinal())?;
2341        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2342        let b_host = crate::transfer::gpu_to_cpu(b_buf, dev).map_err(Self::map_gpu_err)?;
2343        let x =
2344            crate::cusolver::gpu_solve_f64(&a_host, &b_host, n, nrhs, dev)
2345                .map_err(Self::map_gpu_err)?;
2346        let x_buf = crate::transfer::cpu_to_gpu(&x, dev).map_err(Self::map_gpu_err)?;
2347        Ok(Self::wrap_buffer_f64(x_buf, a.device_ordinal()))
2348    }
2349
2350    fn qr_f32(
2351        &self,
2352        a: &GpuBufferHandle,
2353        m: usize,
2354        n: usize,
2355    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
2356        let a_buf = Self::unwrap_buffer(a)?;
2357        let dev = self.device(a.device_ordinal())?;
2358        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2359        let (q, r) =
2360            crate::cusolver::gpu_qr_f32(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2361        let q_buf = crate::transfer::cpu_to_gpu(&q, dev).map_err(Self::map_gpu_err)?;
2362        let r_buf = crate::transfer::cpu_to_gpu(&r, dev).map_err(Self::map_gpu_err)?;
2363        let ord = a.device_ordinal();
2364        Ok((Self::wrap_buffer(q_buf, ord), Self::wrap_buffer(r_buf, ord)))
2365    }
2366
2367    fn qr_f64(
2368        &self,
2369        a: &GpuBufferHandle,
2370        m: usize,
2371        n: usize,
2372    ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
2373        let a_buf = Self::unwrap_buffer_f64(a)?;
2374        let dev = self.device(a.device_ordinal())?;
2375        let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2376        let (q, r) =
2377            crate::cusolver::gpu_qr_f64(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2378        let q_buf = crate::transfer::cpu_to_gpu(&q, dev).map_err(Self::map_gpu_err)?;
2379        let r_buf = crate::transfer::cpu_to_gpu(&r, dev).map_err(Self::map_gpu_err)?;
2380        let ord = a.device_ordinal();
2381        Ok((
2382            Self::wrap_buffer_f64(q_buf, ord),
2383            Self::wrap_buffer_f64(r_buf, ord),
2384        ))
2385    }
2386}
2387
2388// ---------------------------------------------------------------------------
2389// Registration
2390// ---------------------------------------------------------------------------
2391
2392/// Get the `GpuDevice` from the registered CUDA backend.
2393///
2394/// This retrieves the device that was created during [`init_cuda_backend`],
2395/// ensuring all kernel modules and cuBLAS handles are shared. Creating a
2396/// second `GpuDevice` via `GpuDevice::new(0)` would create a separate
2397/// CUDA context with its own module cache, which is not interoperable.
2398pub fn get_cuda_device() -> FerrotorchResult<Arc<GpuDevice>> {
2399    let backend =
2400        ferrotorch_core::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
2401    // The global backend is a &dyn GpuBackend. We know it's CudaBackendImpl
2402    // because init_cuda_backend registered it. Downcast via Any.
2403    let cuda_backend = backend.as_any().downcast_ref::<CudaBackendImpl>().ok_or(
2404        FerrotorchError::InvalidArgument {
2405            message: "registered GPU backend is not CudaBackendImpl".into(),
2406        },
2407    )?;
2408    Ok(Arc::clone(cuda_backend.default_device()?))
2409}
2410
2411/// Initialize the CUDA backend and register it with ferrotorch-core.
2412///
2413/// This must be called before any GPU tensor operations. It creates a
2414/// [`CudaBackendImpl`] (initializing CUDA device 0) and registers it via
2415/// [`ferrotorch_core::gpu_dispatch::register_gpu_backend`].
2416///
2417/// Calling this a second time returns an error (the backend is already
2418/// registered).
2419///
2420/// # Errors
2421///
2422/// - [`FerrotorchError::InvalidArgument`] if CUDA initialization fails.
2423/// - [`FerrotorchError::InvalidArgument`] if a GPU backend is already registered.
2424pub fn init_cuda_backend() -> FerrotorchResult<()> {
2425    // Idempotent: if already registered, return Ok silently.
2426    if ferrotorch_core::gpu_dispatch::has_gpu_backend() {
2427        return Ok(());
2428    }
2429    let backend = CudaBackendImpl::new()?;
2430    // OnceLock::set can still race if two threads call init concurrently —
2431    // if that happens, the second set() fails but the backend is registered
2432    // by the first. We treat that as success.
2433    let _ = ferrotorch_core::gpu_dispatch::register_gpu_backend(Box::new(backend));
2434    Ok(())
2435}
2436
2437// ---------------------------------------------------------------------------
2438// Tests
2439// ---------------------------------------------------------------------------
2440
2441#[cfg(test)]
2442#[cfg(feature = "cuda")]
2443mod tests {
2444    use super::*;
2445    use ferrotorch_core::gpu_dispatch;
2446
2447    // Note: Because `register_gpu_backend` uses a `OnceLock`, only the first
2448    // test to call `init_cuda_backend()` will succeed at registration. The
2449    // others will see the backend as already registered. We handle this by
2450    // checking `has_gpu_backend()` before calling init.
2451
2452    /// Ensure the backend can be initialized (or was already initialized).
2453    fn ensure_init() {
2454        if !gpu_dispatch::has_gpu_backend() {
2455            init_cuda_backend().expect("init_cuda_backend");
2456        }
2457    }
2458
2459    #[test]
2460    fn test_init_cuda_backend() {
2461        // First call succeeds (or backend was already registered by another test).
2462        ensure_init();
2463        assert!(gpu_dispatch::has_gpu_backend());
2464    }
2465
2466    #[test]
2467    fn test_gpu_backend_returns_some() {
2468        ensure_init();
2469        assert!(gpu_dispatch::gpu_backend().is_some());
2470    }
2471
2472    #[test]
2473    fn test_roundtrip_cpu_gpu_cpu() {
2474        ensure_init();
2475        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2476
2477        let host: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2478        let bytes: &[u8] = unsafe {
2479            std::slice::from_raw_parts(
2480                host.as_ptr() as *const u8,
2481                host.len() * std::mem::size_of::<f32>(),
2482            )
2483        };
2484
2485        let handle = backend.cpu_to_gpu(bytes, 4, 0).expect("cpu_to_gpu");
2486        assert_eq!(handle.len(), 5);
2487        assert_eq!(handle.device_ordinal(), 0);
2488
2489        let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu");
2490        let back: &[f32] = unsafe {
2491            std::slice::from_raw_parts(back_bytes.as_ptr() as *const f32, back_bytes.len() / 4)
2492        };
2493        assert_eq!(back, &host[..]);
2494    }
2495
2496    #[test]
2497    fn test_add_f32() {
2498        ensure_init();
2499        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2500
2501        let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2502        let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
2503        let expected: Vec<f32> = vec![11.0, 22.0, 33.0, 44.0];
2504
2505        let a_bytes: &[u8] =
2506            unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
2507        let b_bytes: &[u8] =
2508            unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
2509
2510        let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
2511        let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
2512
2513        let result = backend.add_f32(&a_handle, &b_handle).expect("add_f32");
2514        assert_eq!(result.len(), 4);
2515
2516        let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
2517        let result_f32: &[f32] = unsafe {
2518            std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
2519        };
2520
2521        for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
2522            assert!(
2523                (got - exp).abs() < 1e-6,
2524                "element {i}: got {got}, expected {exp}",
2525            );
2526        }
2527    }
2528
2529    #[test]
2530    fn test_matmul_f32() {
2531        ensure_init();
2532        let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2533
2534        // A = [[1, 2, 3],
2535        //      [4, 5, 6]]  (2x3)
2536        // B = [[7, 8],
2537        //      [9, 10],
2538        //      [11, 12]]   (3x2)
2539        // C = [[58, 64],
2540        //      [139, 154]] (2x2)
2541        let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
2542        let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
2543        let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
2544
2545        let a_bytes: &[u8] =
2546            unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
2547        let b_bytes: &[u8] =
2548            unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
2549
2550        let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
2551        let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
2552
2553        let result = backend
2554            .matmul_f32(&a_handle, &b_handle, 2, 3, 2)
2555            .expect("matmul_f32");
2556        assert_eq!(result.len(), 4);
2557
2558        let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
2559        let result_f32: &[f32] = unsafe {
2560            std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
2561        };
2562
2563        for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
2564            assert!(
2565                (got - exp).abs() < 1e-3,
2566                "element {i}: got {got}, expected {exp}",
2567            );
2568        }
2569    }
2570}