Skip to main content

axonml_core/backends/
cuda.rs

1//! CUDA Backend - NVIDIA GPU Operations
2//!
3//! # File
4//! `crates/axonml-core/src/backends/cuda.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17#[cfg(feature = "cuda")]
18use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, sys::cublasOperation_t};
19#[cfg(feature = "cudnn")]
20use cudarc::cudnn::Cudnn;
21#[cfg(feature = "cuda")]
22use cudarc::driver::{
23    CudaContext, CudaSlice, CudaStream, DeviceRepr, DeviceSlice, LaunchConfig, PushKernelArg,
24    ValidAsZeroBits,
25};
26
27use super::Backend;
28#[cfg(feature = "cuda")]
29use super::cuda_kernels::{self, BLOCK_SIZE, CudaKernels};
30use crate::device::DeviceCapabilities;
31#[cfg(feature = "cuda")]
32use std::sync::Arc;
33#[cfg(feature = "cuda")]
34use std::sync::OnceLock;
35
36// =============================================================================
37// Global CUDA Backend Singleton
38// =============================================================================
39
40#[cfg(feature = "cuda")]
41static CUDA_BACKEND: OnceLock<Option<CudaBackend>> = OnceLock::new();
42
43/// Get the global CUDA backend singleton (initialized lazily on first call).
44#[cfg(feature = "cuda")]
45pub fn get_cuda_backend() -> Option<&'static CudaBackend> {
46    CUDA_BACKEND
47        .get_or_init(|| {
48            let backend = CudaBackend::new(0);
49            if backend.is_some() {
50                eprintln!("[AxonML] CUDA backend initialized (GPU 0)");
51            }
52            backend
53        })
54        .as_ref()
55}
56
57/// Get the global CUDA backend singleton (stub when cuda feature disabled).
58#[cfg(not(feature = "cuda"))]
59pub fn get_cuda_backend() -> Option<&'static CudaBackend> {
60    None
61}
62
63// =============================================================================
64// CUDA Backend Struct
65// =============================================================================
66
67/// CUDA backend for tensor operations on NVIDIA GPUs.
68///
69/// Note: CudaStream is not Send+Sync, so we don't store it in the struct.
70/// Instead, we use synchronous operations and the device's default stream.
71#[cfg(feature = "cuda")]
72pub struct CudaBackend {
73    device_index: usize,
74    ctx: Arc<CudaContext>,
75    stream: Arc<CudaStream>,
76    blas: CudaBlas,
77    kernels: CudaKernels,
78    #[cfg(feature = "cudnn")]
79    cudnn_handle: Option<Arc<Cudnn>>,
80}
81
82/// CUDA backend stub when the `cuda` feature is disabled.
83#[cfg(not(feature = "cuda"))]
84#[derive(Debug)]
85pub struct CudaBackend {
86    device_index: usize,
87}
88
89// Implement Send and Sync for CudaBackend
90// Safe because CudaContext/CudaStream and CudaBlas are internally synchronized
91#[cfg(feature = "cuda")]
92unsafe impl Send for CudaBackend {}
93#[cfg(feature = "cuda")]
94unsafe impl Sync for CudaBackend {}
95
96#[cfg(feature = "cuda")]
97impl std::fmt::Debug for CudaBackend {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("CudaBackend")
100            .field("device_index", &self.device_index)
101            .finish()
102    }
103}
104
105impl CudaBackend {
106    /// Creates a new CUDA backend for the specified device.
107    #[cfg(feature = "cuda")]
108    pub fn new(device_index: usize) -> Option<Self> {
109        let ctx = CudaContext::new(device_index).ok()?;
110        let stream = ctx.default_stream();
111        let blas = CudaBlas::new(stream.clone()).ok()?;
112        let kernels = match CudaKernels::load(ctx.clone()) {
113            Ok(k) => k,
114            Err(e) => {
115                eprintln!("[AxonML CUDA] Kernel loading failed: {:?}", e);
116                return None;
117            }
118        };
119
120        #[cfg(feature = "cudnn")]
121        let cudnn_handle = match Cudnn::new(stream.clone()) {
122            Ok(handle) => {
123                eprintln!("[AxonML] cuDNN handle initialized");
124                Some(handle)
125            }
126            Err(e) => {
127                eprintln!(
128                    "[AxonML CUDA] cuDNN init failed: {:?} (falling back to im2col+GEMM)",
129                    e
130                );
131                None
132            }
133        };
134
135        Some(Self {
136            device_index,
137            ctx,
138            stream,
139            blas,
140            kernels,
141            #[cfg(feature = "cudnn")]
142            cudnn_handle,
143        })
144    }
145
146    /// Creates a new CUDA backend (stub, always returns None without the `cuda` feature).
147    #[cfg(not(feature = "cuda"))]
148    pub fn new(device_index: usize) -> Option<Self> {
149        let _ = device_index;
150        None // CUDA not available without feature
151    }
152
153    /// Returns the device index.
154    pub fn device_index(&self) -> usize {
155        self.device_index
156    }
157
158    /// Returns the underlying CUDA context.
159    #[cfg(feature = "cuda")]
160    pub fn context(&self) -> &Arc<CudaContext> {
161        &self.ctx
162    }
163
164    /// Returns the underlying CUDA stream.
165    #[cfg(feature = "cuda")]
166    pub fn stream(&self) -> &Arc<CudaStream> {
167        &self.stream
168    }
169
170    /// Returns the cuBLAS handle.
171    #[cfg(feature = "cuda")]
172    pub fn blas(&self) -> &CudaBlas {
173        &self.blas
174    }
175
176    /// Returns the cuDNN handle, if available.
177    #[cfg(feature = "cudnn")]
178    pub fn cudnn(&self) -> Option<&Arc<Cudnn>> {
179        self.cudnn_handle.as_ref()
180    }
181
182    /// Allocates a typed buffer on the GPU initialized to zeros.
183    #[cfg(feature = "cuda")]
184    pub fn alloc<T: DeviceRepr + ValidAsZeroBits>(
185        &self,
186        len: usize,
187    ) -> Result<CudaSlice<T>, CudaError> {
188        self.stream.alloc_zeros(len).map_err(CudaError::from)
189    }
190
191    /// Allocates uninitialized memory on the GPU.
192    #[cfg(feature = "cuda")]
193    pub fn alloc_uninit<T: DeviceRepr>(&self, len: usize) -> Result<CudaSlice<T>, CudaError> {
194        unsafe { self.stream.alloc(len).map_err(CudaError::from) }
195    }
196
197    /// Copies data from host to device.
198    #[cfg(feature = "cuda")]
199    pub fn htod_copy<T: DeviceRepr>(&self, src: &[T]) -> Result<CudaSlice<T>, CudaError> {
200        self.stream.clone_htod(src).map_err(CudaError::from)
201    }
202
203    /// Copies data from device to host.
204    #[cfg(feature = "cuda")]
205    pub fn dtoh_copy<T: DeviceRepr>(&self, src: &CudaSlice<T>) -> Result<Vec<T>, CudaError> {
206        self.stream.clone_dtoh(src).map_err(CudaError::from)
207    }
208}
209
210// =============================================================================
211// Backend Trait Implementation
212// =============================================================================
213
214#[cfg(feature = "cuda")]
215impl Backend for CudaBackend {
216    fn name(&self) -> &'static str {
217        "cuda"
218    }
219
220    fn is_available(&self) -> bool {
221        true
222    }
223
224    fn capabilities(&self) -> DeviceCapabilities {
225        // Query actual device properties
226        let name = format!("CUDA Device {}", self.device_index);
227
228        // Get memory info via CUDA driver API
229        let (free, total) = cudarc::driver::result::mem_get_info().unwrap_or((0, 0));
230
231        DeviceCapabilities {
232            name,
233            total_memory: total,
234            available_memory: free,
235            supports_f16: true,
236            supports_f64: true,
237            max_threads_per_block: 1024,
238            compute_capability: None, // Would need to query this
239        }
240    }
241
242    fn allocate(&self, size: usize) -> *mut u8 {
243        match self.stream.alloc_zeros::<u8>(size) {
244            Ok(slice) => {
245                // Get the raw device pointer via leak
246                let ptr = slice.leak() as *mut u8;
247                ptr
248            }
249            Err(_) => std::ptr::null_mut(),
250        }
251    }
252
253    fn deallocate(&self, ptr: *mut u8, size: usize) {
254        if !ptr.is_null() {
255            // Reconstruct the CudaSlice to properly free
256            unsafe {
257                let slice: CudaSlice<u8> = self
258                    .stream
259                    .upgrade_device_ptr(ptr as cudarc::driver::sys::CUdeviceptr, size);
260                drop(slice);
261            }
262        }
263    }
264
265    fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
266        if dst.is_null() || src.is_null() || size == 0 {
267            return;
268        }
269        unsafe {
270            let src_slice = std::slice::from_raw_parts(src, size);
271            let _ = cudarc::driver::result::memcpy_htod_sync(
272                dst as cudarc::driver::sys::CUdeviceptr,
273                src_slice,
274            );
275        }
276    }
277
278    fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
279        if dst.is_null() || src.is_null() || size == 0 {
280            return;
281        }
282        unsafe {
283            let dst_slice = std::slice::from_raw_parts_mut(dst, size);
284            let _ = cudarc::driver::result::memcpy_dtoh_sync(
285                dst_slice,
286                src as cudarc::driver::sys::CUdeviceptr,
287            );
288        }
289    }
290
291    fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
292        if dst.is_null() || src.is_null() || size == 0 {
293            return;
294        }
295        unsafe {
296            let _ = cudarc::driver::result::memcpy_dtod_sync(
297                dst as cudarc::driver::sys::CUdeviceptr,
298                src as cudarc::driver::sys::CUdeviceptr,
299                size,
300            );
301        }
302    }
303
304    fn synchronize(&self) {
305        let _ = self.stream.synchronize();
306    }
307}
308
309/// Synchronize the CUDA device (wait for all GPU operations to complete).
310/// Returns true if sync was performed, false if CUDA is not available.
311#[cfg(feature = "cuda")]
312pub fn cuda_sync() -> bool {
313    if let Some(backend) = get_cuda_backend() {
314        let _ = backend.stream.synchronize();
315        true
316    } else {
317        false
318    }
319}
320
321/// Synchronize the CUDA device (no-op without the `cuda` feature).
322#[cfg(not(feature = "cuda"))]
323pub fn cuda_sync() -> bool {
324    false
325}
326
327#[cfg(not(feature = "cuda"))]
328impl Backend for CudaBackend {
329    fn name(&self) -> &'static str {
330        "cuda"
331    }
332
333    fn is_available(&self) -> bool {
334        false
335    }
336
337    fn capabilities(&self) -> DeviceCapabilities {
338        DeviceCapabilities {
339            name: format!("CUDA Device {} (unavailable)", self.device_index),
340            total_memory: 0,
341            available_memory: 0,
342            supports_f16: false,
343            supports_f64: false,
344            max_threads_per_block: 0,
345            compute_capability: None,
346        }
347    }
348
349    fn allocate(&self, _size: usize) -> *mut u8 {
350        std::ptr::null_mut()
351    }
352
353    fn deallocate(&self, _ptr: *mut u8, _size: usize) {}
354
355    fn copy_to_device(&self, _dst: *mut u8, _src: *const u8, _size: usize) {}
356
357    fn copy_to_host(&self, _dst: *mut u8, _src: *const u8, _size: usize) {}
358
359    fn copy_device_to_device(&self, _dst: *mut u8, _src: *const u8, _size: usize) {}
360
361    fn synchronize(&self) {}
362}
363
364// =============================================================================
365// CUDA Error Type
366// =============================================================================
367
368/// CUDA-specific error type
369#[derive(Debug)]
370pub enum CudaError {
371    /// CUDA device was not found
372    DeviceNotFound,
373    /// Memory allocation on the GPU failed
374    AllocationFailed,
375    /// Memory copy operation failed
376    CopyFailed,
377    /// CUDA kernel launch failed
378    KernelLaunchFailed,
379    /// cuBLAS operation error
380    BlasError(String),
381    /// CUDA driver error
382    DriverError(String),
383    /// PTX module loading failed
384    ModuleLoadFailed(String),
385    /// Kernel function not found in module
386    KernelNotFound(String),
387}
388
389impl std::fmt::Display for CudaError {
390    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391        match self {
392            CudaError::DeviceNotFound => write!(f, "CUDA device not found"),
393            CudaError::AllocationFailed => write!(f, "CUDA memory allocation failed"),
394            CudaError::CopyFailed => write!(f, "CUDA memory copy failed"),
395            CudaError::KernelLaunchFailed => write!(f, "CUDA kernel launch failed"),
396            CudaError::BlasError(s) => write!(f, "cuBLAS error: {}", s),
397            CudaError::DriverError(s) => write!(f, "CUDA driver error: {}", s),
398            CudaError::ModuleLoadFailed(s) => write!(f, "CUDA module load failed: {}", s),
399            CudaError::KernelNotFound(s) => write!(f, "CUDA kernel not found: {}", s),
400        }
401    }
402}
403
404impl std::error::Error for CudaError {}
405
406#[cfg(feature = "cuda")]
407impl From<cudarc::driver::DriverError> for CudaError {
408    fn from(e: cudarc::driver::DriverError) -> Self {
409        CudaError::DriverError(e.to_string())
410    }
411}
412
413#[cfg(feature = "cuda")]
414impl From<cudarc::cublas::result::CublasError> for CudaError {
415    fn from(e: cudarc::cublas::result::CublasError) -> Self {
416        CudaError::BlasError(format!("{:?}", e))
417    }
418}
419
420// =============================================================================
421// CUDA Runtime Functions
422// =============================================================================
423
424/// Returns whether CUDA is available on this system.
425pub fn is_available() -> bool {
426    #[cfg(feature = "cuda")]
427    {
428        CudaContext::new(0).is_ok()
429    }
430    #[cfg(not(feature = "cuda"))]
431    {
432        false
433    }
434}
435
436/// Returns the number of available CUDA devices.
437pub fn device_count() -> usize {
438    #[cfg(feature = "cuda")]
439    {
440        cudarc::driver::result::device::get_count().unwrap_or(0) as usize
441    }
442    #[cfg(not(feature = "cuda"))]
443    {
444        0
445    }
446}
447
448/// Returns whether a specific CUDA device is available.
449pub fn is_device_available(index: usize) -> bool {
450    index < device_count()
451}
452
453/// Returns the capabilities of a CUDA device.
454pub fn get_capabilities(index: usize) -> DeviceCapabilities {
455    #[cfg(feature = "cuda")]
456    {
457        if let Some(backend) = CudaBackend::new(index) {
458            return backend.capabilities();
459        }
460    }
461    #[allow(unreachable_code)]
462    DeviceCapabilities {
463        name: format!("CUDA Device {}", index),
464        total_memory: 0,
465        available_memory: 0,
466        supports_f16: true,
467        supports_f64: true,
468        max_threads_per_block: 1024,
469        compute_capability: None,
470    }
471}
472
473/// Synchronizes a CUDA stream by handle.
474///
475/// # Design Note
476/// This function exists for API compatibility with the `GpuStream` abstraction.
477/// However, AxonML's CUDA backend uses the device's default stream exclusively
478/// (CudaStream is not Send+Sync, so explicit stream management is avoided).
479///
480/// For proper synchronization:
481/// - Use `CudaBackend::synchronize()` which calls `cudaDeviceSynchronize()`
482/// - This synchronizes all pending operations on the device
483///
484/// The handle parameter is accepted but not used because cudarc manages
485/// streams internally and doesn't expose raw stream handles.
486///
487/// # Arguments
488/// * `_handle` - Stream handle (unused, kept for API compatibility)
489#[cfg(feature = "cuda")]
490pub fn stream_synchronize(_handle: usize) {
491    // AxonML uses CudaDevice's default stream for all operations.
492    // Stream-level synchronization requires a CudaDevice reference.
493    // Use CudaBackend::synchronize() for device-level synchronization.
494    //
495    // Without a global device registry, we cannot synchronize here.
496    // This is intentional: synchronization should be explicit via CudaBackend.
497}
498
499/// Synchronize a CUDA stream (no-op without the `cuda` feature).
500#[cfg(not(feature = "cuda"))]
501pub fn stream_synchronize(_handle: usize) {
502    // No-op when CUDA is not available
503}
504
505// =============================================================================
506// cuBLAS Operations
507// =============================================================================
508
509#[cfg(feature = "cuda")]
510impl CudaBackend {
511    /// Performs matrix multiplication using cuBLAS: C = alpha * A @ B + beta * C
512    pub fn gemm_f32(
513        &self,
514        transa: bool,
515        transb: bool,
516        m: usize,
517        n: usize,
518        k: usize,
519        alpha: f32,
520        a: &CudaSlice<f32>,
521        lda: usize,
522        b: &CudaSlice<f32>,
523        ldb: usize,
524        beta: f32,
525        c: &mut CudaSlice<f32>,
526        ldc: usize,
527    ) -> Result<(), CudaError> {
528        let cfg = GemmConfig {
529            transa: if transa {
530                cublasOperation_t::CUBLAS_OP_T
531            } else {
532                cublasOperation_t::CUBLAS_OP_N
533            },
534            transb: if transb {
535                cublasOperation_t::CUBLAS_OP_T
536            } else {
537                cublasOperation_t::CUBLAS_OP_N
538            },
539            m: m as i32,
540            n: n as i32,
541            k: k as i32,
542            alpha,
543            lda: lda as i32,
544            ldb: ldb as i32,
545            beta,
546            ldc: ldc as i32,
547        };
548
549        unsafe { self.blas.gemm(cfg, a, b, c).map_err(CudaError::from) }
550    }
551
552    /// Performs batched matrix multiplication.
553    pub fn gemm_batched_f32(
554        &self,
555        transa: bool,
556        transb: bool,
557        m: usize,
558        n: usize,
559        k: usize,
560        alpha: f32,
561        a_array: &[&CudaSlice<f32>],
562        lda: usize,
563        b_array: &[&CudaSlice<f32>],
564        ldb: usize,
565        beta: f32,
566        c_array: &mut [&mut CudaSlice<f32>],
567        ldc: usize,
568        batch_count: usize,
569    ) -> Result<(), CudaError> {
570        // Execute batched gemm by iterating (cudarc doesn't expose batched directly)
571        for i in 0..batch_count {
572            let cfg = GemmConfig {
573                transa: if transa {
574                    cublasOperation_t::CUBLAS_OP_T
575                } else {
576                    cublasOperation_t::CUBLAS_OP_N
577                },
578                transb: if transb {
579                    cublasOperation_t::CUBLAS_OP_T
580                } else {
581                    cublasOperation_t::CUBLAS_OP_N
582                },
583                m: m as i32,
584                n: n as i32,
585                k: k as i32,
586                alpha,
587                lda: lda as i32,
588                ldb: ldb as i32,
589                beta,
590                ldc: ldc as i32,
591            };
592
593            unsafe {
594                self.blas
595                    .gemm(cfg, a_array[i], b_array[i], c_array[i])
596                    .map_err(CudaError::from)?;
597            }
598        }
599        Ok(())
600    }
601
602    /// Strided batched GEMM using cublasSgemmStridedBatched.
603    /// All batch data in contiguous GPU memory with fixed strides between batches.
604    /// C[i] = alpha * A[i] @ B[i] + beta * C[i] for i in 0..batch_count
605    pub fn gemm_strided_batched_f32(
606        &self,
607        transa: bool,
608        transb: bool,
609        m: usize,
610        n: usize,
611        k: usize,
612        alpha: f32,
613        a: &CudaSlice<f32>,
614        lda: usize,
615        stride_a: i64,
616        b: &CudaSlice<f32>,
617        ldb: usize,
618        stride_b: i64,
619        beta: f32,
620        c: &mut CudaSlice<f32>,
621        ldc: usize,
622        stride_c: i64,
623        batch_count: usize,
624    ) -> Result<(), CudaError> {
625        use cudarc::cublas::result::sgemm_strided_batched;
626        use cudarc::driver::DevicePtr as _;
627        use cudarc::driver::DevicePtrMut as _;
628
629        let op_a = if transa {
630            cublasOperation_t::CUBLAS_OP_T
631        } else {
632            cublasOperation_t::CUBLAS_OP_N
633        };
634        let op_b = if transb {
635            cublasOperation_t::CUBLAS_OP_T
636        } else {
637            cublasOperation_t::CUBLAS_OP_N
638        };
639
640        let (a_devptr, _ga) = a.device_ptr(&self.stream);
641        let (b_devptr, _gb) = b.device_ptr(&self.stream);
642        let (c_devptr, _gc) = c.device_ptr_mut(&self.stream);
643        let a_ptr = a_devptr as *const f32;
644        let b_ptr = b_devptr as *const f32;
645        let c_ptr = c_devptr as *mut f32;
646
647        unsafe {
648            sgemm_strided_batched(
649                *self.blas.handle(),
650                op_a,
651                op_b,
652                m as i32,
653                n as i32,
654                k as i32,
655                &alpha as *const f32,
656                a_ptr,
657                lda as i32,
658                stride_a,
659                b_ptr,
660                ldb as i32,
661                stride_b,
662                &beta as *const f32,
663                c_ptr,
664                ldc as i32,
665                stride_c,
666                batch_count as i32,
667            )
668            .map_err(CudaError::from)
669        }
670    }
671
672    /// Element-wise addition using CUDA kernel.
673    pub fn add_f32(
674        &self,
675        dst: &mut CudaSlice<f32>,
676        a: &CudaSlice<f32>,
677        b: &CudaSlice<f32>,
678        len: usize,
679    ) -> Result<(), CudaError> {
680        let func = self
681            .kernels
682            .get("add_f32")
683            .ok_or_else(|| CudaError::KernelNotFound("add_f32".to_string()))?;
684
685        let cfg = cuda_kernels::launch_config(len);
686        unsafe {
687            self.stream
688                .launch_builder(func)
689                .arg(a)
690                .arg(b)
691                .arg(dst)
692                .arg(&(len as u32))
693                .launch(cfg)
694                .map(|_| ())
695                .map_err(|e| CudaError::DriverError(e.to_string()))?;
696        }
697        Ok(())
698    }
699
700    /// Scalar multiplication using CUDA kernel.
701    pub fn scale_f32(
702        &self,
703        dst: &mut CudaSlice<f32>,
704        alpha: f32,
705        len: usize,
706    ) -> Result<(), CudaError> {
707        let func = self
708            .kernels
709            .get("scale_f32")
710            .ok_or_else(|| CudaError::KernelNotFound("scale_f32".to_string()))?;
711
712        let cfg = cuda_kernels::launch_config(len);
713        unsafe {
714            self.stream
715                .launch_builder(func)
716                .arg(dst)
717                .arg(&alpha)
718                .arg(&(len as u32))
719                .launch(cfg)
720                .map(|_| ())
721                .map_err(|e| CudaError::DriverError(e.to_string()))?;
722        }
723        Ok(())
724    }
725
726    /// Element-wise multiplication using CUDA kernel.
727    pub fn mul_f32(
728        &self,
729        dst: &mut CudaSlice<f32>,
730        a: &CudaSlice<f32>,
731        b: &CudaSlice<f32>,
732        len: usize,
733    ) -> Result<(), CudaError> {
734        let func = self
735            .kernels
736            .get("mul_f32")
737            .ok_or_else(|| CudaError::KernelNotFound("mul_f32".to_string()))?;
738
739        let cfg = cuda_kernels::launch_config(len);
740        unsafe {
741            self.stream
742                .launch_builder(func)
743                .arg(a)
744                .arg(b)
745                .arg(dst)
746                .arg(&(len as u32))
747                .launch(cfg)
748                .map(|_| ())
749                .map_err(|e| CudaError::DriverError(e.to_string()))?;
750        }
751        Ok(())
752    }
753
754    /// ReLU activation using CUDA kernel.
755    pub fn relu_f32(
756        &self,
757        dst: &mut CudaSlice<f32>,
758        src: &CudaSlice<f32>,
759        len: usize,
760    ) -> Result<(), CudaError> {
761        let func = self
762            .kernels
763            .get("relu_f32")
764            .ok_or_else(|| CudaError::KernelNotFound("relu_f32".to_string()))?;
765
766        let cfg = cuda_kernels::launch_config(len);
767        unsafe {
768            self.stream
769                .launch_builder(func)
770                .arg(src)
771                .arg(dst)
772                .arg(&(len as u32))
773                .launch(cfg)
774                .map(|_| ())
775                .map_err(|e| CudaError::DriverError(e.to_string()))?;
776        }
777        Ok(())
778    }
779
780    /// Sigmoid activation using CUDA kernel.
781    pub fn sigmoid_f32(
782        &self,
783        dst: &mut CudaSlice<f32>,
784        src: &CudaSlice<f32>,
785        len: usize,
786    ) -> Result<(), CudaError> {
787        let func = self
788            .kernels
789            .get("sigmoid_f32")
790            .ok_or_else(|| CudaError::KernelNotFound("sigmoid_f32".to_string()))?;
791
792        let cfg = cuda_kernels::launch_config(len);
793        unsafe {
794            self.stream
795                .launch_builder(func)
796                .arg(src)
797                .arg(dst)
798                .arg(&(len as u32))
799                .launch(cfg)
800                .map(|_| ())
801                .map_err(|e| CudaError::DriverError(e.to_string()))?;
802        }
803        Ok(())
804    }
805
806    /// Tanh activation using CUDA kernel.
807    pub fn tanh_f32(
808        &self,
809        dst: &mut CudaSlice<f32>,
810        src: &CudaSlice<f32>,
811        len: usize,
812    ) -> Result<(), CudaError> {
813        let func = self
814            .kernels
815            .get("tanh_f32")
816            .ok_or_else(|| CudaError::KernelNotFound("tanh_f32".to_string()))?;
817
818        let cfg = cuda_kernels::launch_config(len);
819        unsafe {
820            self.stream
821                .launch_builder(func)
822                .arg(src)
823                .arg(dst)
824                .arg(&(len as u32))
825                .launch(cfg)
826                .map(|_| ())
827                .map_err(|e| CudaError::DriverError(e.to_string()))?;
828        }
829        Ok(())
830    }
831
832    /// Element-wise subtraction using CUDA kernel.
833    pub fn sub_f32(
834        &self,
835        dst: &mut CudaSlice<f32>,
836        a: &CudaSlice<f32>,
837        b: &CudaSlice<f32>,
838        len: usize,
839    ) -> Result<(), CudaError> {
840        let func = self
841            .kernels
842            .get("sub_f32")
843            .ok_or_else(|| CudaError::KernelNotFound("sub_f32".to_string()))?;
844        let cfg = cuda_kernels::launch_config(len);
845        unsafe {
846            self.stream
847                .launch_builder(func)
848                .arg(a)
849                .arg(b)
850                .arg(dst)
851                .arg(&(len as u32))
852                .launch(cfg)
853                .map(|_| ())
854                .map_err(|e| CudaError::DriverError(e.to_string()))?;
855        }
856        Ok(())
857    }
858
859    /// Element-wise division using CUDA kernel.
860    pub fn div_f32(
861        &self,
862        dst: &mut CudaSlice<f32>,
863        a: &CudaSlice<f32>,
864        b: &CudaSlice<f32>,
865        len: usize,
866    ) -> Result<(), CudaError> {
867        let func = self
868            .kernels
869            .get("div_f32")
870            .ok_or_else(|| CudaError::KernelNotFound("div_f32".to_string()))?;
871        let cfg = cuda_kernels::launch_config(len);
872        unsafe {
873            self.stream
874                .launch_builder(func)
875                .arg(a)
876                .arg(b)
877                .arg(dst)
878                .arg(&(len as u32))
879                .launch(cfg)
880                .map(|_| ())
881                .map_err(|e| CudaError::DriverError(e.to_string()))?;
882        }
883        Ok(())
884    }
885
886    // =========================================================================
887    // Broadcast Element-wise Operations
888    // =========================================================================
889
890    /// Broadcast addition: out[i] = a[i] + b[i % b_len]
891    /// `a` is the larger tensor (n elements), `b` is broadcast (b_len elements).
892    pub fn broadcast_add_f32(
893        &self,
894        dst: &mut CudaSlice<f32>,
895        a: &CudaSlice<f32>,
896        b: &CudaSlice<f32>,
897        n: usize,
898        b_len: usize,
899    ) -> Result<(), CudaError> {
900        let func = self
901            .kernels
902            .get("broadcast_add_f32")
903            .ok_or_else(|| CudaError::KernelNotFound("broadcast_add_f32".to_string()))?;
904        let cfg = cuda_kernels::launch_config(n);
905        unsafe {
906            self.stream
907                .launch_builder(func)
908                .arg(a)
909                .arg(b)
910                .arg(dst)
911                .arg(&(n as u32))
912                .arg(&(b_len as u32))
913                .launch(cfg)
914                .map(|_| ())
915                .map_err(|e| CudaError::DriverError(e.to_string()))?;
916        }
917        Ok(())
918    }
919
920    /// Broadcast subtraction: out[i] = a[i] - b[i % b_len]
921    pub fn broadcast_sub_f32(
922        &self,
923        dst: &mut CudaSlice<f32>,
924        a: &CudaSlice<f32>,
925        b: &CudaSlice<f32>,
926        n: usize,
927        b_len: usize,
928    ) -> Result<(), CudaError> {
929        let func = self
930            .kernels
931            .get("broadcast_sub_f32")
932            .ok_or_else(|| CudaError::KernelNotFound("broadcast_sub_f32".to_string()))?;
933        let cfg = cuda_kernels::launch_config(n);
934        unsafe {
935            self.stream
936                .launch_builder(func)
937                .arg(a)
938                .arg(b)
939                .arg(dst)
940                .arg(&(n as u32))
941                .arg(&(b_len as u32))
942                .launch(cfg)
943                .map(|_| ())
944                .map_err(|e| CudaError::DriverError(e.to_string()))?;
945        }
946        Ok(())
947    }
948
949    /// Broadcast multiplication: out[i] = a[i] * b[i % b_len]
950    pub fn broadcast_mul_f32(
951        &self,
952        dst: &mut CudaSlice<f32>,
953        a: &CudaSlice<f32>,
954        b: &CudaSlice<f32>,
955        n: usize,
956        b_len: usize,
957    ) -> Result<(), CudaError> {
958        let func = self
959            .kernels
960            .get("broadcast_mul_f32")
961            .ok_or_else(|| CudaError::KernelNotFound("broadcast_mul_f32".to_string()))?;
962        let cfg = cuda_kernels::launch_config(n);
963        unsafe {
964            self.stream
965                .launch_builder(func)
966                .arg(a)
967                .arg(b)
968                .arg(dst)
969                .arg(&(n as u32))
970                .arg(&(b_len as u32))
971                .launch(cfg)
972                .map(|_| ())
973                .map_err(|e| CudaError::DriverError(e.to_string()))?;
974        }
975        Ok(())
976    }
977
978    /// Broadcast division: out[i] = a[i] / b[i % b_len]
979    pub fn broadcast_div_f32(
980        &self,
981        dst: &mut CudaSlice<f32>,
982        a: &CudaSlice<f32>,
983        b: &CudaSlice<f32>,
984        n: usize,
985        b_len: usize,
986    ) -> Result<(), CudaError> {
987        let func = self
988            .kernels
989            .get("broadcast_div_f32")
990            .ok_or_else(|| CudaError::KernelNotFound("broadcast_div_f32".to_string()))?;
991        let cfg = cuda_kernels::launch_config(n);
992        unsafe {
993            self.stream
994                .launch_builder(func)
995                .arg(a)
996                .arg(b)
997                .arg(dst)
998                .arg(&(n as u32))
999                .arg(&(b_len as u32))
1000                .launch(cfg)
1001                .map(|_| ())
1002                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1003        }
1004        Ok(())
1005    }
1006
1007    /// Reverse broadcast addition: out[i] = a[i % a_len] + b[i]
1008    pub fn broadcast_add_rev_f32(
1009        &self,
1010        dst: &mut CudaSlice<f32>,
1011        a: &CudaSlice<f32>,
1012        b: &CudaSlice<f32>,
1013        n: usize,
1014        a_len: usize,
1015    ) -> Result<(), CudaError> {
1016        let func = self
1017            .kernels
1018            .get("broadcast_add_rev_f32")
1019            .ok_or_else(|| CudaError::KernelNotFound("broadcast_add_rev_f32".to_string()))?;
1020        let cfg = cuda_kernels::launch_config(n);
1021        unsafe {
1022            self.stream
1023                .launch_builder(func)
1024                .arg(a)
1025                .arg(b)
1026                .arg(dst)
1027                .arg(&(n as u32))
1028                .arg(&(a_len as u32))
1029                .launch(cfg)
1030                .map(|_| ())
1031                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1032        }
1033        Ok(())
1034    }
1035
1036    /// Reverse broadcast subtraction: out[i] = a[i % a_len] - b[i]
1037    pub fn broadcast_sub_rev_f32(
1038        &self,
1039        dst: &mut CudaSlice<f32>,
1040        a: &CudaSlice<f32>,
1041        b: &CudaSlice<f32>,
1042        n: usize,
1043        a_len: usize,
1044    ) -> Result<(), CudaError> {
1045        let func = self
1046            .kernels
1047            .get("broadcast_sub_rev_f32")
1048            .ok_or_else(|| CudaError::KernelNotFound("broadcast_sub_rev_f32".to_string()))?;
1049        let cfg = cuda_kernels::launch_config(n);
1050        unsafe {
1051            self.stream
1052                .launch_builder(func)
1053                .arg(a)
1054                .arg(b)
1055                .arg(dst)
1056                .arg(&(n as u32))
1057                .arg(&(a_len as u32))
1058                .launch(cfg)
1059                .map(|_| ())
1060                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1061        }
1062        Ok(())
1063    }
1064
1065    /// Reverse broadcast multiplication: out[i] = a[i % a_len] * b[i]
1066    pub fn broadcast_mul_rev_f32(
1067        &self,
1068        dst: &mut CudaSlice<f32>,
1069        a: &CudaSlice<f32>,
1070        b: &CudaSlice<f32>,
1071        n: usize,
1072        a_len: usize,
1073    ) -> Result<(), CudaError> {
1074        let func = self
1075            .kernels
1076            .get("broadcast_mul_rev_f32")
1077            .ok_or_else(|| CudaError::KernelNotFound("broadcast_mul_rev_f32".to_string()))?;
1078        let cfg = cuda_kernels::launch_config(n);
1079        unsafe {
1080            self.stream
1081                .launch_builder(func)
1082                .arg(a)
1083                .arg(b)
1084                .arg(dst)
1085                .arg(&(n as u32))
1086                .arg(&(a_len as u32))
1087                .launch(cfg)
1088                .map(|_| ())
1089                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1090        }
1091        Ok(())
1092    }
1093
1094    /// Reverse broadcast division: out[i] = a[i % a_len] / b[i]
1095    pub fn broadcast_div_rev_f32(
1096        &self,
1097        dst: &mut CudaSlice<f32>,
1098        a: &CudaSlice<f32>,
1099        b: &CudaSlice<f32>,
1100        n: usize,
1101        a_len: usize,
1102    ) -> Result<(), CudaError> {
1103        let func = self
1104            .kernels
1105            .get("broadcast_div_rev_f32")
1106            .ok_or_else(|| CudaError::KernelNotFound("broadcast_div_rev_f32".to_string()))?;
1107        let cfg = cuda_kernels::launch_config(n);
1108        unsafe {
1109            self.stream
1110                .launch_builder(func)
1111                .arg(a)
1112                .arg(b)
1113                .arg(dst)
1114                .arg(&(n as u32))
1115                .arg(&(a_len as u32))
1116                .launch(cfg)
1117                .map(|_| ())
1118                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1119        }
1120        Ok(())
1121    }
1122
1123    /// Element-wise negation using CUDA kernel.
1124    pub fn neg_f32(
1125        &self,
1126        dst: &mut CudaSlice<f32>,
1127        src: &CudaSlice<f32>,
1128        len: usize,
1129    ) -> Result<(), CudaError> {
1130        let func = self
1131            .kernels
1132            .get("neg_f32")
1133            .ok_or_else(|| CudaError::KernelNotFound("neg_f32".to_string()))?;
1134        let cfg = cuda_kernels::launch_config(len);
1135        unsafe {
1136            self.stream
1137                .launch_builder(func)
1138                .arg(src)
1139                .arg(dst)
1140                .arg(&(len as u32))
1141                .launch(cfg)
1142                .map(|_| ())
1143                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1144        }
1145        Ok(())
1146    }
1147
1148    /// Element-wise power using CUDA kernel: dst[i] = a[i] ^ b[i].
1149    pub fn pow_f32(
1150        &self,
1151        dst: &mut CudaSlice<f32>,
1152        a: &CudaSlice<f32>,
1153        b: &CudaSlice<f32>,
1154        len: usize,
1155    ) -> Result<(), CudaError> {
1156        let func = self
1157            .kernels
1158            .get("pow_f32")
1159            .ok_or_else(|| CudaError::KernelNotFound("pow_f32".to_string()))?;
1160        let cfg = cuda_kernels::launch_config(len);
1161        unsafe {
1162            self.stream
1163                .launch_builder(func)
1164                .arg(a)
1165                .arg(b)
1166                .arg(dst)
1167                .arg(&(len as u32))
1168                .launch(cfg)
1169                .map(|_| ())
1170                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1171        }
1172        Ok(())
1173    }
1174
1175    /// Element-wise power with scalar exponent: dst[i] = src[i] ^ exp.
1176    pub fn pow_scalar_f32(
1177        &self,
1178        dst: &mut CudaSlice<f32>,
1179        src: &CudaSlice<f32>,
1180        exp: f32,
1181        len: usize,
1182    ) -> Result<(), CudaError> {
1183        let func = self
1184            .kernels
1185            .get("pow_scalar_f32")
1186            .ok_or_else(|| CudaError::KernelNotFound("pow_scalar_f32".to_string()))?;
1187        let cfg = cuda_kernels::launch_config(len);
1188        unsafe {
1189            self.stream
1190                .launch_builder(func)
1191                .arg(src)
1192                .arg(&exp)
1193                .arg(dst)
1194                .arg(&(len as u32))
1195                .launch(cfg)
1196                .map(|_| ())
1197                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1198        }
1199        Ok(())
1200    }
1201
1202    /// Element-wise exp using CUDA kernel.
1203    pub fn exp_f32(
1204        &self,
1205        dst: &mut CudaSlice<f32>,
1206        src: &CudaSlice<f32>,
1207        len: usize,
1208    ) -> Result<(), CudaError> {
1209        let func = self
1210            .kernels
1211            .get("exp_f32")
1212            .ok_or_else(|| CudaError::KernelNotFound("exp_f32".to_string()))?;
1213        let cfg = cuda_kernels::launch_config(len);
1214        unsafe {
1215            self.stream
1216                .launch_builder(func)
1217                .arg(src)
1218                .arg(dst)
1219                .arg(&(len as u32))
1220                .launch(cfg)
1221                .map(|_| ())
1222                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1223        }
1224        Ok(())
1225    }
1226
1227    /// Element-wise log using CUDA kernel.
1228    pub fn log_f32(
1229        &self,
1230        dst: &mut CudaSlice<f32>,
1231        src: &CudaSlice<f32>,
1232        len: usize,
1233    ) -> Result<(), CudaError> {
1234        let func = self
1235            .kernels
1236            .get("log_f32")
1237            .ok_or_else(|| CudaError::KernelNotFound("log_f32".to_string()))?;
1238        let cfg = cuda_kernels::launch_config(len);
1239        unsafe {
1240            self.stream
1241                .launch_builder(func)
1242                .arg(src)
1243                .arg(dst)
1244                .arg(&(len as u32))
1245                .launch(cfg)
1246                .map(|_| ())
1247                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1248        }
1249        Ok(())
1250    }
1251
1252    /// Element-wise sqrt using CUDA kernel.
1253    pub fn sqrt_f32(
1254        &self,
1255        dst: &mut CudaSlice<f32>,
1256        src: &CudaSlice<f32>,
1257        len: usize,
1258    ) -> Result<(), CudaError> {
1259        let func = self
1260            .kernels
1261            .get("sqrt_f32")
1262            .ok_or_else(|| CudaError::KernelNotFound("sqrt_f32".to_string()))?;
1263        let cfg = cuda_kernels::launch_config(len);
1264        unsafe {
1265            self.stream
1266                .launch_builder(func)
1267                .arg(src)
1268                .arg(dst)
1269                .arg(&(len as u32))
1270                .launch(cfg)
1271                .map(|_| ())
1272                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1273        }
1274        Ok(())
1275    }
1276
1277    /// GELU activation using CUDA kernel.
1278    pub fn gelu_f32(
1279        &self,
1280        dst: &mut CudaSlice<f32>,
1281        src: &CudaSlice<f32>,
1282        len: usize,
1283    ) -> Result<(), CudaError> {
1284        let func = self
1285            .kernels
1286            .get("gelu_f32")
1287            .ok_or_else(|| CudaError::KernelNotFound("gelu_f32".to_string()))?;
1288        let cfg = cuda_kernels::launch_config(len);
1289        unsafe {
1290            self.stream
1291                .launch_builder(func)
1292                .arg(src)
1293                .arg(dst)
1294                .arg(&(len as u32))
1295                .launch(cfg)
1296                .map(|_| ())
1297                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1298        }
1299        Ok(())
1300    }
1301
1302    /// SiLU activation using CUDA kernel.
1303    pub fn silu_f32(
1304        &self,
1305        dst: &mut CudaSlice<f32>,
1306        src: &CudaSlice<f32>,
1307        len: usize,
1308    ) -> Result<(), CudaError> {
1309        let func = self
1310            .kernels
1311            .get("silu_f32")
1312            .ok_or_else(|| CudaError::KernelNotFound("silu_f32".to_string()))?;
1313        let cfg = cuda_kernels::launch_config(len);
1314        unsafe {
1315            self.stream
1316                .launch_builder(func)
1317                .arg(src)
1318                .arg(dst)
1319                .arg(&(len as u32))
1320                .launch(cfg)
1321                .map(|_| ())
1322                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1323        }
1324        Ok(())
1325    }
1326
1327    /// Scalar addition: dst[i] = src[i] + scalar.
1328    pub fn add_scalar_f32(
1329        &self,
1330        dst: &mut CudaSlice<f32>,
1331        src: &CudaSlice<f32>,
1332        scalar: f32,
1333        len: usize,
1334    ) -> Result<(), CudaError> {
1335        let func = self
1336            .kernels
1337            .get("add_scalar_f32")
1338            .ok_or_else(|| CudaError::KernelNotFound("add_scalar_f32".to_string()))?;
1339        let cfg = cuda_kernels::launch_config(len);
1340        unsafe {
1341            self.stream
1342                .launch_builder(func)
1343                .arg(src)
1344                .arg(&scalar)
1345                .arg(dst)
1346                .arg(&(len as u32))
1347                .launch(cfg)
1348                .map(|_| ())
1349                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1350        }
1351        Ok(())
1352    }
1353
1354    /// ReLU backward using CUDA kernel.
1355    pub fn relu_backward_f32(
1356        &self,
1357        dst: &mut CudaSlice<f32>,
1358        grad_output: &CudaSlice<f32>,
1359        input: &CudaSlice<f32>,
1360        len: usize,
1361    ) -> Result<(), CudaError> {
1362        let func = self
1363            .kernels
1364            .get("relu_backward_f32")
1365            .ok_or_else(|| CudaError::KernelNotFound("relu_backward_f32".to_string()))?;
1366        let cfg = cuda_kernels::launch_config(len);
1367        unsafe {
1368            self.stream
1369                .launch_builder(func)
1370                .arg(grad_output)
1371                .arg(input)
1372                .arg(dst)
1373                .arg(&(len as u32))
1374                .launch(cfg)
1375                .map(|_| ())
1376                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1377        }
1378        Ok(())
1379    }
1380
1381    /// Sigmoid backward using CUDA kernel.
1382    pub fn sigmoid_backward_f32(
1383        &self,
1384        dst: &mut CudaSlice<f32>,
1385        grad_output: &CudaSlice<f32>,
1386        output: &CudaSlice<f32>,
1387        len: usize,
1388    ) -> Result<(), CudaError> {
1389        let func = self
1390            .kernels
1391            .get("sigmoid_backward_f32")
1392            .ok_or_else(|| CudaError::KernelNotFound("sigmoid_backward_f32".to_string()))?;
1393        let cfg = cuda_kernels::launch_config(len);
1394        unsafe {
1395            self.stream
1396                .launch_builder(func)
1397                .arg(grad_output)
1398                .arg(output)
1399                .arg(dst)
1400                .arg(&(len as u32))
1401                .launch(cfg)
1402                .map(|_| ())
1403                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1404        }
1405        Ok(())
1406    }
1407
1408    /// Tanh backward using CUDA kernel.
1409    pub fn tanh_backward_f32(
1410        &self,
1411        dst: &mut CudaSlice<f32>,
1412        grad_output: &CudaSlice<f32>,
1413        output: &CudaSlice<f32>,
1414        len: usize,
1415    ) -> Result<(), CudaError> {
1416        let func = self
1417            .kernels
1418            .get("tanh_backward_f32")
1419            .ok_or_else(|| CudaError::KernelNotFound("tanh_backward_f32".to_string()))?;
1420        let cfg = cuda_kernels::launch_config(len);
1421        unsafe {
1422            self.stream
1423                .launch_builder(func)
1424                .arg(grad_output)
1425                .arg(output)
1426                .arg(dst)
1427                .arg(&(len as u32))
1428                .launch(cfg)
1429                .map(|_| ())
1430                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1431        }
1432        Ok(())
1433    }
1434
1435    /// Sum along a dimension. Tensor viewed as [outer_size, dim_size, inner_size].
1436    /// Output has outer_size * inner_size elements.
1437    pub fn sum_dim_f32(
1438        &self,
1439        dst: &mut CudaSlice<f32>,
1440        src: &CudaSlice<f32>,
1441        outer_size: usize,
1442        dim_size: usize,
1443        inner_size: usize,
1444    ) -> Result<(), CudaError> {
1445        let func = self
1446            .kernels
1447            .get("sum_dim_f32")
1448            .ok_or_else(|| CudaError::KernelNotFound("sum_dim_f32".to_string()))?;
1449        let out_len = outer_size * inner_size;
1450        let cfg = cuda_kernels::launch_config(out_len);
1451        unsafe {
1452            self.stream
1453                .launch_builder(func)
1454                .arg(src)
1455                .arg(dst)
1456                .arg(&(outer_size as u32))
1457                .arg(&(dim_size as u32))
1458                .arg(&(inner_size as u32))
1459                .launch(cfg)
1460                .map(|_| ())
1461                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1462        }
1463        Ok(())
1464    }
1465
1466    /// Softmax along last dimension, in-place.
1467    /// Data layout: num_rows x row_size, each row gets softmax independently.
1468    /// One block per row, 256 threads per block.
1469    pub fn softmax_row_f32(
1470        &self,
1471        data: &mut CudaSlice<f32>,
1472        num_rows: usize,
1473        row_size: usize,
1474    ) -> Result<(), CudaError> {
1475        let func = self
1476            .kernels
1477            .get("softmax_row_f32")
1478            .ok_or_else(|| CudaError::KernelNotFound("softmax_row_f32".to_string()))?;
1479        // One block per row
1480        let cfg = LaunchConfig {
1481            grid_dim: (num_rows as u32, 1, 1),
1482            block_dim: (BLOCK_SIZE, 1, 1),
1483            shared_mem_bytes: BLOCK_SIZE * 4,
1484        };
1485        unsafe {
1486            self.stream
1487                .launch_builder(func)
1488                .arg(data)
1489                .arg(&(num_rows as u32))
1490                .arg(&(row_size as u32))
1491                .launch(cfg)
1492                .map(|_| ())
1493                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1494        }
1495        Ok(())
1496    }
1497
1498    /// Broadcast copy: out[i] = src[i % src_len], for n output elements.
1499    pub fn broadcast_copy_f32(
1500        &self,
1501        dst: &mut CudaSlice<f32>,
1502        src: &CudaSlice<f32>,
1503        n: usize,
1504        src_len: usize,
1505    ) -> Result<(), CudaError> {
1506        let func = self
1507            .kernels
1508            .get("broadcast_copy_f32")
1509            .ok_or_else(|| CudaError::KernelNotFound("broadcast_copy_f32".to_string()))?;
1510        let cfg = cuda_kernels::launch_config(n);
1511        unsafe {
1512            self.stream
1513                .launch_builder(func)
1514                .arg(src)
1515                .arg(dst)
1516                .arg(&(n as u32))
1517                .arg(&(src_len as u32))
1518                .launch(cfg)
1519                .map(|_| ())
1520                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1521        }
1522        Ok(())
1523    }
1524
1525    /// LayerNorm: per-row normalization with affine transform on GPU.
1526    /// One block per row, 256 threads. Computes mean, variance, normalize, apply gamma/beta.
1527    pub fn layer_norm_f32(
1528        &self,
1529        dst: &mut CudaSlice<f32>,
1530        input: &CudaSlice<f32>,
1531        gamma: &CudaSlice<f32>,
1532        beta: &CudaSlice<f32>,
1533        norm_size: usize,
1534        eps: f32,
1535        num_rows: usize,
1536    ) -> Result<(), CudaError> {
1537        let func = self
1538            .kernels
1539            .get("layer_norm_f32")
1540            .ok_or_else(|| CudaError::KernelNotFound("layer_norm_f32".to_string()))?;
1541        let cfg = LaunchConfig {
1542            grid_dim: (num_rows as u32, 1, 1),
1543            block_dim: (BLOCK_SIZE, 1, 1),
1544            shared_mem_bytes: BLOCK_SIZE * 4,
1545        };
1546        unsafe {
1547            self.stream
1548                .launch_builder(func)
1549                .arg(input)
1550                .arg(gamma)
1551                .arg(beta)
1552                .arg(dst)
1553                .arg(&(norm_size as u32))
1554                .arg(&eps)
1555                .arg(&(num_rows as u32))
1556                .launch(cfg)
1557                .map(|_| ())
1558                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1559        }
1560        Ok(())
1561    }
1562
1563    /// Softmax backward: per-row backward pass.
1564    /// result[i] = softmax[i] * (grad[i] - dot), where dot = sum(softmax * grad) per row.
1565    /// One block per row, 256 threads.
1566    pub fn softmax_backward_row_f32(
1567        &self,
1568        dst: &mut CudaSlice<f32>,
1569        softmax_output: &CudaSlice<f32>,
1570        grad_output: &CudaSlice<f32>,
1571        num_rows: usize,
1572        row_size: usize,
1573    ) -> Result<(), CudaError> {
1574        let func = self
1575            .kernels
1576            .get("softmax_backward_row_f32")
1577            .ok_or_else(|| CudaError::KernelNotFound("softmax_backward_row_f32".to_string()))?;
1578        let cfg = LaunchConfig {
1579            grid_dim: (num_rows as u32, 1, 1),
1580            block_dim: (BLOCK_SIZE, 1, 1),
1581            shared_mem_bytes: BLOCK_SIZE * 4,
1582        };
1583        unsafe {
1584            self.stream
1585                .launch_builder(func)
1586                .arg(softmax_output)
1587                .arg(grad_output)
1588                .arg(dst)
1589                .arg(&(num_rows as u32))
1590                .arg(&(row_size as u32))
1591                .launch(cfg)
1592                .map(|_| ())
1593                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1594        }
1595        Ok(())
1596    }
1597
1598    /// LayerNorm backward: compute d_input on GPU.
1599    /// One block per row, 256 threads. Computes mean, var, sum_dy, sum_dy_xhat, then d_input.
1600    pub fn layer_norm_backward_dinput_f32(
1601        &self,
1602        d_input: &mut CudaSlice<f32>,
1603        grad_output: &CudaSlice<f32>,
1604        input: &CudaSlice<f32>,
1605        gamma: &CudaSlice<f32>,
1606        norm_size: usize,
1607        eps: f32,
1608        num_rows: usize,
1609    ) -> Result<(), CudaError> {
1610        let func = self
1611            .kernels
1612            .get("layer_norm_backward_dinput_f32")
1613            .ok_or_else(|| {
1614                CudaError::KernelNotFound("layer_norm_backward_dinput_f32".to_string())
1615            })?;
1616        let cfg = LaunchConfig {
1617            grid_dim: (num_rows as u32, 1, 1),
1618            block_dim: (BLOCK_SIZE, 1, 1),
1619            shared_mem_bytes: BLOCK_SIZE * 4 * 2, // two shared arrays
1620        };
1621        unsafe {
1622            self.stream
1623                .launch_builder(func)
1624                .arg(grad_output)
1625                .arg(input)
1626                .arg(gamma)
1627                .arg(d_input)
1628                .arg(&(norm_size as u32))
1629                .arg(&eps)
1630                .arg(&(num_rows as u32))
1631                .launch(cfg)
1632                .map(|_| ())
1633                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1634        }
1635        Ok(())
1636    }
1637
1638    /// LayerNorm backward: compute d_weight and d_bias on GPU.
1639    /// One thread per element in norm_size. Each thread loops over all rows.
1640    pub fn layer_norm_backward_dweight_dbias_f32(
1641        &self,
1642        d_weight: &mut CudaSlice<f32>,
1643        d_bias: &mut CudaSlice<f32>,
1644        grad_output: &CudaSlice<f32>,
1645        input: &CudaSlice<f32>,
1646        norm_size: usize,
1647        eps: f32,
1648        num_rows: usize,
1649    ) -> Result<(), CudaError> {
1650        let func = self
1651            .kernels
1652            .get("layer_norm_backward_dweight_dbias_f32")
1653            .ok_or_else(|| {
1654                CudaError::KernelNotFound("layer_norm_backward_dweight_dbias_f32".to_string())
1655            })?;
1656        let cfg = cuda_kernels::launch_config(norm_size);
1657        unsafe {
1658            self.stream
1659                .launch_builder(func)
1660                .arg(grad_output)
1661                .arg(input)
1662                .arg(d_weight)
1663                .arg(d_bias)
1664                .arg(&(norm_size as u32))
1665                .arg(&eps)
1666                .arg(&(num_rows as u32))
1667                .launch(cfg)
1668                .map(|_| ())
1669                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1670        }
1671        Ok(())
1672    }
1673
1674    /// Gather elements from src using index array: out[i] = src[indices[i]]
1675    pub fn gather_contiguous_f32(
1676        &self,
1677        dst: &mut CudaSlice<f32>,
1678        src: &CudaSlice<f32>,
1679        indices: &CudaSlice<u32>,
1680        n: usize,
1681    ) -> Result<(), CudaError> {
1682        let func = self
1683            .kernels
1684            .get("gather_contiguous_f32")
1685            .ok_or_else(|| CudaError::KernelNotFound("gather_contiguous_f32".to_string()))?;
1686        let cfg = cuda_kernels::launch_config(n);
1687        unsafe {
1688            self.stream
1689                .launch_builder(func)
1690                .arg(src)
1691                .arg(indices)
1692                .arg(dst)
1693                .arg(&(n as u32))
1694                .launch(cfg)
1695                .map(|_| ())
1696                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1697        }
1698        Ok(())
1699    }
1700
1701    /// Embedding scatter-add: atomically accumulates gradients into weight_grad.
1702    /// Each thread handles one element of grad_src (total = num_indices * emb_dim).
1703    pub fn embedding_scatter_add_f32(
1704        &self,
1705        grad_src: &CudaSlice<f32>,
1706        indices: &CudaSlice<u32>,
1707        weight_grad: &mut CudaSlice<f32>,
1708        total_n: usize,
1709        emb_dim: usize,
1710    ) -> Result<(), CudaError> {
1711        let func = self
1712            .kernels
1713            .get("embedding_scatter_add_f32")
1714            .ok_or_else(|| CudaError::KernelNotFound("embedding_scatter_add_f32".to_string()))?;
1715        let cfg = cuda_kernels::launch_config(total_n);
1716        unsafe {
1717            self.stream
1718                .launch_builder(func)
1719                .arg(grad_src)
1720                .arg(indices)
1721                .arg(weight_grad)
1722                .arg(&(total_n as u32))
1723                .arg(&(emb_dim as u32))
1724                .launch(cfg)
1725                .map(|_| ())
1726                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1727        }
1728        Ok(())
1729    }
1730
1731    /// Fused Adam optimizer step: updates param, exp_avg, exp_avg_sq in-place on GPU.
1732    /// Eliminates the GPU->CPU->GPU copy in standard Adam.
1733    #[allow(clippy::too_many_arguments)]
1734    pub fn adam_step_f32(
1735        &self,
1736        param: &mut CudaSlice<f32>,
1737        grad: &CudaSlice<f32>,
1738        exp_avg: &mut CudaSlice<f32>,
1739        exp_avg_sq: &mut CudaSlice<f32>,
1740        n: usize,
1741        lr: f32,
1742        beta1: f32,
1743        beta2: f32,
1744        eps: f32,
1745        weight_decay: f32,
1746        bias_correction1: f32,
1747        bias_correction2: f32,
1748    ) -> Result<(), CudaError> {
1749        let func = self
1750            .kernels
1751            .get("adam_step_f32")
1752            .ok_or_else(|| CudaError::KernelNotFound("adam_step_f32".to_string()))?;
1753        let cfg = cuda_kernels::launch_config(n);
1754        unsafe {
1755            self.stream
1756                .launch_builder(func)
1757                .arg(param)
1758                .arg(grad)
1759                .arg(exp_avg)
1760                .arg(exp_avg_sq)
1761                .arg(&(n as u32))
1762                .arg(&lr)
1763                .arg(&beta1)
1764                .arg(&beta2)
1765                .arg(&eps)
1766                .arg(&weight_decay)
1767                .arg(&bias_correction1)
1768                .arg(&bias_correction2)
1769                .launch(cfg)
1770                .map(|_| ())
1771                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1772        }
1773        Ok(())
1774    }
1775
1776    /// Compute sum of squares of all elements (for gradient norm).
1777    /// Result is atomically accumulated into output[0].
1778    pub fn grad_norm_sq_f32(
1779        &self,
1780        data: &CudaSlice<f32>,
1781        output: &mut CudaSlice<f32>,
1782        n: usize,
1783    ) -> Result<(), CudaError> {
1784        let func = self
1785            .kernels
1786            .get("grad_norm_sq_f32")
1787            .ok_or_else(|| CudaError::KernelNotFound("grad_norm_sq_f32".to_string()))?;
1788        let cfg = cuda_kernels::launch_config(n);
1789        unsafe {
1790            self.stream
1791                .launch_builder(func)
1792                .arg(data)
1793                .arg(output)
1794                .arg(&(n as u32))
1795                .launch(cfg)
1796                .map(|_| ())
1797                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1798        }
1799        Ok(())
1800    }
1801
1802    /// Scale all elements in-place: data[i] *= scale
1803    pub fn grad_scale_f32(
1804        &self,
1805        data: &mut CudaSlice<f32>,
1806        n: usize,
1807        scale: f32,
1808    ) -> Result<(), CudaError> {
1809        let func = self
1810            .kernels
1811            .get("grad_scale_f32")
1812            .ok_or_else(|| CudaError::KernelNotFound("grad_scale_f32".to_string()))?;
1813        let cfg = cuda_kernels::launch_config(n);
1814        unsafe {
1815            self.stream
1816                .launch_builder(func)
1817                .arg(data)
1818                .arg(&(n as u32))
1819                .arg(&scale)
1820                .launch(cfg)
1821                .map(|_| ())
1822                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1823        }
1824        Ok(())
1825    }
1826
1827    /// CrossEntropy forward: fused softmax + NLL loss.
1828    /// One block per batch item, 256 threads per block.
1829    /// Returns per-sample losses and softmax probabilities (for backward).
1830    pub fn cross_entropy_fwd_f32(
1831        &self,
1832        logits: &CudaSlice<f32>,
1833        targets: &CudaSlice<f32>,
1834        losses: &mut CudaSlice<f32>,
1835        softmax_out: &mut CudaSlice<f32>,
1836        batch_size: usize,
1837        num_classes: usize,
1838    ) -> Result<(), CudaError> {
1839        let func = self
1840            .kernels
1841            .get("cross_entropy_fwd_f32")
1842            .ok_or_else(|| CudaError::KernelNotFound("cross_entropy_fwd_f32".to_string()))?;
1843        let cfg = LaunchConfig {
1844            grid_dim: (batch_size as u32, 1, 1),
1845            block_dim: (BLOCK_SIZE, 1, 1),
1846            shared_mem_bytes: BLOCK_SIZE * 4,
1847        };
1848        unsafe {
1849            self.stream
1850                .launch_builder(func)
1851                .arg(logits)
1852                .arg(targets)
1853                .arg(losses)
1854                .arg(softmax_out)
1855                .arg(&(num_classes as u32))
1856                .launch(cfg)
1857                .map(|_| ())
1858                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1859        }
1860        Ok(())
1861    }
1862
1863    /// CrossEntropy backward: grad = (softmax - one_hot(target)) * grad_output.
1864    /// Elementwise kernel, one thread per element.
1865    pub fn cross_entropy_bwd_f32(
1866        &self,
1867        softmax_probs: &CudaSlice<f32>,
1868        targets: &CudaSlice<f32>,
1869        grad_output: &CudaSlice<f32>,
1870        grad_input: &mut CudaSlice<f32>,
1871        batch_size: usize,
1872        num_classes: usize,
1873    ) -> Result<(), CudaError> {
1874        let func = self
1875            .kernels
1876            .get("cross_entropy_bwd_f32")
1877            .ok_or_else(|| CudaError::KernelNotFound("cross_entropy_bwd_f32".to_string()))?;
1878        let total = batch_size * num_classes;
1879        let cfg = cuda_kernels::launch_config(total);
1880        unsafe {
1881            self.stream
1882                .launch_builder(func)
1883                .arg(softmax_probs)
1884                .arg(targets)
1885                .arg(grad_output)
1886                .arg(grad_input)
1887                .arg(&(batch_size as u32))
1888                .arg(&(num_classes as u32))
1889                .launch(cfg)
1890                .map(|_| ())
1891                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1892        }
1893        Ok(())
1894    }
1895
1896    /// Zero-fills a GPU allocation using cudaMemset.
1897    #[cfg(feature = "cuda")]
1898    pub fn memset_zeros_f32(&self, dst: &mut CudaSlice<f32>) -> Result<(), CudaError> {
1899        self.stream
1900            .memset_zeros(dst)
1901            .map_err(|e| CudaError::DriverError(e.to_string()))
1902    }
1903
1904    /// Device-to-device copy of `count` f32 elements with source and destination offsets.
1905    /// Copies src[src_offset..src_offset+count] → dst[dst_offset..dst_offset+count].
1906    #[cfg(feature = "cuda")]
1907    pub fn memcpy_dtod_f32(
1908        &self,
1909        dst: &mut CudaSlice<f32>,
1910        dst_offset: usize,
1911        src: &CudaSlice<f32>,
1912        src_offset: usize,
1913        count: usize,
1914    ) -> Result<(), CudaError> {
1915        use cudarc::driver::DevicePtr as _;
1916        let (src_ptr, _guard_s) = src.device_ptr(&self.stream);
1917        let src_ptr =
1918            src_ptr + (src_offset * std::mem::size_of::<f32>()) as cudarc::driver::sys::CUdeviceptr;
1919        use cudarc::driver::DevicePtrMut as _;
1920        let (dst_ptr, _guard_d) = dst.device_ptr_mut(&self.stream);
1921        let dst_ptr =
1922            dst_ptr + (dst_offset * std::mem::size_of::<f32>()) as cudarc::driver::sys::CUdeviceptr;
1923        let size = count * std::mem::size_of::<f32>();
1924        unsafe {
1925            cudarc::driver::result::memcpy_dtod_sync(dst_ptr, src_ptr, size)
1926                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1927        }
1928        Ok(())
1929    }
1930}
1931
1932// =============================================================================
1933// Attention Mask Expansion GPU Operations
1934// =============================================================================
1935
1936#[cfg(feature = "cuda")]
1937impl CudaBackend {
1938    /// Expand causal mask [T, S] → [B, H, T, S] with 0→-1e9 conversion, entirely on GPU.
1939    pub fn mask_expand_causal_f32(
1940        &self,
1941        mask: &CudaSlice<f32>,
1942        output: &mut CudaSlice<f32>,
1943        total_n: usize,
1944        tgt_len: usize,
1945        src_len: usize,
1946    ) -> Result<(), CudaError> {
1947        let func = self
1948            .kernels
1949            .get("mask_expand_causal_f32")
1950            .ok_or_else(|| CudaError::KernelNotFound("mask_expand_causal_f32".to_string()))?;
1951        let cfg = cuda_kernels::launch_config(total_n);
1952        unsafe {
1953            self.stream
1954                .launch_builder(func)
1955                .arg(mask)
1956                .arg(output)
1957                .arg(&(total_n as u32))
1958                .arg(&(tgt_len as u32))
1959                .arg(&(src_len as u32))
1960                .launch(cfg)
1961                .map(|_| ())
1962                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1963        }
1964        Ok(())
1965    }
1966
1967    /// Expand padding mask [B, S] → [B, H, T, S] with 0→-1e9 conversion, entirely on GPU.
1968    pub fn mask_expand_padding_f32(
1969        &self,
1970        mask: &CudaSlice<f32>,
1971        output: &mut CudaSlice<f32>,
1972        total_n: usize,
1973        num_heads: usize,
1974        tgt_len: usize,
1975        src_len: usize,
1976    ) -> Result<(), CudaError> {
1977        let func = self
1978            .kernels
1979            .get("mask_expand_padding_f32")
1980            .ok_or_else(|| CudaError::KernelNotFound("mask_expand_padding_f32".to_string()))?;
1981        let cfg = cuda_kernels::launch_config(total_n);
1982        unsafe {
1983            self.stream
1984                .launch_builder(func)
1985                .arg(mask)
1986                .arg(output)
1987                .arg(&(total_n as u32))
1988                .arg(&(num_heads as u32))
1989                .arg(&(tgt_len as u32))
1990                .arg(&(src_len as u32))
1991                .launch(cfg)
1992                .map(|_| ())
1993                .map_err(|e| CudaError::DriverError(e.to_string()))?;
1994        }
1995        Ok(())
1996    }
1997}
1998
1999// =============================================================================
2000// Strided Gather (GPU-native contiguous)
2001// =============================================================================
2002
2003#[cfg(feature = "cuda")]
2004impl CudaBackend {
2005    /// Gather elements from a strided tensor layout into contiguous output on GPU.
2006    /// Replaces the CPU index computation in contiguous_gpu().
2007    pub fn strided_gather_f32(
2008        &self,
2009        src: &CudaSlice<f32>,
2010        dst: &mut CudaSlice<f32>,
2011        strides: &CudaSlice<i64>,
2012        shape: &CudaSlice<u32>,
2013        ndim: usize,
2014        offset: usize,
2015        total_n: usize,
2016    ) -> Result<(), CudaError> {
2017        let func = self
2018            .kernels
2019            .get("strided_gather_f32")
2020            .ok_or_else(|| CudaError::KernelNotFound("strided_gather_f32".to_string()))?;
2021
2022        let cfg = cuda_kernels::launch_config(total_n);
2023        unsafe {
2024            self.stream
2025                .launch_builder(func)
2026                .arg(src)
2027                .arg(dst)
2028                .arg(strides)
2029                .arg(shape)
2030                .arg(&(ndim as u32))
2031                .arg(&(offset as u32))
2032                .arg(&(total_n as u32))
2033                .launch(cfg)
2034                .map(|_| ())
2035                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2036        }
2037        Ok(())
2038    }
2039
2040    // =========================================================================
2041    // Fused LSTM Gate Kernel
2042    // =========================================================================
2043
2044    /// Fused LSTM gate computation on GPU.
2045    ///
2046    /// Takes pre-computed gates (ih + hh from cuBLAS GEMM) and c_prev,
2047    /// applies sigmoid/tanh activations and cell/hidden state update
2048    /// in a single kernel launch.
2049    ///
2050    /// - `gates`: [batch, 4*hidden] = x@W_ih^T + b_ih + h@W_hh^T + b_hh
2051    /// - `c_prev`: [batch, hidden]
2052    /// - `h_new`: [batch, hidden] output
2053    /// - `c_new`: [batch, hidden] output
2054    pub fn lstm_gates_f32(
2055        &self,
2056        gates: &CudaSlice<f32>,
2057        c_prev: &CudaSlice<f32>,
2058        h_new: &mut CudaSlice<f32>,
2059        c_new: &mut CudaSlice<f32>,
2060        hidden_size: usize,
2061        total: usize,
2062    ) -> Result<(), CudaError> {
2063        let func = self
2064            .kernels
2065            .get("lstm_gates_f32")
2066            .ok_or_else(|| CudaError::KernelNotFound("lstm_gates_f32".to_string()))?;
2067        let cfg = cuda_kernels::launch_config(total);
2068        unsafe {
2069            self.stream
2070                .launch_builder(func)
2071                .arg(gates)
2072                .arg(c_prev)
2073                .arg(h_new)
2074                .arg(c_new)
2075                .arg(&(hidden_size as u32))
2076                .arg(&(total as u32))
2077                .launch(cfg)
2078                .map(|_| ())
2079                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2080        }
2081        Ok(())
2082    }
2083
2084    // =========================================================================
2085    // Fused LSTM Gate Backward Kernel
2086    // =========================================================================
2087
2088    /// Fused LSTM gate backward computation on GPU.
2089    ///
2090    /// Given saved forward state and incoming gradients, computes gate gradients
2091    /// and cell gradient to previous timestep in a single kernel launch.
2092    ///
2093    /// - `gates`: [batch, 4*hidden] pre-activation gates from forward
2094    /// - `c_prev`: [batch, hidden] previous cell state
2095    /// - `c_new`: [batch, hidden] cell state from forward
2096    /// - `grad_h`: [batch, hidden] gradient from output
2097    /// - `grad_c_next`: [batch, hidden] gradient from next timestep cell
2098    /// - `grad_gates`: [batch, 4*hidden] output gate gradients
2099    /// - `grad_c_prev`: [batch, hidden] output cell gradient to prev timestep
2100    pub fn lstm_gates_backward_f32(
2101        &self,
2102        gates: &CudaSlice<f32>,
2103        c_prev: &CudaSlice<f32>,
2104        c_new: &CudaSlice<f32>,
2105        grad_h: &CudaSlice<f32>,
2106        grad_c_next: &CudaSlice<f32>,
2107        grad_gates: &mut CudaSlice<f32>,
2108        grad_c_prev: &mut CudaSlice<f32>,
2109        hidden_size: usize,
2110        total: usize,
2111    ) -> Result<(), CudaError> {
2112        let func = self
2113            .kernels
2114            .get("lstm_gates_backward_f32")
2115            .ok_or_else(|| CudaError::KernelNotFound("lstm_gates_backward_f32".to_string()))?;
2116        let cfg = cuda_kernels::launch_config(total);
2117        unsafe {
2118            self.stream
2119                .launch_builder(func)
2120                .arg(gates)
2121                .arg(c_prev)
2122                .arg(c_new)
2123                .arg(grad_h)
2124                .arg(grad_c_next)
2125                .arg(grad_gates)
2126                .arg(grad_c_prev)
2127                .arg(&(hidden_size as u32))
2128                .arg(&(total as u32))
2129                .launch(cfg)
2130                .map(|_| ())
2131                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2132        }
2133        Ok(())
2134    }
2135
2136    // =========================================================================
2137    // Fused GRU Gate Kernel
2138    // =========================================================================
2139
2140    /// Fused GRU gate computation on GPU.
2141    ///
2142    /// - `gates_ih`: [batch, 3*hidden] = x@W_ih^T + b_ih
2143    /// - `gates_hh`: [batch, 3*hidden] = h@W_hh^T + b_hh
2144    /// - `h_prev`: [batch, hidden]
2145    /// - `h_new`: [batch, hidden] output
2146    pub fn gru_gates_f32(
2147        &self,
2148        gates_ih: &CudaSlice<f32>,
2149        gates_hh: &CudaSlice<f32>,
2150        h_prev: &CudaSlice<f32>,
2151        h_new: &mut CudaSlice<f32>,
2152        hidden_size: usize,
2153        total: usize,
2154    ) -> Result<(), CudaError> {
2155        let func = self
2156            .kernels
2157            .get("gru_gates_f32")
2158            .ok_or_else(|| CudaError::KernelNotFound("gru_gates_f32".to_string()))?;
2159        let cfg = cuda_kernels::launch_config(total);
2160        unsafe {
2161            self.stream
2162                .launch_builder(func)
2163                .arg(gates_ih)
2164                .arg(gates_hh)
2165                .arg(h_prev)
2166                .arg(h_new)
2167                .arg(&(hidden_size as u32))
2168                .arg(&(total as u32))
2169                .launch(cfg)
2170                .map(|_| ())
2171                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2172        }
2173        Ok(())
2174    }
2175
2176    // =========================================================================
2177    // Fused GRU Gate Backward Kernel
2178    // =========================================================================
2179
2180    /// Fused GRU gate backward computation on GPU.
2181    ///
2182    /// Given saved forward state and incoming gradient, computes ih/hh gate
2183    /// gradients and hidden state gradient to previous timestep.
2184    ///
2185    /// - `gates_ih`: [batch, 3*hidden] pre-activation ih gates from forward
2186    /// - `gates_hh`: [batch, 3*hidden] pre-activation hh gates from forward
2187    /// - `h_prev`: [batch, hidden] previous hidden state
2188    /// - `grad_h_new`: [batch, hidden] gradient from output
2189    /// - `grad_gates_ih`: [batch, 3*hidden] output ih gate gradients
2190    /// - `grad_gates_hh`: [batch, 3*hidden] output hh gate gradients
2191    /// - `grad_h_prev`: [batch, hidden] output gradient to prev hidden
2192    pub fn gru_gates_backward_f32(
2193        &self,
2194        gates_ih: &CudaSlice<f32>,
2195        gates_hh: &CudaSlice<f32>,
2196        h_prev: &CudaSlice<f32>,
2197        grad_h_new: &CudaSlice<f32>,
2198        grad_gates_ih: &mut CudaSlice<f32>,
2199        grad_gates_hh: &mut CudaSlice<f32>,
2200        grad_h_prev: &mut CudaSlice<f32>,
2201        hidden_size: usize,
2202        total: usize,
2203    ) -> Result<(), CudaError> {
2204        let func = self
2205            .kernels
2206            .get("gru_gates_backward_f32")
2207            .ok_or_else(|| CudaError::KernelNotFound("gru_gates_backward_f32".to_string()))?;
2208        let cfg = cuda_kernels::launch_config(total);
2209        unsafe {
2210            self.stream
2211                .launch_builder(func)
2212                .arg(gates_ih)
2213                .arg(gates_hh)
2214                .arg(h_prev)
2215                .arg(grad_h_new)
2216                .arg(grad_gates_ih)
2217                .arg(grad_gates_hh)
2218                .arg(grad_h_prev)
2219                .arg(&(hidden_size as u32))
2220                .arg(&(total as u32))
2221                .launch(cfg)
2222                .map(|_| ())
2223                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2224        }
2225        Ok(())
2226    }
2227
2228    // =========================================================================
2229    // Fused BatchNorm Forward Kernels
2230    // =========================================================================
2231
2232    /// BatchNorm pass 1: compute per-channel sum and sum_sq via atomics.
2233    pub fn batchnorm_stats_f32(
2234        &self,
2235        x: &CudaSlice<f32>,
2236        sum_out: &mut CudaSlice<f32>,
2237        sum_sq_out: &mut CudaSlice<f32>,
2238        n: usize,
2239        c: usize,
2240        spatial: usize,
2241    ) -> Result<(), CudaError> {
2242        let func = self
2243            .kernels
2244            .get("batchnorm_stats_f32")
2245            .ok_or_else(|| CudaError::KernelNotFound("batchnorm_stats_f32".to_string()))?;
2246        let total = n * c * spatial;
2247        let cfg = cuda_kernels::launch_config(total);
2248        unsafe {
2249            self.stream
2250                .launch_builder(func)
2251                .arg(x)
2252                .arg(sum_out)
2253                .arg(sum_sq_out)
2254                .arg(&(n as u32))
2255                .arg(&(c as u32))
2256                .arg(&(spatial as u32))
2257                .launch(cfg)
2258                .map(|_| ())
2259                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2260        }
2261        Ok(())
2262    }
2263
2264    /// BatchNorm pass 2: normalize + affine transform using pre-computed mean/var.
2265    pub fn batchnorm_norm_f32(
2266        &self,
2267        x: &CudaSlice<f32>,
2268        mean: &CudaSlice<f32>,
2269        var: &CudaSlice<f32>,
2270        gamma: &CudaSlice<f32>,
2271        beta: &CudaSlice<f32>,
2272        y: &mut CudaSlice<f32>,
2273        eps: f32,
2274        c: usize,
2275        spatial: usize,
2276        total: usize,
2277    ) -> Result<(), CudaError> {
2278        let func = self
2279            .kernels
2280            .get("batchnorm_norm_f32")
2281            .ok_or_else(|| CudaError::KernelNotFound("batchnorm_norm_f32".to_string()))?;
2282        let cfg = cuda_kernels::launch_config(total);
2283        unsafe {
2284            self.stream
2285                .launch_builder(func)
2286                .arg(x)
2287                .arg(mean)
2288                .arg(var)
2289                .arg(gamma)
2290                .arg(beta)
2291                .arg(y)
2292                .arg(&eps)
2293                .arg(&(c as u32))
2294                .arg(&(spatial as u32))
2295                .arg(&(total as u32))
2296                .launch(cfg)
2297                .map(|_| ())
2298                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2299        }
2300        Ok(())
2301    }
2302}
2303
2304// =============================================================================
2305// Fused Scaled Dot-Product Attention
2306// =============================================================================
2307
2308#[cfg(feature = "cuda")]
2309impl CudaBackend {
2310    /// Fused attention forward: Q @ K^T * scale -> softmax -> @ V
2311    /// without materializing the full N*N attention matrix.
2312    ///
2313    /// Q: [B, H, Tq, D], K: [B, H, Tk, D], V: [B, H, Tk, D]
2314    /// Output: [B, H, Tq, D]
2315    pub fn fused_attention_fwd_f32(
2316        &self,
2317        q: &CudaSlice<f32>,
2318        k: &CudaSlice<f32>,
2319        v: &CudaSlice<f32>,
2320        output: &mut CudaSlice<f32>,
2321        scale: f32,
2322        batch_size: usize,
2323        num_heads: usize,
2324        tgt_len: usize,
2325        src_len: usize,
2326        head_dim: usize,
2327        is_causal: bool,
2328    ) -> Result<(), CudaError> {
2329        let func = self
2330            .kernels
2331            .get("fused_attention_fwd_f32")
2332            .ok_or_else(|| CudaError::KernelNotFound("fused_attention_fwd_f32".to_string()))?;
2333        let total_rows = batch_size * num_heads * tgt_len;
2334        let cfg = cuda_kernels::launch_config(total_rows);
2335        let is_causal_u32: u32 = if is_causal { 1 } else { 0 };
2336        unsafe {
2337            self.stream
2338                .launch_builder(func)
2339                .arg(q)
2340                .arg(k)
2341                .arg(v)
2342                .arg(output)
2343                .arg(&scale)
2344                .arg(&(batch_size as u32))
2345                .arg(&(num_heads as u32))
2346                .arg(&(tgt_len as u32))
2347                .arg(&(src_len as u32))
2348                .arg(&(head_dim as u32))
2349                .arg(&is_causal_u32)
2350                .launch(cfg)
2351                .map(|_| ())
2352                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2353        }
2354        Ok(())
2355    }
2356}
2357
2358// =============================================================================
2359// Fused Attention Backward (recomputation-based, memory-efficient)
2360// =============================================================================
2361
2362#[cfg(feature = "cuda")]
2363impl CudaBackend {
2364    /// Fused attention backward: recomputes attention weights from Q, K, O
2365    /// and computes grad_Q, grad_K, grad_V without materializing the N*N matrix.
2366    ///
2367    /// Q, K, V: [B, H, Tq/Tk, D]
2368    /// O: forward output [B, H, Tq, D]
2369    /// grad_O: gradient of loss w.r.t. output [B, H, Tq, D]
2370    /// grad_Q, grad_K, grad_V: output buffers (must be zero-initialized)
2371    pub fn fused_attention_bwd_f32(
2372        &self,
2373        q: &CudaSlice<f32>,
2374        k: &CudaSlice<f32>,
2375        v: &CudaSlice<f32>,
2376        o: &CudaSlice<f32>,
2377        grad_o: &CudaSlice<f32>,
2378        grad_q: &mut CudaSlice<f32>,
2379        grad_k: &mut CudaSlice<f32>,
2380        grad_v: &mut CudaSlice<f32>,
2381        scale: f32,
2382        batch_size: usize,
2383        num_heads: usize,
2384        tgt_len: usize,
2385        src_len: usize,
2386        head_dim: usize,
2387        is_causal: bool,
2388    ) -> Result<(), CudaError> {
2389        let func = self
2390            .kernels
2391            .get("fused_attention_bwd_f32")
2392            .ok_or_else(|| CudaError::KernelNotFound("fused_attention_bwd_f32".to_string()))?;
2393        let total_rows = batch_size * num_heads * tgt_len;
2394        let cfg = cuda_kernels::launch_config(total_rows);
2395        let is_causal_u32: u32 = if is_causal { 1 } else { 0 };
2396        unsafe {
2397            self.stream
2398                .launch_builder(func)
2399                .arg(q)
2400                .arg(k)
2401                .arg(v)
2402                .arg(o)
2403                .arg(grad_o)
2404                .arg(grad_q)
2405                .arg(grad_k)
2406                .arg(grad_v)
2407                .arg(&scale)
2408                .arg(&(batch_size as u32))
2409                .arg(&(num_heads as u32))
2410                .arg(&(tgt_len as u32))
2411                .arg(&(src_len as u32))
2412                .arg(&(head_dim as u32))
2413                .arg(&is_causal_u32)
2414                .launch(cfg)
2415                .map(|_| ())
2416                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2417        }
2418        Ok(())
2419    }
2420}
2421
2422// =============================================================================
2423// Conv2d GPU Operations (im2col + GEMM)
2424// =============================================================================
2425
2426#[cfg(feature = "cuda")]
2427impl CudaBackend {
2428    /// Launch the GPU im2col kernel.
2429    ///
2430    /// Unfolds one batch element's input patches into a column matrix.
2431    /// - `input`: device buffer for one batch element [C_in, H, W]
2432    /// - `col`: output device buffer [C_in*kH*kW, out_H*out_W]
2433    /// - `params`: device buffer with u32[10] = {H, W, kH, kW, pH, pW, sH, sW, oH, oW}
2434    pub fn im2col_f32(
2435        &self,
2436        input: &CudaSlice<f32>,
2437        col: &mut CudaSlice<f32>,
2438        params: &CudaSlice<u32>,
2439        n: usize,
2440    ) -> Result<(), CudaError> {
2441        let func = self
2442            .kernels
2443            .get("im2col_f32")
2444            .ok_or_else(|| CudaError::KernelNotFound("im2col_f32".to_string()))?;
2445
2446        let cfg = cuda_kernels::launch_config(n);
2447        unsafe {
2448            self.stream
2449                .launch_builder(func)
2450                .arg(input)
2451                .arg(col)
2452                .arg(params)
2453                .arg(&(n as u32))
2454                .launch(cfg)
2455                .map(|_| ())
2456                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2457        }
2458        Ok(())
2459    }
2460
2461    /// Launch the GPU col2im kernel (reverse of im2col).
2462    ///
2463    /// Scatters column matrix back to input spatial positions using atomicAdd.
2464    /// The output buffer MUST be zero-initialized before calling this.
2465    pub fn col2im_f32(
2466        &self,
2467        col: &CudaSlice<f32>,
2468        output: &mut CudaSlice<f32>,
2469        params: &CudaSlice<u32>,
2470        n: usize,
2471    ) -> Result<(), CudaError> {
2472        let func = self
2473            .kernels
2474            .get("col2im_f32")
2475            .ok_or_else(|| CudaError::KernelNotFound("col2im_f32".to_string()))?;
2476
2477        let cfg = cuda_kernels::launch_config(n);
2478        unsafe {
2479            self.stream
2480                .launch_builder(func)
2481                .arg(col)
2482                .arg(output)
2483                .arg(params)
2484                .arg(&(n as u32))
2485                .launch(cfg)
2486                .map(|_| ())
2487                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2488        }
2489        Ok(())
2490    }
2491
2492    /// Launch the GPU bias_add_channels kernel (in-place).
2493    ///
2494    /// Adds bias per output channel: data[i] += bias[i / spatial_size]
2495    pub fn bias_add_channels_f32(
2496        &self,
2497        data: &mut CudaSlice<f32>,
2498        bias: &CudaSlice<f32>,
2499        spatial: usize,
2500        n: usize,
2501    ) -> Result<(), CudaError> {
2502        let func = self
2503            .kernels
2504            .get("bias_add_channels_f32")
2505            .ok_or_else(|| CudaError::KernelNotFound("bias_add_channels_f32".to_string()))?;
2506
2507        let cfg = cuda_kernels::launch_config(n);
2508        unsafe {
2509            self.stream
2510                .launch_builder(func)
2511                .arg(data)
2512                .arg(bias)
2513                .arg(&(spatial as u32))
2514                .arg(&(n as u32))
2515                .launch(cfg)
2516                .map(|_| ())
2517                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2518        }
2519        Ok(())
2520    }
2521
2522    /// Full GPU conv2d forward: im2col on GPU → cuBLAS GEMM → bias add on GPU.
2523    ///
2524    /// Handles groups=1 only. Returns output as flat Vec<f32> in NCHW layout.
2525    /// Returns None if any GPU operation fails (caller falls back to CPU).
2526    pub fn conv2d_forward(
2527        &self,
2528        input: &[f32],
2529        weight: &[f32],
2530        bias: Option<&[f32]>,
2531        batch_size: usize,
2532        in_channels: usize,
2533        in_height: usize,
2534        in_width: usize,
2535        out_channels: usize,
2536        kernel_h: usize,
2537        kernel_w: usize,
2538        stride_h: usize,
2539        stride_w: usize,
2540        pad_h: usize,
2541        pad_w: usize,
2542    ) -> Option<Vec<f32>> {
2543        let out_h = (in_height + 2 * pad_h - kernel_h) / stride_h + 1;
2544        let out_w = (in_width + 2 * pad_w - kernel_w) / stride_w + 1;
2545        let col_h = in_channels * kernel_h * kernel_w;
2546        let col_w = out_h * out_w;
2547        let col_n = col_h * col_w;
2548        let spatial = out_h * out_w;
2549        let out_per_batch = out_channels * spatial;
2550        let in_per_batch = in_channels * in_height * in_width;
2551
2552        use super::cuda_pool::pool_alloc;
2553
2554        // Upload weight [out_channels, col_h] to GPU (once for all batches)
2555        let weight_gpu = self.htod_copy(weight).ok()?;
2556
2557        // Upload bias if present
2558        let bias_gpu = bias.and_then(|b| self.htod_copy(b).ok());
2559
2560        // Upload im2col parameters as u32 buffer (reused across batches)
2561        let im2col_params: [u32; 10] = [
2562            in_height as u32,
2563            in_width as u32,
2564            kernel_h as u32,
2565            kernel_w as u32,
2566            pad_h as u32,
2567            pad_w as u32,
2568            stride_h as u32,
2569            stride_w as u32,
2570            out_h as u32,
2571            out_w as u32,
2572        ];
2573        let params_gpu = self.htod_copy(&im2col_params[..]).ok()?;
2574
2575        // Pool-allocate col buffer on GPU (reused across batches)
2576        let mut col_gpu = pool_alloc(col_n).ok()?;
2577
2578        // Pool-allocate output buffer on GPU
2579        let mut batch_out_gpu = pool_alloc(out_per_batch).ok()?;
2580
2581        let mut output = vec![0.0f32; batch_size * out_per_batch];
2582
2583        for b in 0..batch_size {
2584            // Upload input for this batch element
2585            let input_slice = &input[b * in_per_batch..(b + 1) * in_per_batch];
2586            let input_gpu = self.htod_copy(input_slice).ok()?;
2587
2588            // GPU im2col: input [C_in, H, W] → col [col_h, col_w]
2589            self.im2col_f32(&input_gpu, &mut col_gpu, &params_gpu, col_n)
2590                .ok()?;
2591
2592            // GPU GEMM: out = weight @ col
2593            // weight: [out_channels, col_h] (row-major)
2594            // col: [col_h, col_w] (row-major)
2595            // result: [out_channels, col_w] (row-major)
2596            //
2597            // cuBLAS column-major: C^T = B^T @ A^T
2598            // m=col_w, n=out_channels, k=col_h
2599            self.gemm_f32(
2600                false,
2601                false,
2602                col_w,
2603                out_channels,
2604                col_h,
2605                1.0,
2606                &col_gpu,
2607                col_w,
2608                &weight_gpu,
2609                col_h,
2610                0.0,
2611                &mut batch_out_gpu,
2612                col_w,
2613            )
2614            .ok()?;
2615
2616            // GPU bias add (in-place on batch_out_gpu)
2617            if let Some(ref bg) = bias_gpu {
2618                self.bias_add_channels_f32(&mut batch_out_gpu, bg, spatial, out_per_batch)
2619                    .ok()?;
2620            }
2621
2622            // Download output for this batch
2623            let batch_result = self.dtoh_copy(&batch_out_gpu).ok()?;
2624            output[b * out_per_batch..(b + 1) * out_per_batch]
2625                .copy_from_slice(&batch_result[..out_per_batch]);
2626        }
2627
2628        Some(output)
2629    }
2630}
2631
2632/// Public GPU conv2d forward — callable from other crates.
2633///
2634/// Returns Some(output_vec) on success, None if CUDA unavailable or operation fails.
2635/// Only handles groups=1. Caller should fall back to CPU for grouped convolution.
2636#[cfg(feature = "cuda")]
2637pub fn cuda_conv2d_forward(
2638    input: &[f32],
2639    weight: &[f32],
2640    bias: Option<&[f32]>,
2641    batch_size: usize,
2642    in_channels: usize,
2643    in_height: usize,
2644    in_width: usize,
2645    out_channels: usize,
2646    kernel_h: usize,
2647    kernel_w: usize,
2648    stride_h: usize,
2649    stride_w: usize,
2650    pad_h: usize,
2651    pad_w: usize,
2652) -> Option<Vec<f32>> {
2653    let cuda = get_cuda_backend()?;
2654    cuda.conv2d_forward(
2655        input,
2656        weight,
2657        bias,
2658        batch_size,
2659        in_channels,
2660        in_height,
2661        in_width,
2662        out_channels,
2663        kernel_h,
2664        kernel_w,
2665        stride_h,
2666        stride_w,
2667        pad_h,
2668        pad_w,
2669    )
2670}
2671
2672/// Stub when CUDA feature is disabled.
2673#[cfg(not(feature = "cuda"))]
2674pub fn cuda_conv2d_forward(
2675    _input: &[f32],
2676    _weight: &[f32],
2677    _bias: Option<&[f32]>,
2678    _batch_size: usize,
2679    _in_channels: usize,
2680    _in_height: usize,
2681    _in_width: usize,
2682    _out_channels: usize,
2683    _kernel_h: usize,
2684    _kernel_w: usize,
2685    _stride_h: usize,
2686    _stride_w: usize,
2687    _pad_h: usize,
2688    _pad_w: usize,
2689) -> Option<Vec<f32>> {
2690    None
2691}
2692
2693// =============================================================================
2694// Pooling GPU Operations (MaxPool2d + AvgPool2d)
2695// =============================================================================
2696
2697#[cfg(feature = "cuda")]
2698impl CudaBackend {
2699    /// Launch MaxPool2d forward kernel on GPU (device-resident).
2700    ///
2701    /// - `input`: GPU slice [N*C*H*W]
2702    /// - `output`: GPU slice [N*C*out_h*out_w] (pre-allocated, zero-init)
2703    /// - `indices`: GPU slice [N*C*out_h*out_w] (pre-allocated, i32)
2704    /// - `params`: GPU u32[8] = {H, W, kH, kW, sH, sW, pH, pW}
2705    pub fn maxpool2d_fwd_f32(
2706        &self,
2707        input: &CudaSlice<f32>,
2708        output: &mut CudaSlice<f32>,
2709        indices: &mut CudaSlice<i32>,
2710        params: &CudaSlice<u32>,
2711        channels: usize,
2712        out_h: usize,
2713        out_w: usize,
2714        total: usize,
2715    ) -> Result<(), CudaError> {
2716        let func = self
2717            .kernels
2718            .get("maxpool2d_fwd_f32")
2719            .ok_or_else(|| CudaError::KernelNotFound("maxpool2d_fwd_f32".to_string()))?;
2720
2721        let cfg = cuda_kernels::launch_config(total);
2722        unsafe {
2723            self.stream
2724                .launch_builder(func)
2725                .arg(input)
2726                .arg(output)
2727                .arg(indices)
2728                .arg(params)
2729                .arg(&(channels as u32))
2730                .arg(&(out_h as u32))
2731                .arg(&(out_w as u32))
2732                .arg(&(total as u32))
2733                .launch(cfg)
2734                .map(|_| ())
2735                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2736        }
2737        Ok(())
2738    }
2739
2740    /// Launch MaxPool2d backward kernel on GPU (device-resident).
2741    ///
2742    /// Scatters grad_output to grad_input at max index positions using atomicAdd.
2743    /// `grad_input` must be zero-initialized.
2744    pub fn maxpool2d_bwd_f32(
2745        &self,
2746        grad_output: &CudaSlice<f32>,
2747        indices: &CudaSlice<i32>,
2748        grad_input: &mut CudaSlice<f32>,
2749        total: usize,
2750    ) -> Result<(), CudaError> {
2751        let func = self
2752            .kernels
2753            .get("maxpool2d_bwd_f32")
2754            .ok_or_else(|| CudaError::KernelNotFound("maxpool2d_bwd_f32".to_string()))?;
2755
2756        let cfg = cuda_kernels::launch_config(total);
2757        unsafe {
2758            self.stream
2759                .launch_builder(func)
2760                .arg(grad_output)
2761                .arg(indices)
2762                .arg(grad_input)
2763                .arg(&(total as u32))
2764                .launch(cfg)
2765                .map(|_| ())
2766                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2767        }
2768        Ok(())
2769    }
2770
2771    /// Launch AvgPool2d forward kernel on GPU (device-resident).
2772    ///
2773    /// - `params`: GPU u32[9] = {H, W, kH, kW, sH, sW, pH, pW, count_include_pad}
2774    pub fn avgpool2d_fwd_f32(
2775        &self,
2776        input: &CudaSlice<f32>,
2777        output: &mut CudaSlice<f32>,
2778        params: &CudaSlice<u32>,
2779        channels: usize,
2780        out_h: usize,
2781        out_w: usize,
2782        total: usize,
2783    ) -> Result<(), CudaError> {
2784        let func = self
2785            .kernels
2786            .get("avgpool2d_fwd_f32")
2787            .ok_or_else(|| CudaError::KernelNotFound("avgpool2d_fwd_f32".to_string()))?;
2788
2789        let cfg = cuda_kernels::launch_config(total);
2790        unsafe {
2791            self.stream
2792                .launch_builder(func)
2793                .arg(input)
2794                .arg(output)
2795                .arg(params)
2796                .arg(&(channels as u32))
2797                .arg(&(out_h as u32))
2798                .arg(&(out_w as u32))
2799                .arg(&(total as u32))
2800                .launch(cfg)
2801                .map(|_| ())
2802                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2803        }
2804        Ok(())
2805    }
2806
2807    /// Launch AvgPool2d backward kernel on GPU (device-resident).
2808    ///
2809    /// `grad_input` must be zero-initialized.
2810    pub fn avgpool2d_bwd_f32(
2811        &self,
2812        grad_output: &CudaSlice<f32>,
2813        grad_input: &mut CudaSlice<f32>,
2814        params: &CudaSlice<u32>,
2815        channels: usize,
2816        out_h: usize,
2817        out_w: usize,
2818        total: usize,
2819    ) -> Result<(), CudaError> {
2820        let func = self
2821            .kernels
2822            .get("avgpool2d_bwd_f32")
2823            .ok_or_else(|| CudaError::KernelNotFound("avgpool2d_bwd_f32".to_string()))?;
2824
2825        let cfg = cuda_kernels::launch_config(total);
2826        unsafe {
2827            self.stream
2828                .launch_builder(func)
2829                .arg(grad_output)
2830                .arg(grad_input)
2831                .arg(params)
2832                .arg(&(channels as u32))
2833                .arg(&(out_h as u32))
2834                .arg(&(out_w as u32))
2835                .arg(&(total as u32))
2836                .launch(cfg)
2837                .map(|_| ())
2838                .map_err(|e| CudaError::DriverError(e.to_string()))?;
2839        }
2840        Ok(())
2841    }
2842}
2843
2844// =============================================================================
2845// Pinned (Page-Locked) Host Memory
2846// =============================================================================
2847
2848/// A page-locked (pinned) host memory buffer for fast CPU-to-GPU transfers.
2849///
2850/// Pinned memory is allocated via `cuMemAllocHost` and is not subject to
2851/// OS paging, enabling the GPU to DMA directly from the host buffer. This
2852/// typically provides 2-3x faster host-to-device transfer compared to
2853/// pageable (regular) memory.
2854///
2855/// # Usage
2856/// ```ignore
2857/// use axonml_core::backends::cuda::PinnedBuffer;
2858///
2859/// let data = vec![1.0f32; 1024];
2860/// let pinned = PinnedBuffer::from_slice(&data).expect("pin failed");
2861/// // Use pinned.as_slice() as the source for htod transfers
2862/// ```
2863#[cfg(feature = "cuda")]
2864pub struct PinnedBuffer {
2865    /// Raw pointer to the pinned host allocation (from cuMemAllocHost).
2866    ptr: *mut f32,
2867    /// Number of f32 elements in the buffer.
2868    len: usize,
2869}
2870
2871#[cfg(feature = "cuda")]
2872unsafe impl Send for PinnedBuffer {}
2873#[cfg(feature = "cuda")]
2874unsafe impl Sync for PinnedBuffer {}
2875
2876#[cfg(feature = "cuda")]
2877impl PinnedBuffer {
2878    /// Allocates a pinned host buffer and copies `data` into it.
2879    ///
2880    /// The returned buffer can be used as a source for fast CPU-to-GPU
2881    /// transfers. The memory is page-locked so the GPU can DMA from it
2882    /// without going through the OS page cache.
2883    ///
2884    /// # Errors
2885    /// Returns `CudaError` if pinned memory allocation fails (e.g., out of
2886    /// lockable memory, CUDA not initialized).
2887    pub fn from_slice(data: &[f32]) -> Result<Self, CudaError> {
2888        use std::ptr;
2889
2890        if data.is_empty() {
2891            return Ok(Self {
2892                ptr: ptr::null_mut(),
2893                len: 0,
2894            });
2895        }
2896
2897        let byte_size = data.len() * std::mem::size_of::<f32>();
2898        let mut host_ptr: *mut std::ffi::c_void = ptr::null_mut();
2899
2900        // Ensure CUDA is initialized before calling driver API
2901        let _ = get_cuda_backend().ok_or(CudaError::DeviceNotFound)?;
2902
2903        unsafe {
2904            let result = cudarc::driver::sys::cuMemAllocHost_v2(&mut host_ptr, byte_size);
2905            if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
2906                return Err(CudaError::AllocationFailed);
2907            }
2908
2909            // Copy data into pinned buffer
2910            ptr::copy_nonoverlapping(data.as_ptr(), host_ptr as *mut f32, data.len());
2911        }
2912
2913        Ok(Self {
2914            ptr: host_ptr as *mut f32,
2915            len: data.len(),
2916        })
2917    }
2918
2919    /// Allocates an uninitialized pinned host buffer of the given length.
2920    ///
2921    /// # Safety
2922    /// The contents are uninitialized. Caller must write to the buffer
2923    /// before reading from it.
2924    ///
2925    /// # Errors
2926    /// Returns `CudaError` if pinned memory allocation fails.
2927    pub fn alloc(len: usize) -> Result<Self, CudaError> {
2928        use std::ptr;
2929
2930        if len == 0 {
2931            return Ok(Self {
2932                ptr: ptr::null_mut(),
2933                len: 0,
2934            });
2935        }
2936
2937        let byte_size = len * std::mem::size_of::<f32>();
2938        let mut host_ptr: *mut std::ffi::c_void = ptr::null_mut();
2939
2940        let _ = get_cuda_backend().ok_or(CudaError::DeviceNotFound)?;
2941
2942        unsafe {
2943            let result = cudarc::driver::sys::cuMemAllocHost_v2(&mut host_ptr, byte_size);
2944            if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
2945                return Err(CudaError::AllocationFailed);
2946            }
2947        }
2948
2949        Ok(Self {
2950            ptr: host_ptr as *mut f32,
2951            len,
2952        })
2953    }
2954
2955    /// Returns a slice view of the pinned buffer.
2956    pub fn as_slice(&self) -> &[f32] {
2957        if self.ptr.is_null() || self.len == 0 {
2958            return &[];
2959        }
2960        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
2961    }
2962
2963    /// Returns a mutable slice view of the pinned buffer.
2964    pub fn as_slice_mut(&mut self) -> &mut [f32] {
2965        if self.ptr.is_null() || self.len == 0 {
2966            return &mut [];
2967        }
2968        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
2969    }
2970
2971    /// Returns the number of elements in the buffer.
2972    pub fn len(&self) -> usize {
2973        self.len
2974    }
2975
2976    /// Returns true if the buffer is empty.
2977    pub fn is_empty(&self) -> bool {
2978        self.len == 0
2979    }
2980
2981    /// Returns the raw host pointer.
2982    pub fn as_ptr(&self) -> *const f32 {
2983        self.ptr
2984    }
2985
2986    /// Returns a mutable raw host pointer.
2987    pub fn as_mut_ptr(&mut self) -> *mut f32 {
2988        self.ptr
2989    }
2990
2991    /// Transfers the pinned buffer contents to a GPU `CudaSlice`.
2992    ///
2993    /// This is the fast path: since the source memory is pinned, the GPU
2994    /// can DMA directly without staging through pageable memory.
2995    pub fn to_gpu(&self) -> Result<CudaSlice<f32>, CudaError> {
2996        let backend = get_cuda_backend().ok_or(CudaError::DeviceNotFound)?;
2997        backend.htod_copy(self.as_slice())
2998    }
2999}
3000
3001#[cfg(feature = "cuda")]
3002impl Drop for PinnedBuffer {
3003    fn drop(&mut self) {
3004        if !self.ptr.is_null() {
3005            unsafe {
3006                let _ = cudarc::driver::sys::cuMemFreeHost(self.ptr as *mut std::ffi::c_void);
3007            }
3008            self.ptr = std::ptr::null_mut();
3009        }
3010    }
3011}
3012
3013/// Convenience function: allocate pinned host memory and copy data into it.
3014///
3015/// This is a shorthand for `PinnedBuffer::from_slice(data)`.
3016///
3017/// # Errors
3018/// Returns `CudaError` if CUDA is not available or allocation fails.
3019#[cfg(feature = "cuda")]
3020pub fn pin_memory(data: &[f32]) -> Result<PinnedBuffer, CudaError> {
3021    PinnedBuffer::from_slice(data)
3022}
3023
3024/// Stub when CUDA is not enabled - pinned memory is not available.
3025#[cfg(not(feature = "cuda"))]
3026pub fn pin_memory(_data: &[f32]) -> Result<(), CudaError> {
3027    Err(CudaError::DeviceNotFound)
3028}
3029
3030// =============================================================================
3031// Tests
3032// =============================================================================
3033
3034#[cfg(test)]
3035mod tests {
3036    use super::*;
3037
3038    #[test]
3039    fn test_cuda_availability() {
3040        let available = is_available();
3041        println!("CUDA available: {}", available);
3042    }
3043
3044    #[test]
3045    fn test_device_count() {
3046        let count = device_count();
3047        println!("CUDA device count: {}", count);
3048        assert!(count <= 16);
3049    }
3050
3051    #[test]
3052    #[cfg(feature = "cuda")]
3053    fn test_cuda_backend_creation() {
3054        if is_available() {
3055            let backend = CudaBackend::new(0);
3056            assert!(backend.is_some());
3057        }
3058    }
3059
3060    #[test]
3061    #[cfg(feature = "cuda")]
3062    fn test_cuda_memory_operations() {
3063        if !is_available() {
3064            return;
3065        }
3066
3067        let backend = CudaBackend::new(0).unwrap();
3068
3069        // Test allocation
3070        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
3071        let gpu_data = backend.htod_copy(&data).unwrap();
3072
3073        // Test copy back
3074        let result = backend.dtoh_copy(&gpu_data).unwrap();
3075        assert_eq!(data, result);
3076    }
3077
3078    #[test]
3079    #[cfg(feature = "cuda")]
3080    fn test_cuda_gemm() {
3081        if !is_available() {
3082            return;
3083        }
3084
3085        let backend = CudaBackend::new(0).unwrap();
3086
3087        // cuBLAS uses column-major order
3088        // To compute C = A @ B where:
3089        //   A is 2x3 (m=2, k=3) and B is 3x2 (k=3, n=2), C is 2x2 (m=2, n=2)
3090        // In column-major: lda >= m, ldb >= k, ldc >= m
3091        //
3092        // A in column-major (2x3):
3093        // | a00 a01 a02 |    stored as: [a00, a10, a01, a11, a02, a12]
3094        // | a10 a11 a12 |
3095        let a: Vec<f32> = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; // column-major 2x3
3096        // B in column-major (3x2):
3097        // | b00 b01 |    stored as: [b00, b10, b20, b01, b11, b21]
3098        // | b10 b11 |
3099        // | b20 b21 |
3100        let b: Vec<f32> = vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]; // column-major 3x2
3101        let c: Vec<f32> = vec![0.0; 4]; // 2x2
3102
3103        let a_gpu = backend.htod_copy(&a).unwrap();
3104        let b_gpu = backend.htod_copy(&b).unwrap();
3105        let mut c_gpu = backend.htod_copy(&c).unwrap();
3106
3107        // C = A @ B
3108        // m=2 (rows of A, rows of C)
3109        // n=2 (cols of B, cols of C)
3110        // k=3 (cols of A, rows of B)
3111        // lda=2 (leading dimension of A, >= m)
3112        // ldb=3 (leading dimension of B, >= k)
3113        // ldc=2 (leading dimension of C, >= m)
3114        backend
3115            .gemm_f32(
3116                false, false, 2, 2, 3,   // m, n, k
3117                1.0, // alpha
3118                &a_gpu, 2, // A, lda
3119                &b_gpu, 3,   // B, ldb
3120                0.0, // beta
3121                &mut c_gpu, 2, // C, ldc
3122            )
3123            .unwrap();
3124
3125        let result = backend.dtoh_copy(&c_gpu).unwrap();
3126        // C = A @ B (in matrix form, row-major interpretation):
3127        // A = [[1,2,3],[4,5,6]], B = [[1,2],[3,4],[5,6]]
3128        // C[0,0] = 1*1 + 2*3 + 3*5 = 1 + 6 + 15 = 22
3129        // C[1,0] = 4*1 + 5*3 + 6*5 = 4 + 15 + 30 = 49
3130        // C[0,1] = 1*2 + 2*4 + 3*6 = 2 + 8 + 18 = 28
3131        // C[1,1] = 4*2 + 5*4 + 6*6 = 8 + 20 + 36 = 64
3132        // Column-major result: [22, 49, 28, 64]
3133        assert!((result[0] - 22.0).abs() < 1e-5, "result[0] = {}", result[0]);
3134        assert!((result[1] - 49.0).abs() < 1e-5, "result[1] = {}", result[1]);
3135        assert!((result[2] - 28.0).abs() < 1e-5, "result[2] = {}", result[2]);
3136        assert!((result[3] - 64.0).abs() < 1e-5, "result[3] = {}", result[3]);
3137    }
3138
3139    #[test]
3140    #[cfg(feature = "cuda")]
3141    fn test_cuda_add_kernel() {
3142        if !is_available() {
3143            return;
3144        }
3145
3146        let backend = CudaBackend::new(0).unwrap();
3147
3148        let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
3149        let b: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
3150
3151        let a_gpu = backend.htod_copy(&a).unwrap();
3152        let b_gpu = backend.htod_copy(&b).unwrap();
3153        let mut c_gpu = backend.alloc::<f32>(4).unwrap();
3154
3155        backend.add_f32(&mut c_gpu, &a_gpu, &b_gpu, 4).unwrap();
3156
3157        let result = backend.dtoh_copy(&c_gpu).unwrap();
3158        assert!((result[0] - 6.0).abs() < 1e-5);
3159        assert!((result[1] - 8.0).abs() < 1e-5);
3160        assert!((result[2] - 10.0).abs() < 1e-5);
3161        assert!((result[3] - 12.0).abs() < 1e-5);
3162    }
3163
3164    #[test]
3165    #[cfg(feature = "cuda")]
3166    fn test_cuda_mul_kernel() {
3167        if !is_available() {
3168            return;
3169        }
3170
3171        let backend = CudaBackend::new(0).unwrap();
3172
3173        let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
3174        let b: Vec<f32> = vec![2.0, 3.0, 4.0, 5.0];
3175
3176        let a_gpu = backend.htod_copy(&a).unwrap();
3177        let b_gpu = backend.htod_copy(&b).unwrap();
3178        let mut c_gpu = backend.alloc::<f32>(4).unwrap();
3179
3180        backend.mul_f32(&mut c_gpu, &a_gpu, &b_gpu, 4).unwrap();
3181
3182        let result = backend.dtoh_copy(&c_gpu).unwrap();
3183        assert!((result[0] - 2.0).abs() < 1e-5);
3184        assert!((result[1] - 6.0).abs() < 1e-5);
3185        assert!((result[2] - 12.0).abs() < 1e-5);
3186        assert!((result[3] - 20.0).abs() < 1e-5);
3187    }
3188
3189    #[test]
3190    #[cfg(feature = "cuda")]
3191    fn test_cuda_scale_kernel() {
3192        if !is_available() {
3193            return;
3194        }
3195
3196        let backend = CudaBackend::new(0).unwrap();
3197
3198        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
3199        let mut data_gpu = backend.htod_copy(&data).unwrap();
3200
3201        backend.scale_f32(&mut data_gpu, 2.5, 4).unwrap();
3202
3203        let result = backend.dtoh_copy(&data_gpu).unwrap();
3204        assert!((result[0] - 2.5).abs() < 1e-5);
3205        assert!((result[1] - 5.0).abs() < 1e-5);
3206        assert!((result[2] - 7.5).abs() < 1e-5);
3207        assert!((result[3] - 10.0).abs() < 1e-5);
3208    }
3209
3210    #[test]
3211    #[cfg(feature = "cuda")]
3212    fn test_cuda_relu_kernel() {
3213        if !is_available() {
3214            return;
3215        }
3216
3217        let backend = CudaBackend::new(0).unwrap();
3218
3219        let input: Vec<f32> = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
3220        let input_gpu = backend.htod_copy(&input).unwrap();
3221        let mut output_gpu = backend.alloc::<f32>(5).unwrap();
3222
3223        backend.relu_f32(&mut output_gpu, &input_gpu, 5).unwrap();
3224
3225        let result = backend.dtoh_copy(&output_gpu).unwrap();
3226        assert!((result[0] - 0.0).abs() < 1e-5);
3227        assert!((result[1] - 0.0).abs() < 1e-5);
3228        assert!((result[2] - 0.0).abs() < 1e-5);
3229        assert!((result[3] - 1.0).abs() < 1e-5);
3230        assert!((result[4] - 2.0).abs() < 1e-5);
3231    }
3232
3233    #[test]
3234    #[cfg(feature = "cuda")]
3235    fn test_cuda_sigmoid_kernel() {
3236        if !is_available() {
3237            return;
3238        }
3239
3240        let backend = CudaBackend::new(0).unwrap();
3241
3242        let input: Vec<f32> = vec![0.0, 1.0, -1.0];
3243        let input_gpu = backend.htod_copy(&input).unwrap();
3244        let mut output_gpu = backend.alloc::<f32>(3).unwrap();
3245
3246        backend.sigmoid_f32(&mut output_gpu, &input_gpu, 3).unwrap();
3247
3248        let result = backend.dtoh_copy(&output_gpu).unwrap();
3249        // sigmoid(0) = 0.5
3250        assert!((result[0] - 0.5).abs() < 1e-4);
3251        // sigmoid(1) ≈ 0.7311
3252        assert!((result[1] - 0.7311).abs() < 1e-3);
3253        // sigmoid(-1) ≈ 0.2689
3254        assert!((result[2] - 0.2689).abs() < 1e-3);
3255    }
3256
3257    #[test]
3258    #[cfg(feature = "cuda")]
3259    fn test_cuda_tanh_kernel() {
3260        if !is_available() {
3261            return;
3262        }
3263
3264        let backend = CudaBackend::new(0).unwrap();
3265
3266        let input: Vec<f32> = vec![0.0, 1.0, -1.0];
3267        let input_gpu = backend.htod_copy(&input).unwrap();
3268        let mut output_gpu = backend.alloc::<f32>(3).unwrap();
3269
3270        backend.tanh_f32(&mut output_gpu, &input_gpu, 3).unwrap();
3271
3272        let result = backend.dtoh_copy(&output_gpu).unwrap();
3273        // tanh(0) = 0
3274        assert!((result[0] - 0.0).abs() < 1e-5);
3275        // tanh(1) ≈ 0.7616
3276        assert!((result[1] - 0.7616).abs() < 1e-3);
3277        // tanh(-1) ≈ -0.7616
3278        assert!((result[2] - (-0.7616)).abs() < 1e-3);
3279    }
3280
3281    #[test]
3282    #[cfg(feature = "cuda")]
3283    fn test_cuda_large_tensor_add() {
3284        if !is_available() {
3285            return;
3286        }
3287
3288        let backend = CudaBackend::new(0).unwrap();
3289
3290        // Test with a large tensor (1M elements)
3291        let n = 1_000_000;
3292        let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
3293        let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
3294
3295        let a_gpu = backend.htod_copy(&a).unwrap();
3296        let b_gpu = backend.htod_copy(&b).unwrap();
3297        let mut c_gpu = backend.alloc::<f32>(n).unwrap();
3298
3299        backend.add_f32(&mut c_gpu, &a_gpu, &b_gpu, n).unwrap();
3300
3301        let result = backend.dtoh_copy(&c_gpu).unwrap();
3302
3303        // Each element should equal n (i + (n-i) = n)
3304        assert!((result[0] - n as f32).abs() < 1e-3);
3305        assert!((result[n / 2] - n as f32).abs() < 1e-3);
3306        assert!((result[n - 1] - n as f32).abs() < 1e-3);
3307    }
3308
3309    #[test]
3310    #[cfg(feature = "cuda")]
3311    fn test_cuda_conv2d_forward() {
3312        if !is_available() {
3313            return;
3314        }
3315
3316        // 1x1 conv: 3 in_channels → 2 out_channels, input 4x4
3317        let input = vec![1.0f32; 1 * 3 * 4 * 4]; // all ones
3318        let mut weight = vec![0.0f32; 2 * 3 * 1 * 1];
3319        // out_ch0 = in_ch0 (weight[0]=1), out_ch1 = in_ch1 (weight[4]=1)
3320        weight[0] = 1.0;
3321        weight[4] = 1.0;
3322        let bias = vec![0.5f32; 2];
3323
3324        let result = cuda_conv2d_forward(
3325            &input,
3326            &weight,
3327            Some(&bias),
3328            1,
3329            3,
3330            4,
3331            4,
3332            2,
3333            1,
3334            1,
3335            1,
3336            1,
3337            0,
3338            0,
3339        );
3340
3341        let out = result.expect("CUDA conv2d should succeed");
3342        assert_eq!(out.len(), 2 * 4 * 4);
3343        // out_ch0 = 1.0*1 + 0.5 = 1.5
3344        assert!(
3345            (out[0] - 1.5).abs() < 0.01,
3346            "1x1 conv ch0: expected 1.5, got {}",
3347            out[0]
3348        );
3349        // out_ch1 = 1.0*1 + 0.5 = 1.5
3350        assert!(
3351            (out[16] - 1.5).abs() < 0.01,
3352            "1x1 conv ch1: expected 1.5, got {}",
3353            out[16]
3354        );
3355
3356        // 3x3 conv with padding=1: all-ones input, all-ones weight
3357        let input2 = vec![1.0f32; 1 * 3 * 8 * 8];
3358        let weight2 = vec![1.0f32; 2 * 3 * 3 * 3]; // all 1s → each output = sum of 27 inputs
3359        let bias2 = vec![0.0f32; 2];
3360
3361        let result2 = cuda_conv2d_forward(
3362            &input2,
3363            &weight2,
3364            Some(&bias2),
3365            1,
3366            3,
3367            8,
3368            8,
3369            2,
3370            3,
3371            3,
3372            1,
3373            1,
3374            1,
3375            1,
3376        );
3377
3378        let out2 = result2.expect("CUDA 3x3 conv should succeed");
3379        assert_eq!(out2.len(), 2 * 8 * 8);
3380        // Center pixel (row 4, col 4) = 3 channels * 9 kernel positions * 1.0 = 27.0
3381        let center = 4 * 8 + 4;
3382        assert!(
3383            (out2[center] - 27.0).abs() < 0.1,
3384            "3x3 conv center: expected 27.0, got {}",
3385            out2[center]
3386        );
3387        // Corner pixel (0,0) with pad=1: only 2x2x3 = 12 valid positions
3388        assert!(
3389            (out2[0] - 12.0).abs() < 0.1,
3390            "3x3 conv corner: expected 12.0, got {}",
3391            out2[0]
3392        );
3393    }
3394}