1#[cfg(feature = "cuda")]
18use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, sys::cublasOperation_t};
19#[cfg(feature = "cudnn")]
20use cudarc::cudnn::Cudnn;
21#[cfg(feature = "cuda")]
22use cudarc::driver::{
23 CudaContext, CudaSlice, CudaStream, DeviceRepr, DeviceSlice, LaunchConfig, PushKernelArg,
24 ValidAsZeroBits,
25};
26
27use super::Backend;
28#[cfg(feature = "cuda")]
29use super::cuda_kernels::{self, BLOCK_SIZE, CudaKernels};
30use crate::device::DeviceCapabilities;
31#[cfg(feature = "cuda")]
32use std::sync::Arc;
33#[cfg(feature = "cuda")]
34use std::sync::OnceLock;
35
36#[cfg(feature = "cuda")]
41static CUDA_BACKEND: OnceLock<Option<CudaBackend>> = OnceLock::new();
42
43#[cfg(feature = "cuda")]
45pub fn get_cuda_backend() -> Option<&'static CudaBackend> {
46 CUDA_BACKEND
47 .get_or_init(|| {
48 let backend = CudaBackend::new(0);
49 if backend.is_some() {
50 eprintln!("[AxonML] CUDA backend initialized (GPU 0)");
51 }
52 backend
53 })
54 .as_ref()
55}
56
57#[cfg(not(feature = "cuda"))]
59pub fn get_cuda_backend() -> Option<&'static CudaBackend> {
60 None
61}
62
63#[cfg(feature = "cuda")]
72pub struct CudaBackend {
73 device_index: usize,
74 ctx: Arc<CudaContext>,
75 stream: Arc<CudaStream>,
76 blas: CudaBlas,
77 kernels: CudaKernels,
78 #[cfg(feature = "cudnn")]
79 cudnn_handle: Option<Arc<Cudnn>>,
80}
81
82#[cfg(not(feature = "cuda"))]
84#[derive(Debug)]
85pub struct CudaBackend {
86 device_index: usize,
87}
88
89#[cfg(feature = "cuda")]
92unsafe impl Send for CudaBackend {}
93#[cfg(feature = "cuda")]
94unsafe impl Sync for CudaBackend {}
95
96#[cfg(feature = "cuda")]
97impl std::fmt::Debug for CudaBackend {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("CudaBackend")
100 .field("device_index", &self.device_index)
101 .finish()
102 }
103}
104
105impl CudaBackend {
106 #[cfg(feature = "cuda")]
108 pub fn new(device_index: usize) -> Option<Self> {
109 let ctx = CudaContext::new(device_index).ok()?;
110 let stream = ctx.default_stream();
111 let blas = CudaBlas::new(stream.clone()).ok()?;
112 let kernels = match CudaKernels::load(ctx.clone()) {
113 Ok(k) => k,
114 Err(e) => {
115 eprintln!("[AxonML CUDA] Kernel loading failed: {:?}", e);
116 return None;
117 }
118 };
119
120 #[cfg(feature = "cudnn")]
121 let cudnn_handle = match Cudnn::new(stream.clone()) {
122 Ok(handle) => {
123 eprintln!("[AxonML] cuDNN handle initialized");
124 Some(handle)
125 }
126 Err(e) => {
127 eprintln!(
128 "[AxonML CUDA] cuDNN init failed: {:?} (falling back to im2col+GEMM)",
129 e
130 );
131 None
132 }
133 };
134
135 Some(Self {
136 device_index,
137 ctx,
138 stream,
139 blas,
140 kernels,
141 #[cfg(feature = "cudnn")]
142 cudnn_handle,
143 })
144 }
145
146 #[cfg(not(feature = "cuda"))]
148 pub fn new(device_index: usize) -> Option<Self> {
149 let _ = device_index;
150 None }
152
153 pub fn device_index(&self) -> usize {
155 self.device_index
156 }
157
158 #[cfg(feature = "cuda")]
160 pub fn context(&self) -> &Arc<CudaContext> {
161 &self.ctx
162 }
163
164 #[cfg(feature = "cuda")]
166 pub fn stream(&self) -> &Arc<CudaStream> {
167 &self.stream
168 }
169
170 #[cfg(feature = "cuda")]
172 pub fn blas(&self) -> &CudaBlas {
173 &self.blas
174 }
175
176 #[cfg(feature = "cudnn")]
178 pub fn cudnn(&self) -> Option<&Arc<Cudnn>> {
179 self.cudnn_handle.as_ref()
180 }
181
182 #[cfg(feature = "cuda")]
184 pub fn alloc<T: DeviceRepr + ValidAsZeroBits>(
185 &self,
186 len: usize,
187 ) -> Result<CudaSlice<T>, CudaError> {
188 self.stream.alloc_zeros(len).map_err(CudaError::from)
189 }
190
191 #[cfg(feature = "cuda")]
193 pub fn alloc_uninit<T: DeviceRepr>(&self, len: usize) -> Result<CudaSlice<T>, CudaError> {
194 unsafe { self.stream.alloc(len).map_err(CudaError::from) }
195 }
196
197 #[cfg(feature = "cuda")]
199 pub fn htod_copy<T: DeviceRepr>(&self, src: &[T]) -> Result<CudaSlice<T>, CudaError> {
200 self.stream.clone_htod(src).map_err(CudaError::from)
201 }
202
203 #[cfg(feature = "cuda")]
205 pub fn dtoh_copy<T: DeviceRepr>(&self, src: &CudaSlice<T>) -> Result<Vec<T>, CudaError> {
206 self.stream.clone_dtoh(src).map_err(CudaError::from)
207 }
208}
209
210#[cfg(feature = "cuda")]
215impl Backend for CudaBackend {
216 fn name(&self) -> &'static str {
217 "cuda"
218 }
219
220 fn is_available(&self) -> bool {
221 true
222 }
223
224 fn capabilities(&self) -> DeviceCapabilities {
225 let name = format!("CUDA Device {}", self.device_index);
227
228 let (free, total) = cudarc::driver::result::mem_get_info().unwrap_or((0, 0));
230
231 DeviceCapabilities {
232 name,
233 total_memory: total,
234 available_memory: free,
235 supports_f16: true,
236 supports_f64: true,
237 max_threads_per_block: 1024,
238 compute_capability: None, }
240 }
241
242 fn allocate(&self, size: usize) -> *mut u8 {
243 match self.stream.alloc_zeros::<u8>(size) {
244 Ok(slice) => {
245 let ptr = slice.leak() as *mut u8;
247 ptr
248 }
249 Err(_) => std::ptr::null_mut(),
250 }
251 }
252
253 fn deallocate(&self, ptr: *mut u8, size: usize) {
254 if !ptr.is_null() {
255 unsafe {
257 let slice: CudaSlice<u8> = self
258 .stream
259 .upgrade_device_ptr(ptr as cudarc::driver::sys::CUdeviceptr, size);
260 drop(slice);
261 }
262 }
263 }
264
265 fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
266 if dst.is_null() || src.is_null() || size == 0 {
267 return;
268 }
269 unsafe {
270 let src_slice = std::slice::from_raw_parts(src, size);
271 let _ = cudarc::driver::result::memcpy_htod_sync(
272 dst as cudarc::driver::sys::CUdeviceptr,
273 src_slice,
274 );
275 }
276 }
277
278 fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize) {
279 if dst.is_null() || src.is_null() || size == 0 {
280 return;
281 }
282 unsafe {
283 let dst_slice = std::slice::from_raw_parts_mut(dst, size);
284 let _ = cudarc::driver::result::memcpy_dtoh_sync(
285 dst_slice,
286 src as cudarc::driver::sys::CUdeviceptr,
287 );
288 }
289 }
290
291 fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize) {
292 if dst.is_null() || src.is_null() || size == 0 {
293 return;
294 }
295 unsafe {
296 let _ = cudarc::driver::result::memcpy_dtod_sync(
297 dst as cudarc::driver::sys::CUdeviceptr,
298 src as cudarc::driver::sys::CUdeviceptr,
299 size,
300 );
301 }
302 }
303
304 fn synchronize(&self) {
305 let _ = self.stream.synchronize();
306 }
307}
308
309#[cfg(feature = "cuda")]
312pub fn cuda_sync() -> bool {
313 if let Some(backend) = get_cuda_backend() {
314 let _ = backend.stream.synchronize();
315 true
316 } else {
317 false
318 }
319}
320
321#[cfg(not(feature = "cuda"))]
323pub fn cuda_sync() -> bool {
324 false
325}
326
327#[cfg(not(feature = "cuda"))]
328impl Backend for CudaBackend {
329 fn name(&self) -> &'static str {
330 "cuda"
331 }
332
333 fn is_available(&self) -> bool {
334 false
335 }
336
337 fn capabilities(&self) -> DeviceCapabilities {
338 DeviceCapabilities {
339 name: format!("CUDA Device {} (unavailable)", self.device_index),
340 total_memory: 0,
341 available_memory: 0,
342 supports_f16: false,
343 supports_f64: false,
344 max_threads_per_block: 0,
345 compute_capability: None,
346 }
347 }
348
349 fn allocate(&self, _size: usize) -> *mut u8 {
350 std::ptr::null_mut()
351 }
352
353 fn deallocate(&self, _ptr: *mut u8, _size: usize) {}
354
355 fn copy_to_device(&self, _dst: *mut u8, _src: *const u8, _size: usize) {}
356
357 fn copy_to_host(&self, _dst: *mut u8, _src: *const u8, _size: usize) {}
358
359 fn copy_device_to_device(&self, _dst: *mut u8, _src: *const u8, _size: usize) {}
360
361 fn synchronize(&self) {}
362}
363
364#[derive(Debug)]
370pub enum CudaError {
371 DeviceNotFound,
373 AllocationFailed,
375 CopyFailed,
377 KernelLaunchFailed,
379 BlasError(String),
381 DriverError(String),
383 ModuleLoadFailed(String),
385 KernelNotFound(String),
387}
388
389impl std::fmt::Display for CudaError {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 match self {
392 CudaError::DeviceNotFound => write!(f, "CUDA device not found"),
393 CudaError::AllocationFailed => write!(f, "CUDA memory allocation failed"),
394 CudaError::CopyFailed => write!(f, "CUDA memory copy failed"),
395 CudaError::KernelLaunchFailed => write!(f, "CUDA kernel launch failed"),
396 CudaError::BlasError(s) => write!(f, "cuBLAS error: {}", s),
397 CudaError::DriverError(s) => write!(f, "CUDA driver error: {}", s),
398 CudaError::ModuleLoadFailed(s) => write!(f, "CUDA module load failed: {}", s),
399 CudaError::KernelNotFound(s) => write!(f, "CUDA kernel not found: {}", s),
400 }
401 }
402}
403
404impl std::error::Error for CudaError {}
405
406#[cfg(feature = "cuda")]
407impl From<cudarc::driver::DriverError> for CudaError {
408 fn from(e: cudarc::driver::DriverError) -> Self {
409 CudaError::DriverError(e.to_string())
410 }
411}
412
413#[cfg(feature = "cuda")]
414impl From<cudarc::cublas::result::CublasError> for CudaError {
415 fn from(e: cudarc::cublas::result::CublasError) -> Self {
416 CudaError::BlasError(format!("{:?}", e))
417 }
418}
419
420pub fn is_available() -> bool {
426 #[cfg(feature = "cuda")]
427 {
428 CudaContext::new(0).is_ok()
429 }
430 #[cfg(not(feature = "cuda"))]
431 {
432 false
433 }
434}
435
436pub fn device_count() -> usize {
438 #[cfg(feature = "cuda")]
439 {
440 cudarc::driver::result::device::get_count().unwrap_or(0) as usize
441 }
442 #[cfg(not(feature = "cuda"))]
443 {
444 0
445 }
446}
447
448pub fn is_device_available(index: usize) -> bool {
450 index < device_count()
451}
452
453pub fn get_capabilities(index: usize) -> DeviceCapabilities {
455 #[cfg(feature = "cuda")]
456 {
457 if let Some(backend) = CudaBackend::new(index) {
458 return backend.capabilities();
459 }
460 }
461 #[allow(unreachable_code)]
462 DeviceCapabilities {
463 name: format!("CUDA Device {}", index),
464 total_memory: 0,
465 available_memory: 0,
466 supports_f16: true,
467 supports_f64: true,
468 max_threads_per_block: 1024,
469 compute_capability: None,
470 }
471}
472
473#[cfg(feature = "cuda")]
490pub fn stream_synchronize(_handle: usize) {
491 }
498
499#[cfg(not(feature = "cuda"))]
501pub fn stream_synchronize(_handle: usize) {
502 }
504
505#[cfg(feature = "cuda")]
510impl CudaBackend {
511 pub fn gemm_f32(
513 &self,
514 transa: bool,
515 transb: bool,
516 m: usize,
517 n: usize,
518 k: usize,
519 alpha: f32,
520 a: &CudaSlice<f32>,
521 lda: usize,
522 b: &CudaSlice<f32>,
523 ldb: usize,
524 beta: f32,
525 c: &mut CudaSlice<f32>,
526 ldc: usize,
527 ) -> Result<(), CudaError> {
528 let cfg = GemmConfig {
529 transa: if transa {
530 cublasOperation_t::CUBLAS_OP_T
531 } else {
532 cublasOperation_t::CUBLAS_OP_N
533 },
534 transb: if transb {
535 cublasOperation_t::CUBLAS_OP_T
536 } else {
537 cublasOperation_t::CUBLAS_OP_N
538 },
539 m: m as i32,
540 n: n as i32,
541 k: k as i32,
542 alpha,
543 lda: lda as i32,
544 ldb: ldb as i32,
545 beta,
546 ldc: ldc as i32,
547 };
548
549 unsafe { self.blas.gemm(cfg, a, b, c).map_err(CudaError::from) }
550 }
551
552 pub fn gemm_batched_f32(
554 &self,
555 transa: bool,
556 transb: bool,
557 m: usize,
558 n: usize,
559 k: usize,
560 alpha: f32,
561 a_array: &[&CudaSlice<f32>],
562 lda: usize,
563 b_array: &[&CudaSlice<f32>],
564 ldb: usize,
565 beta: f32,
566 c_array: &mut [&mut CudaSlice<f32>],
567 ldc: usize,
568 batch_count: usize,
569 ) -> Result<(), CudaError> {
570 for i in 0..batch_count {
572 let cfg = GemmConfig {
573 transa: if transa {
574 cublasOperation_t::CUBLAS_OP_T
575 } else {
576 cublasOperation_t::CUBLAS_OP_N
577 },
578 transb: if transb {
579 cublasOperation_t::CUBLAS_OP_T
580 } else {
581 cublasOperation_t::CUBLAS_OP_N
582 },
583 m: m as i32,
584 n: n as i32,
585 k: k as i32,
586 alpha,
587 lda: lda as i32,
588 ldb: ldb as i32,
589 beta,
590 ldc: ldc as i32,
591 };
592
593 unsafe {
594 self.blas
595 .gemm(cfg, a_array[i], b_array[i], c_array[i])
596 .map_err(CudaError::from)?;
597 }
598 }
599 Ok(())
600 }
601
602 pub fn gemm_strided_batched_f32(
606 &self,
607 transa: bool,
608 transb: bool,
609 m: usize,
610 n: usize,
611 k: usize,
612 alpha: f32,
613 a: &CudaSlice<f32>,
614 lda: usize,
615 stride_a: i64,
616 b: &CudaSlice<f32>,
617 ldb: usize,
618 stride_b: i64,
619 beta: f32,
620 c: &mut CudaSlice<f32>,
621 ldc: usize,
622 stride_c: i64,
623 batch_count: usize,
624 ) -> Result<(), CudaError> {
625 use cudarc::cublas::result::sgemm_strided_batched;
626 use cudarc::driver::DevicePtr as _;
627 use cudarc::driver::DevicePtrMut as _;
628
629 let op_a = if transa {
630 cublasOperation_t::CUBLAS_OP_T
631 } else {
632 cublasOperation_t::CUBLAS_OP_N
633 };
634 let op_b = if transb {
635 cublasOperation_t::CUBLAS_OP_T
636 } else {
637 cublasOperation_t::CUBLAS_OP_N
638 };
639
640 let (a_devptr, _ga) = a.device_ptr(&self.stream);
641 let (b_devptr, _gb) = b.device_ptr(&self.stream);
642 let (c_devptr, _gc) = c.device_ptr_mut(&self.stream);
643 let a_ptr = a_devptr as *const f32;
644 let b_ptr = b_devptr as *const f32;
645 let c_ptr = c_devptr as *mut f32;
646
647 unsafe {
648 sgemm_strided_batched(
649 *self.blas.handle(),
650 op_a,
651 op_b,
652 m as i32,
653 n as i32,
654 k as i32,
655 &alpha as *const f32,
656 a_ptr,
657 lda as i32,
658 stride_a,
659 b_ptr,
660 ldb as i32,
661 stride_b,
662 &beta as *const f32,
663 c_ptr,
664 ldc as i32,
665 stride_c,
666 batch_count as i32,
667 )
668 .map_err(CudaError::from)
669 }
670 }
671
672 pub fn add_f32(
674 &self,
675 dst: &mut CudaSlice<f32>,
676 a: &CudaSlice<f32>,
677 b: &CudaSlice<f32>,
678 len: usize,
679 ) -> Result<(), CudaError> {
680 let func = self
681 .kernels
682 .get("add_f32")
683 .ok_or_else(|| CudaError::KernelNotFound("add_f32".to_string()))?;
684
685 let cfg = cuda_kernels::launch_config(len);
686 unsafe {
687 self.stream
688 .launch_builder(func)
689 .arg(a)
690 .arg(b)
691 .arg(dst)
692 .arg(&(len as u32))
693 .launch(cfg)
694 .map(|_| ())
695 .map_err(|e| CudaError::DriverError(e.to_string()))?;
696 }
697 Ok(())
698 }
699
700 pub fn scale_f32(
702 &self,
703 dst: &mut CudaSlice<f32>,
704 alpha: f32,
705 len: usize,
706 ) -> Result<(), CudaError> {
707 let func = self
708 .kernels
709 .get("scale_f32")
710 .ok_or_else(|| CudaError::KernelNotFound("scale_f32".to_string()))?;
711
712 let cfg = cuda_kernels::launch_config(len);
713 unsafe {
714 self.stream
715 .launch_builder(func)
716 .arg(dst)
717 .arg(&alpha)
718 .arg(&(len as u32))
719 .launch(cfg)
720 .map(|_| ())
721 .map_err(|e| CudaError::DriverError(e.to_string()))?;
722 }
723 Ok(())
724 }
725
726 pub fn mul_f32(
728 &self,
729 dst: &mut CudaSlice<f32>,
730 a: &CudaSlice<f32>,
731 b: &CudaSlice<f32>,
732 len: usize,
733 ) -> Result<(), CudaError> {
734 let func = self
735 .kernels
736 .get("mul_f32")
737 .ok_or_else(|| CudaError::KernelNotFound("mul_f32".to_string()))?;
738
739 let cfg = cuda_kernels::launch_config(len);
740 unsafe {
741 self.stream
742 .launch_builder(func)
743 .arg(a)
744 .arg(b)
745 .arg(dst)
746 .arg(&(len as u32))
747 .launch(cfg)
748 .map(|_| ())
749 .map_err(|e| CudaError::DriverError(e.to_string()))?;
750 }
751 Ok(())
752 }
753
754 pub fn relu_f32(
756 &self,
757 dst: &mut CudaSlice<f32>,
758 src: &CudaSlice<f32>,
759 len: usize,
760 ) -> Result<(), CudaError> {
761 let func = self
762 .kernels
763 .get("relu_f32")
764 .ok_or_else(|| CudaError::KernelNotFound("relu_f32".to_string()))?;
765
766 let cfg = cuda_kernels::launch_config(len);
767 unsafe {
768 self.stream
769 .launch_builder(func)
770 .arg(src)
771 .arg(dst)
772 .arg(&(len as u32))
773 .launch(cfg)
774 .map(|_| ())
775 .map_err(|e| CudaError::DriverError(e.to_string()))?;
776 }
777 Ok(())
778 }
779
780 pub fn sigmoid_f32(
782 &self,
783 dst: &mut CudaSlice<f32>,
784 src: &CudaSlice<f32>,
785 len: usize,
786 ) -> Result<(), CudaError> {
787 let func = self
788 .kernels
789 .get("sigmoid_f32")
790 .ok_or_else(|| CudaError::KernelNotFound("sigmoid_f32".to_string()))?;
791
792 let cfg = cuda_kernels::launch_config(len);
793 unsafe {
794 self.stream
795 .launch_builder(func)
796 .arg(src)
797 .arg(dst)
798 .arg(&(len as u32))
799 .launch(cfg)
800 .map(|_| ())
801 .map_err(|e| CudaError::DriverError(e.to_string()))?;
802 }
803 Ok(())
804 }
805
806 pub fn tanh_f32(
808 &self,
809 dst: &mut CudaSlice<f32>,
810 src: &CudaSlice<f32>,
811 len: usize,
812 ) -> Result<(), CudaError> {
813 let func = self
814 .kernels
815 .get("tanh_f32")
816 .ok_or_else(|| CudaError::KernelNotFound("tanh_f32".to_string()))?;
817
818 let cfg = cuda_kernels::launch_config(len);
819 unsafe {
820 self.stream
821 .launch_builder(func)
822 .arg(src)
823 .arg(dst)
824 .arg(&(len as u32))
825 .launch(cfg)
826 .map(|_| ())
827 .map_err(|e| CudaError::DriverError(e.to_string()))?;
828 }
829 Ok(())
830 }
831
832 pub fn sub_f32(
834 &self,
835 dst: &mut CudaSlice<f32>,
836 a: &CudaSlice<f32>,
837 b: &CudaSlice<f32>,
838 len: usize,
839 ) -> Result<(), CudaError> {
840 let func = self
841 .kernels
842 .get("sub_f32")
843 .ok_or_else(|| CudaError::KernelNotFound("sub_f32".to_string()))?;
844 let cfg = cuda_kernels::launch_config(len);
845 unsafe {
846 self.stream
847 .launch_builder(func)
848 .arg(a)
849 .arg(b)
850 .arg(dst)
851 .arg(&(len as u32))
852 .launch(cfg)
853 .map(|_| ())
854 .map_err(|e| CudaError::DriverError(e.to_string()))?;
855 }
856 Ok(())
857 }
858
859 pub fn div_f32(
861 &self,
862 dst: &mut CudaSlice<f32>,
863 a: &CudaSlice<f32>,
864 b: &CudaSlice<f32>,
865 len: usize,
866 ) -> Result<(), CudaError> {
867 let func = self
868 .kernels
869 .get("div_f32")
870 .ok_or_else(|| CudaError::KernelNotFound("div_f32".to_string()))?;
871 let cfg = cuda_kernels::launch_config(len);
872 unsafe {
873 self.stream
874 .launch_builder(func)
875 .arg(a)
876 .arg(b)
877 .arg(dst)
878 .arg(&(len as u32))
879 .launch(cfg)
880 .map(|_| ())
881 .map_err(|e| CudaError::DriverError(e.to_string()))?;
882 }
883 Ok(())
884 }
885
886 pub fn broadcast_add_f32(
893 &self,
894 dst: &mut CudaSlice<f32>,
895 a: &CudaSlice<f32>,
896 b: &CudaSlice<f32>,
897 n: usize,
898 b_len: usize,
899 ) -> Result<(), CudaError> {
900 let func = self
901 .kernels
902 .get("broadcast_add_f32")
903 .ok_or_else(|| CudaError::KernelNotFound("broadcast_add_f32".to_string()))?;
904 let cfg = cuda_kernels::launch_config(n);
905 unsafe {
906 self.stream
907 .launch_builder(func)
908 .arg(a)
909 .arg(b)
910 .arg(dst)
911 .arg(&(n as u32))
912 .arg(&(b_len as u32))
913 .launch(cfg)
914 .map(|_| ())
915 .map_err(|e| CudaError::DriverError(e.to_string()))?;
916 }
917 Ok(())
918 }
919
920 pub fn broadcast_sub_f32(
922 &self,
923 dst: &mut CudaSlice<f32>,
924 a: &CudaSlice<f32>,
925 b: &CudaSlice<f32>,
926 n: usize,
927 b_len: usize,
928 ) -> Result<(), CudaError> {
929 let func = self
930 .kernels
931 .get("broadcast_sub_f32")
932 .ok_or_else(|| CudaError::KernelNotFound("broadcast_sub_f32".to_string()))?;
933 let cfg = cuda_kernels::launch_config(n);
934 unsafe {
935 self.stream
936 .launch_builder(func)
937 .arg(a)
938 .arg(b)
939 .arg(dst)
940 .arg(&(n as u32))
941 .arg(&(b_len as u32))
942 .launch(cfg)
943 .map(|_| ())
944 .map_err(|e| CudaError::DriverError(e.to_string()))?;
945 }
946 Ok(())
947 }
948
949 pub fn broadcast_mul_f32(
951 &self,
952 dst: &mut CudaSlice<f32>,
953 a: &CudaSlice<f32>,
954 b: &CudaSlice<f32>,
955 n: usize,
956 b_len: usize,
957 ) -> Result<(), CudaError> {
958 let func = self
959 .kernels
960 .get("broadcast_mul_f32")
961 .ok_or_else(|| CudaError::KernelNotFound("broadcast_mul_f32".to_string()))?;
962 let cfg = cuda_kernels::launch_config(n);
963 unsafe {
964 self.stream
965 .launch_builder(func)
966 .arg(a)
967 .arg(b)
968 .arg(dst)
969 .arg(&(n as u32))
970 .arg(&(b_len as u32))
971 .launch(cfg)
972 .map(|_| ())
973 .map_err(|e| CudaError::DriverError(e.to_string()))?;
974 }
975 Ok(())
976 }
977
978 pub fn broadcast_div_f32(
980 &self,
981 dst: &mut CudaSlice<f32>,
982 a: &CudaSlice<f32>,
983 b: &CudaSlice<f32>,
984 n: usize,
985 b_len: usize,
986 ) -> Result<(), CudaError> {
987 let func = self
988 .kernels
989 .get("broadcast_div_f32")
990 .ok_or_else(|| CudaError::KernelNotFound("broadcast_div_f32".to_string()))?;
991 let cfg = cuda_kernels::launch_config(n);
992 unsafe {
993 self.stream
994 .launch_builder(func)
995 .arg(a)
996 .arg(b)
997 .arg(dst)
998 .arg(&(n as u32))
999 .arg(&(b_len as u32))
1000 .launch(cfg)
1001 .map(|_| ())
1002 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1003 }
1004 Ok(())
1005 }
1006
1007 pub fn broadcast_add_rev_f32(
1009 &self,
1010 dst: &mut CudaSlice<f32>,
1011 a: &CudaSlice<f32>,
1012 b: &CudaSlice<f32>,
1013 n: usize,
1014 a_len: usize,
1015 ) -> Result<(), CudaError> {
1016 let func = self
1017 .kernels
1018 .get("broadcast_add_rev_f32")
1019 .ok_or_else(|| CudaError::KernelNotFound("broadcast_add_rev_f32".to_string()))?;
1020 let cfg = cuda_kernels::launch_config(n);
1021 unsafe {
1022 self.stream
1023 .launch_builder(func)
1024 .arg(a)
1025 .arg(b)
1026 .arg(dst)
1027 .arg(&(n as u32))
1028 .arg(&(a_len as u32))
1029 .launch(cfg)
1030 .map(|_| ())
1031 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1032 }
1033 Ok(())
1034 }
1035
1036 pub fn broadcast_sub_rev_f32(
1038 &self,
1039 dst: &mut CudaSlice<f32>,
1040 a: &CudaSlice<f32>,
1041 b: &CudaSlice<f32>,
1042 n: usize,
1043 a_len: usize,
1044 ) -> Result<(), CudaError> {
1045 let func = self
1046 .kernels
1047 .get("broadcast_sub_rev_f32")
1048 .ok_or_else(|| CudaError::KernelNotFound("broadcast_sub_rev_f32".to_string()))?;
1049 let cfg = cuda_kernels::launch_config(n);
1050 unsafe {
1051 self.stream
1052 .launch_builder(func)
1053 .arg(a)
1054 .arg(b)
1055 .arg(dst)
1056 .arg(&(n as u32))
1057 .arg(&(a_len as u32))
1058 .launch(cfg)
1059 .map(|_| ())
1060 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1061 }
1062 Ok(())
1063 }
1064
1065 pub fn broadcast_mul_rev_f32(
1067 &self,
1068 dst: &mut CudaSlice<f32>,
1069 a: &CudaSlice<f32>,
1070 b: &CudaSlice<f32>,
1071 n: usize,
1072 a_len: usize,
1073 ) -> Result<(), CudaError> {
1074 let func = self
1075 .kernels
1076 .get("broadcast_mul_rev_f32")
1077 .ok_or_else(|| CudaError::KernelNotFound("broadcast_mul_rev_f32".to_string()))?;
1078 let cfg = cuda_kernels::launch_config(n);
1079 unsafe {
1080 self.stream
1081 .launch_builder(func)
1082 .arg(a)
1083 .arg(b)
1084 .arg(dst)
1085 .arg(&(n as u32))
1086 .arg(&(a_len as u32))
1087 .launch(cfg)
1088 .map(|_| ())
1089 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1090 }
1091 Ok(())
1092 }
1093
1094 pub fn broadcast_div_rev_f32(
1096 &self,
1097 dst: &mut CudaSlice<f32>,
1098 a: &CudaSlice<f32>,
1099 b: &CudaSlice<f32>,
1100 n: usize,
1101 a_len: usize,
1102 ) -> Result<(), CudaError> {
1103 let func = self
1104 .kernels
1105 .get("broadcast_div_rev_f32")
1106 .ok_or_else(|| CudaError::KernelNotFound("broadcast_div_rev_f32".to_string()))?;
1107 let cfg = cuda_kernels::launch_config(n);
1108 unsafe {
1109 self.stream
1110 .launch_builder(func)
1111 .arg(a)
1112 .arg(b)
1113 .arg(dst)
1114 .arg(&(n as u32))
1115 .arg(&(a_len as u32))
1116 .launch(cfg)
1117 .map(|_| ())
1118 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1119 }
1120 Ok(())
1121 }
1122
1123 pub fn neg_f32(
1125 &self,
1126 dst: &mut CudaSlice<f32>,
1127 src: &CudaSlice<f32>,
1128 len: usize,
1129 ) -> Result<(), CudaError> {
1130 let func = self
1131 .kernels
1132 .get("neg_f32")
1133 .ok_or_else(|| CudaError::KernelNotFound("neg_f32".to_string()))?;
1134 let cfg = cuda_kernels::launch_config(len);
1135 unsafe {
1136 self.stream
1137 .launch_builder(func)
1138 .arg(src)
1139 .arg(dst)
1140 .arg(&(len as u32))
1141 .launch(cfg)
1142 .map(|_| ())
1143 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1144 }
1145 Ok(())
1146 }
1147
1148 pub fn pow_f32(
1150 &self,
1151 dst: &mut CudaSlice<f32>,
1152 a: &CudaSlice<f32>,
1153 b: &CudaSlice<f32>,
1154 len: usize,
1155 ) -> Result<(), CudaError> {
1156 let func = self
1157 .kernels
1158 .get("pow_f32")
1159 .ok_or_else(|| CudaError::KernelNotFound("pow_f32".to_string()))?;
1160 let cfg = cuda_kernels::launch_config(len);
1161 unsafe {
1162 self.stream
1163 .launch_builder(func)
1164 .arg(a)
1165 .arg(b)
1166 .arg(dst)
1167 .arg(&(len as u32))
1168 .launch(cfg)
1169 .map(|_| ())
1170 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1171 }
1172 Ok(())
1173 }
1174
1175 pub fn pow_scalar_f32(
1177 &self,
1178 dst: &mut CudaSlice<f32>,
1179 src: &CudaSlice<f32>,
1180 exp: f32,
1181 len: usize,
1182 ) -> Result<(), CudaError> {
1183 let func = self
1184 .kernels
1185 .get("pow_scalar_f32")
1186 .ok_or_else(|| CudaError::KernelNotFound("pow_scalar_f32".to_string()))?;
1187 let cfg = cuda_kernels::launch_config(len);
1188 unsafe {
1189 self.stream
1190 .launch_builder(func)
1191 .arg(src)
1192 .arg(&exp)
1193 .arg(dst)
1194 .arg(&(len as u32))
1195 .launch(cfg)
1196 .map(|_| ())
1197 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1198 }
1199 Ok(())
1200 }
1201
1202 pub fn exp_f32(
1204 &self,
1205 dst: &mut CudaSlice<f32>,
1206 src: &CudaSlice<f32>,
1207 len: usize,
1208 ) -> Result<(), CudaError> {
1209 let func = self
1210 .kernels
1211 .get("exp_f32")
1212 .ok_or_else(|| CudaError::KernelNotFound("exp_f32".to_string()))?;
1213 let cfg = cuda_kernels::launch_config(len);
1214 unsafe {
1215 self.stream
1216 .launch_builder(func)
1217 .arg(src)
1218 .arg(dst)
1219 .arg(&(len as u32))
1220 .launch(cfg)
1221 .map(|_| ())
1222 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1223 }
1224 Ok(())
1225 }
1226
1227 pub fn log_f32(
1229 &self,
1230 dst: &mut CudaSlice<f32>,
1231 src: &CudaSlice<f32>,
1232 len: usize,
1233 ) -> Result<(), CudaError> {
1234 let func = self
1235 .kernels
1236 .get("log_f32")
1237 .ok_or_else(|| CudaError::KernelNotFound("log_f32".to_string()))?;
1238 let cfg = cuda_kernels::launch_config(len);
1239 unsafe {
1240 self.stream
1241 .launch_builder(func)
1242 .arg(src)
1243 .arg(dst)
1244 .arg(&(len as u32))
1245 .launch(cfg)
1246 .map(|_| ())
1247 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1248 }
1249 Ok(())
1250 }
1251
1252 pub fn sqrt_f32(
1254 &self,
1255 dst: &mut CudaSlice<f32>,
1256 src: &CudaSlice<f32>,
1257 len: usize,
1258 ) -> Result<(), CudaError> {
1259 let func = self
1260 .kernels
1261 .get("sqrt_f32")
1262 .ok_or_else(|| CudaError::KernelNotFound("sqrt_f32".to_string()))?;
1263 let cfg = cuda_kernels::launch_config(len);
1264 unsafe {
1265 self.stream
1266 .launch_builder(func)
1267 .arg(src)
1268 .arg(dst)
1269 .arg(&(len as u32))
1270 .launch(cfg)
1271 .map(|_| ())
1272 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1273 }
1274 Ok(())
1275 }
1276
1277 pub fn gelu_f32(
1279 &self,
1280 dst: &mut CudaSlice<f32>,
1281 src: &CudaSlice<f32>,
1282 len: usize,
1283 ) -> Result<(), CudaError> {
1284 let func = self
1285 .kernels
1286 .get("gelu_f32")
1287 .ok_or_else(|| CudaError::KernelNotFound("gelu_f32".to_string()))?;
1288 let cfg = cuda_kernels::launch_config(len);
1289 unsafe {
1290 self.stream
1291 .launch_builder(func)
1292 .arg(src)
1293 .arg(dst)
1294 .arg(&(len as u32))
1295 .launch(cfg)
1296 .map(|_| ())
1297 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1298 }
1299 Ok(())
1300 }
1301
1302 pub fn silu_f32(
1304 &self,
1305 dst: &mut CudaSlice<f32>,
1306 src: &CudaSlice<f32>,
1307 len: usize,
1308 ) -> Result<(), CudaError> {
1309 let func = self
1310 .kernels
1311 .get("silu_f32")
1312 .ok_or_else(|| CudaError::KernelNotFound("silu_f32".to_string()))?;
1313 let cfg = cuda_kernels::launch_config(len);
1314 unsafe {
1315 self.stream
1316 .launch_builder(func)
1317 .arg(src)
1318 .arg(dst)
1319 .arg(&(len as u32))
1320 .launch(cfg)
1321 .map(|_| ())
1322 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1323 }
1324 Ok(())
1325 }
1326
1327 pub fn add_scalar_f32(
1329 &self,
1330 dst: &mut CudaSlice<f32>,
1331 src: &CudaSlice<f32>,
1332 scalar: f32,
1333 len: usize,
1334 ) -> Result<(), CudaError> {
1335 let func = self
1336 .kernels
1337 .get("add_scalar_f32")
1338 .ok_or_else(|| CudaError::KernelNotFound("add_scalar_f32".to_string()))?;
1339 let cfg = cuda_kernels::launch_config(len);
1340 unsafe {
1341 self.stream
1342 .launch_builder(func)
1343 .arg(src)
1344 .arg(&scalar)
1345 .arg(dst)
1346 .arg(&(len as u32))
1347 .launch(cfg)
1348 .map(|_| ())
1349 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1350 }
1351 Ok(())
1352 }
1353
1354 pub fn relu_backward_f32(
1356 &self,
1357 dst: &mut CudaSlice<f32>,
1358 grad_output: &CudaSlice<f32>,
1359 input: &CudaSlice<f32>,
1360 len: usize,
1361 ) -> Result<(), CudaError> {
1362 let func = self
1363 .kernels
1364 .get("relu_backward_f32")
1365 .ok_or_else(|| CudaError::KernelNotFound("relu_backward_f32".to_string()))?;
1366 let cfg = cuda_kernels::launch_config(len);
1367 unsafe {
1368 self.stream
1369 .launch_builder(func)
1370 .arg(grad_output)
1371 .arg(input)
1372 .arg(dst)
1373 .arg(&(len as u32))
1374 .launch(cfg)
1375 .map(|_| ())
1376 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1377 }
1378 Ok(())
1379 }
1380
1381 pub fn sigmoid_backward_f32(
1383 &self,
1384 dst: &mut CudaSlice<f32>,
1385 grad_output: &CudaSlice<f32>,
1386 output: &CudaSlice<f32>,
1387 len: usize,
1388 ) -> Result<(), CudaError> {
1389 let func = self
1390 .kernels
1391 .get("sigmoid_backward_f32")
1392 .ok_or_else(|| CudaError::KernelNotFound("sigmoid_backward_f32".to_string()))?;
1393 let cfg = cuda_kernels::launch_config(len);
1394 unsafe {
1395 self.stream
1396 .launch_builder(func)
1397 .arg(grad_output)
1398 .arg(output)
1399 .arg(dst)
1400 .arg(&(len as u32))
1401 .launch(cfg)
1402 .map(|_| ())
1403 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1404 }
1405 Ok(())
1406 }
1407
1408 pub fn tanh_backward_f32(
1410 &self,
1411 dst: &mut CudaSlice<f32>,
1412 grad_output: &CudaSlice<f32>,
1413 output: &CudaSlice<f32>,
1414 len: usize,
1415 ) -> Result<(), CudaError> {
1416 let func = self
1417 .kernels
1418 .get("tanh_backward_f32")
1419 .ok_or_else(|| CudaError::KernelNotFound("tanh_backward_f32".to_string()))?;
1420 let cfg = cuda_kernels::launch_config(len);
1421 unsafe {
1422 self.stream
1423 .launch_builder(func)
1424 .arg(grad_output)
1425 .arg(output)
1426 .arg(dst)
1427 .arg(&(len as u32))
1428 .launch(cfg)
1429 .map(|_| ())
1430 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1431 }
1432 Ok(())
1433 }
1434
1435 pub fn sum_dim_f32(
1438 &self,
1439 dst: &mut CudaSlice<f32>,
1440 src: &CudaSlice<f32>,
1441 outer_size: usize,
1442 dim_size: usize,
1443 inner_size: usize,
1444 ) -> Result<(), CudaError> {
1445 let func = self
1446 .kernels
1447 .get("sum_dim_f32")
1448 .ok_or_else(|| CudaError::KernelNotFound("sum_dim_f32".to_string()))?;
1449 let out_len = outer_size * inner_size;
1450 let cfg = cuda_kernels::launch_config(out_len);
1451 unsafe {
1452 self.stream
1453 .launch_builder(func)
1454 .arg(src)
1455 .arg(dst)
1456 .arg(&(outer_size as u32))
1457 .arg(&(dim_size as u32))
1458 .arg(&(inner_size as u32))
1459 .launch(cfg)
1460 .map(|_| ())
1461 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1462 }
1463 Ok(())
1464 }
1465
1466 pub fn softmax_row_f32(
1470 &self,
1471 data: &mut CudaSlice<f32>,
1472 num_rows: usize,
1473 row_size: usize,
1474 ) -> Result<(), CudaError> {
1475 let func = self
1476 .kernels
1477 .get("softmax_row_f32")
1478 .ok_or_else(|| CudaError::KernelNotFound("softmax_row_f32".to_string()))?;
1479 let cfg = LaunchConfig {
1481 grid_dim: (num_rows as u32, 1, 1),
1482 block_dim: (BLOCK_SIZE, 1, 1),
1483 shared_mem_bytes: BLOCK_SIZE * 4,
1484 };
1485 unsafe {
1486 self.stream
1487 .launch_builder(func)
1488 .arg(data)
1489 .arg(&(num_rows as u32))
1490 .arg(&(row_size as u32))
1491 .launch(cfg)
1492 .map(|_| ())
1493 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1494 }
1495 Ok(())
1496 }
1497
1498 pub fn broadcast_copy_f32(
1500 &self,
1501 dst: &mut CudaSlice<f32>,
1502 src: &CudaSlice<f32>,
1503 n: usize,
1504 src_len: usize,
1505 ) -> Result<(), CudaError> {
1506 let func = self
1507 .kernels
1508 .get("broadcast_copy_f32")
1509 .ok_or_else(|| CudaError::KernelNotFound("broadcast_copy_f32".to_string()))?;
1510 let cfg = cuda_kernels::launch_config(n);
1511 unsafe {
1512 self.stream
1513 .launch_builder(func)
1514 .arg(src)
1515 .arg(dst)
1516 .arg(&(n as u32))
1517 .arg(&(src_len as u32))
1518 .launch(cfg)
1519 .map(|_| ())
1520 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1521 }
1522 Ok(())
1523 }
1524
1525 pub fn layer_norm_f32(
1528 &self,
1529 dst: &mut CudaSlice<f32>,
1530 input: &CudaSlice<f32>,
1531 gamma: &CudaSlice<f32>,
1532 beta: &CudaSlice<f32>,
1533 norm_size: usize,
1534 eps: f32,
1535 num_rows: usize,
1536 ) -> Result<(), CudaError> {
1537 let func = self
1538 .kernels
1539 .get("layer_norm_f32")
1540 .ok_or_else(|| CudaError::KernelNotFound("layer_norm_f32".to_string()))?;
1541 let cfg = LaunchConfig {
1542 grid_dim: (num_rows as u32, 1, 1),
1543 block_dim: (BLOCK_SIZE, 1, 1),
1544 shared_mem_bytes: BLOCK_SIZE * 4,
1545 };
1546 unsafe {
1547 self.stream
1548 .launch_builder(func)
1549 .arg(input)
1550 .arg(gamma)
1551 .arg(beta)
1552 .arg(dst)
1553 .arg(&(norm_size as u32))
1554 .arg(&eps)
1555 .arg(&(num_rows as u32))
1556 .launch(cfg)
1557 .map(|_| ())
1558 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1559 }
1560 Ok(())
1561 }
1562
1563 pub fn softmax_backward_row_f32(
1567 &self,
1568 dst: &mut CudaSlice<f32>,
1569 softmax_output: &CudaSlice<f32>,
1570 grad_output: &CudaSlice<f32>,
1571 num_rows: usize,
1572 row_size: usize,
1573 ) -> Result<(), CudaError> {
1574 let func = self
1575 .kernels
1576 .get("softmax_backward_row_f32")
1577 .ok_or_else(|| CudaError::KernelNotFound("softmax_backward_row_f32".to_string()))?;
1578 let cfg = LaunchConfig {
1579 grid_dim: (num_rows as u32, 1, 1),
1580 block_dim: (BLOCK_SIZE, 1, 1),
1581 shared_mem_bytes: BLOCK_SIZE * 4,
1582 };
1583 unsafe {
1584 self.stream
1585 .launch_builder(func)
1586 .arg(softmax_output)
1587 .arg(grad_output)
1588 .arg(dst)
1589 .arg(&(num_rows as u32))
1590 .arg(&(row_size as u32))
1591 .launch(cfg)
1592 .map(|_| ())
1593 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1594 }
1595 Ok(())
1596 }
1597
1598 pub fn layer_norm_backward_dinput_f32(
1601 &self,
1602 d_input: &mut CudaSlice<f32>,
1603 grad_output: &CudaSlice<f32>,
1604 input: &CudaSlice<f32>,
1605 gamma: &CudaSlice<f32>,
1606 norm_size: usize,
1607 eps: f32,
1608 num_rows: usize,
1609 ) -> Result<(), CudaError> {
1610 let func = self
1611 .kernels
1612 .get("layer_norm_backward_dinput_f32")
1613 .ok_or_else(|| {
1614 CudaError::KernelNotFound("layer_norm_backward_dinput_f32".to_string())
1615 })?;
1616 let cfg = LaunchConfig {
1617 grid_dim: (num_rows as u32, 1, 1),
1618 block_dim: (BLOCK_SIZE, 1, 1),
1619 shared_mem_bytes: BLOCK_SIZE * 4 * 2, };
1621 unsafe {
1622 self.stream
1623 .launch_builder(func)
1624 .arg(grad_output)
1625 .arg(input)
1626 .arg(gamma)
1627 .arg(d_input)
1628 .arg(&(norm_size as u32))
1629 .arg(&eps)
1630 .arg(&(num_rows as u32))
1631 .launch(cfg)
1632 .map(|_| ())
1633 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1634 }
1635 Ok(())
1636 }
1637
1638 pub fn layer_norm_backward_dweight_dbias_f32(
1641 &self,
1642 d_weight: &mut CudaSlice<f32>,
1643 d_bias: &mut CudaSlice<f32>,
1644 grad_output: &CudaSlice<f32>,
1645 input: &CudaSlice<f32>,
1646 norm_size: usize,
1647 eps: f32,
1648 num_rows: usize,
1649 ) -> Result<(), CudaError> {
1650 let func = self
1651 .kernels
1652 .get("layer_norm_backward_dweight_dbias_f32")
1653 .ok_or_else(|| {
1654 CudaError::KernelNotFound("layer_norm_backward_dweight_dbias_f32".to_string())
1655 })?;
1656 let cfg = cuda_kernels::launch_config(norm_size);
1657 unsafe {
1658 self.stream
1659 .launch_builder(func)
1660 .arg(grad_output)
1661 .arg(input)
1662 .arg(d_weight)
1663 .arg(d_bias)
1664 .arg(&(norm_size as u32))
1665 .arg(&eps)
1666 .arg(&(num_rows as u32))
1667 .launch(cfg)
1668 .map(|_| ())
1669 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1670 }
1671 Ok(())
1672 }
1673
1674 pub fn gather_contiguous_f32(
1676 &self,
1677 dst: &mut CudaSlice<f32>,
1678 src: &CudaSlice<f32>,
1679 indices: &CudaSlice<u32>,
1680 n: usize,
1681 ) -> Result<(), CudaError> {
1682 let func = self
1683 .kernels
1684 .get("gather_contiguous_f32")
1685 .ok_or_else(|| CudaError::KernelNotFound("gather_contiguous_f32".to_string()))?;
1686 let cfg = cuda_kernels::launch_config(n);
1687 unsafe {
1688 self.stream
1689 .launch_builder(func)
1690 .arg(src)
1691 .arg(indices)
1692 .arg(dst)
1693 .arg(&(n as u32))
1694 .launch(cfg)
1695 .map(|_| ())
1696 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1697 }
1698 Ok(())
1699 }
1700
1701 pub fn embedding_scatter_add_f32(
1704 &self,
1705 grad_src: &CudaSlice<f32>,
1706 indices: &CudaSlice<u32>,
1707 weight_grad: &mut CudaSlice<f32>,
1708 total_n: usize,
1709 emb_dim: usize,
1710 ) -> Result<(), CudaError> {
1711 let func = self
1712 .kernels
1713 .get("embedding_scatter_add_f32")
1714 .ok_or_else(|| CudaError::KernelNotFound("embedding_scatter_add_f32".to_string()))?;
1715 let cfg = cuda_kernels::launch_config(total_n);
1716 unsafe {
1717 self.stream
1718 .launch_builder(func)
1719 .arg(grad_src)
1720 .arg(indices)
1721 .arg(weight_grad)
1722 .arg(&(total_n as u32))
1723 .arg(&(emb_dim as u32))
1724 .launch(cfg)
1725 .map(|_| ())
1726 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1727 }
1728 Ok(())
1729 }
1730
1731 #[allow(clippy::too_many_arguments)]
1734 pub fn adam_step_f32(
1735 &self,
1736 param: &mut CudaSlice<f32>,
1737 grad: &CudaSlice<f32>,
1738 exp_avg: &mut CudaSlice<f32>,
1739 exp_avg_sq: &mut CudaSlice<f32>,
1740 n: usize,
1741 lr: f32,
1742 beta1: f32,
1743 beta2: f32,
1744 eps: f32,
1745 weight_decay: f32,
1746 bias_correction1: f32,
1747 bias_correction2: f32,
1748 ) -> Result<(), CudaError> {
1749 let func = self
1750 .kernels
1751 .get("adam_step_f32")
1752 .ok_or_else(|| CudaError::KernelNotFound("adam_step_f32".to_string()))?;
1753 let cfg = cuda_kernels::launch_config(n);
1754 unsafe {
1755 self.stream
1756 .launch_builder(func)
1757 .arg(param)
1758 .arg(grad)
1759 .arg(exp_avg)
1760 .arg(exp_avg_sq)
1761 .arg(&(n as u32))
1762 .arg(&lr)
1763 .arg(&beta1)
1764 .arg(&beta2)
1765 .arg(&eps)
1766 .arg(&weight_decay)
1767 .arg(&bias_correction1)
1768 .arg(&bias_correction2)
1769 .launch(cfg)
1770 .map(|_| ())
1771 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1772 }
1773 Ok(())
1774 }
1775
1776 pub fn grad_norm_sq_f32(
1779 &self,
1780 data: &CudaSlice<f32>,
1781 output: &mut CudaSlice<f32>,
1782 n: usize,
1783 ) -> Result<(), CudaError> {
1784 let func = self
1785 .kernels
1786 .get("grad_norm_sq_f32")
1787 .ok_or_else(|| CudaError::KernelNotFound("grad_norm_sq_f32".to_string()))?;
1788 let cfg = cuda_kernels::launch_config(n);
1789 unsafe {
1790 self.stream
1791 .launch_builder(func)
1792 .arg(data)
1793 .arg(output)
1794 .arg(&(n as u32))
1795 .launch(cfg)
1796 .map(|_| ())
1797 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1798 }
1799 Ok(())
1800 }
1801
1802 pub fn grad_scale_f32(
1804 &self,
1805 data: &mut CudaSlice<f32>,
1806 n: usize,
1807 scale: f32,
1808 ) -> Result<(), CudaError> {
1809 let func = self
1810 .kernels
1811 .get("grad_scale_f32")
1812 .ok_or_else(|| CudaError::KernelNotFound("grad_scale_f32".to_string()))?;
1813 let cfg = cuda_kernels::launch_config(n);
1814 unsafe {
1815 self.stream
1816 .launch_builder(func)
1817 .arg(data)
1818 .arg(&(n as u32))
1819 .arg(&scale)
1820 .launch(cfg)
1821 .map(|_| ())
1822 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1823 }
1824 Ok(())
1825 }
1826
1827 pub fn cross_entropy_fwd_f32(
1831 &self,
1832 logits: &CudaSlice<f32>,
1833 targets: &CudaSlice<f32>,
1834 losses: &mut CudaSlice<f32>,
1835 softmax_out: &mut CudaSlice<f32>,
1836 batch_size: usize,
1837 num_classes: usize,
1838 ) -> Result<(), CudaError> {
1839 let func = self
1840 .kernels
1841 .get("cross_entropy_fwd_f32")
1842 .ok_or_else(|| CudaError::KernelNotFound("cross_entropy_fwd_f32".to_string()))?;
1843 let cfg = LaunchConfig {
1844 grid_dim: (batch_size as u32, 1, 1),
1845 block_dim: (BLOCK_SIZE, 1, 1),
1846 shared_mem_bytes: BLOCK_SIZE * 4,
1847 };
1848 unsafe {
1849 self.stream
1850 .launch_builder(func)
1851 .arg(logits)
1852 .arg(targets)
1853 .arg(losses)
1854 .arg(softmax_out)
1855 .arg(&(num_classes as u32))
1856 .launch(cfg)
1857 .map(|_| ())
1858 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1859 }
1860 Ok(())
1861 }
1862
1863 pub fn cross_entropy_bwd_f32(
1866 &self,
1867 softmax_probs: &CudaSlice<f32>,
1868 targets: &CudaSlice<f32>,
1869 grad_output: &CudaSlice<f32>,
1870 grad_input: &mut CudaSlice<f32>,
1871 batch_size: usize,
1872 num_classes: usize,
1873 ) -> Result<(), CudaError> {
1874 let func = self
1875 .kernels
1876 .get("cross_entropy_bwd_f32")
1877 .ok_or_else(|| CudaError::KernelNotFound("cross_entropy_bwd_f32".to_string()))?;
1878 let total = batch_size * num_classes;
1879 let cfg = cuda_kernels::launch_config(total);
1880 unsafe {
1881 self.stream
1882 .launch_builder(func)
1883 .arg(softmax_probs)
1884 .arg(targets)
1885 .arg(grad_output)
1886 .arg(grad_input)
1887 .arg(&(batch_size as u32))
1888 .arg(&(num_classes as u32))
1889 .launch(cfg)
1890 .map(|_| ())
1891 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1892 }
1893 Ok(())
1894 }
1895
1896 #[cfg(feature = "cuda")]
1898 pub fn memset_zeros_f32(&self, dst: &mut CudaSlice<f32>) -> Result<(), CudaError> {
1899 self.stream
1900 .memset_zeros(dst)
1901 .map_err(|e| CudaError::DriverError(e.to_string()))
1902 }
1903
1904 #[cfg(feature = "cuda")]
1907 pub fn memcpy_dtod_f32(
1908 &self,
1909 dst: &mut CudaSlice<f32>,
1910 dst_offset: usize,
1911 src: &CudaSlice<f32>,
1912 src_offset: usize,
1913 count: usize,
1914 ) -> Result<(), CudaError> {
1915 use cudarc::driver::DevicePtr as _;
1916 let (src_ptr, _guard_s) = src.device_ptr(&self.stream);
1917 let src_ptr =
1918 src_ptr + (src_offset * std::mem::size_of::<f32>()) as cudarc::driver::sys::CUdeviceptr;
1919 use cudarc::driver::DevicePtrMut as _;
1920 let (dst_ptr, _guard_d) = dst.device_ptr_mut(&self.stream);
1921 let dst_ptr =
1922 dst_ptr + (dst_offset * std::mem::size_of::<f32>()) as cudarc::driver::sys::CUdeviceptr;
1923 let size = count * std::mem::size_of::<f32>();
1924 unsafe {
1925 cudarc::driver::result::memcpy_dtod_sync(dst_ptr, src_ptr, size)
1926 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1927 }
1928 Ok(())
1929 }
1930}
1931
1932#[cfg(feature = "cuda")]
1937impl CudaBackend {
1938 pub fn mask_expand_causal_f32(
1940 &self,
1941 mask: &CudaSlice<f32>,
1942 output: &mut CudaSlice<f32>,
1943 total_n: usize,
1944 tgt_len: usize,
1945 src_len: usize,
1946 ) -> Result<(), CudaError> {
1947 let func = self
1948 .kernels
1949 .get("mask_expand_causal_f32")
1950 .ok_or_else(|| CudaError::KernelNotFound("mask_expand_causal_f32".to_string()))?;
1951 let cfg = cuda_kernels::launch_config(total_n);
1952 unsafe {
1953 self.stream
1954 .launch_builder(func)
1955 .arg(mask)
1956 .arg(output)
1957 .arg(&(total_n as u32))
1958 .arg(&(tgt_len as u32))
1959 .arg(&(src_len as u32))
1960 .launch(cfg)
1961 .map(|_| ())
1962 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1963 }
1964 Ok(())
1965 }
1966
1967 pub fn mask_expand_padding_f32(
1969 &self,
1970 mask: &CudaSlice<f32>,
1971 output: &mut CudaSlice<f32>,
1972 total_n: usize,
1973 num_heads: usize,
1974 tgt_len: usize,
1975 src_len: usize,
1976 ) -> Result<(), CudaError> {
1977 let func = self
1978 .kernels
1979 .get("mask_expand_padding_f32")
1980 .ok_or_else(|| CudaError::KernelNotFound("mask_expand_padding_f32".to_string()))?;
1981 let cfg = cuda_kernels::launch_config(total_n);
1982 unsafe {
1983 self.stream
1984 .launch_builder(func)
1985 .arg(mask)
1986 .arg(output)
1987 .arg(&(total_n as u32))
1988 .arg(&(num_heads as u32))
1989 .arg(&(tgt_len as u32))
1990 .arg(&(src_len as u32))
1991 .launch(cfg)
1992 .map(|_| ())
1993 .map_err(|e| CudaError::DriverError(e.to_string()))?;
1994 }
1995 Ok(())
1996 }
1997}
1998
1999#[cfg(feature = "cuda")]
2004impl CudaBackend {
2005 pub fn strided_gather_f32(
2008 &self,
2009 src: &CudaSlice<f32>,
2010 dst: &mut CudaSlice<f32>,
2011 strides: &CudaSlice<i64>,
2012 shape: &CudaSlice<u32>,
2013 ndim: usize,
2014 offset: usize,
2015 total_n: usize,
2016 ) -> Result<(), CudaError> {
2017 let func = self
2018 .kernels
2019 .get("strided_gather_f32")
2020 .ok_or_else(|| CudaError::KernelNotFound("strided_gather_f32".to_string()))?;
2021
2022 let cfg = cuda_kernels::launch_config(total_n);
2023 unsafe {
2024 self.stream
2025 .launch_builder(func)
2026 .arg(src)
2027 .arg(dst)
2028 .arg(strides)
2029 .arg(shape)
2030 .arg(&(ndim as u32))
2031 .arg(&(offset as u32))
2032 .arg(&(total_n as u32))
2033 .launch(cfg)
2034 .map(|_| ())
2035 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2036 }
2037 Ok(())
2038 }
2039
2040 pub fn lstm_gates_f32(
2055 &self,
2056 gates: &CudaSlice<f32>,
2057 c_prev: &CudaSlice<f32>,
2058 h_new: &mut CudaSlice<f32>,
2059 c_new: &mut CudaSlice<f32>,
2060 hidden_size: usize,
2061 total: usize,
2062 ) -> Result<(), CudaError> {
2063 let func = self
2064 .kernels
2065 .get("lstm_gates_f32")
2066 .ok_or_else(|| CudaError::KernelNotFound("lstm_gates_f32".to_string()))?;
2067 let cfg = cuda_kernels::launch_config(total);
2068 unsafe {
2069 self.stream
2070 .launch_builder(func)
2071 .arg(gates)
2072 .arg(c_prev)
2073 .arg(h_new)
2074 .arg(c_new)
2075 .arg(&(hidden_size as u32))
2076 .arg(&(total as u32))
2077 .launch(cfg)
2078 .map(|_| ())
2079 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2080 }
2081 Ok(())
2082 }
2083
2084 pub fn lstm_gates_backward_f32(
2101 &self,
2102 gates: &CudaSlice<f32>,
2103 c_prev: &CudaSlice<f32>,
2104 c_new: &CudaSlice<f32>,
2105 grad_h: &CudaSlice<f32>,
2106 grad_c_next: &CudaSlice<f32>,
2107 grad_gates: &mut CudaSlice<f32>,
2108 grad_c_prev: &mut CudaSlice<f32>,
2109 hidden_size: usize,
2110 total: usize,
2111 ) -> Result<(), CudaError> {
2112 let func = self
2113 .kernels
2114 .get("lstm_gates_backward_f32")
2115 .ok_or_else(|| CudaError::KernelNotFound("lstm_gates_backward_f32".to_string()))?;
2116 let cfg = cuda_kernels::launch_config(total);
2117 unsafe {
2118 self.stream
2119 .launch_builder(func)
2120 .arg(gates)
2121 .arg(c_prev)
2122 .arg(c_new)
2123 .arg(grad_h)
2124 .arg(grad_c_next)
2125 .arg(grad_gates)
2126 .arg(grad_c_prev)
2127 .arg(&(hidden_size as u32))
2128 .arg(&(total as u32))
2129 .launch(cfg)
2130 .map(|_| ())
2131 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2132 }
2133 Ok(())
2134 }
2135
2136 pub fn gru_gates_f32(
2147 &self,
2148 gates_ih: &CudaSlice<f32>,
2149 gates_hh: &CudaSlice<f32>,
2150 h_prev: &CudaSlice<f32>,
2151 h_new: &mut CudaSlice<f32>,
2152 hidden_size: usize,
2153 total: usize,
2154 ) -> Result<(), CudaError> {
2155 let func = self
2156 .kernels
2157 .get("gru_gates_f32")
2158 .ok_or_else(|| CudaError::KernelNotFound("gru_gates_f32".to_string()))?;
2159 let cfg = cuda_kernels::launch_config(total);
2160 unsafe {
2161 self.stream
2162 .launch_builder(func)
2163 .arg(gates_ih)
2164 .arg(gates_hh)
2165 .arg(h_prev)
2166 .arg(h_new)
2167 .arg(&(hidden_size as u32))
2168 .arg(&(total as u32))
2169 .launch(cfg)
2170 .map(|_| ())
2171 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2172 }
2173 Ok(())
2174 }
2175
2176 pub fn gru_gates_backward_f32(
2193 &self,
2194 gates_ih: &CudaSlice<f32>,
2195 gates_hh: &CudaSlice<f32>,
2196 h_prev: &CudaSlice<f32>,
2197 grad_h_new: &CudaSlice<f32>,
2198 grad_gates_ih: &mut CudaSlice<f32>,
2199 grad_gates_hh: &mut CudaSlice<f32>,
2200 grad_h_prev: &mut CudaSlice<f32>,
2201 hidden_size: usize,
2202 total: usize,
2203 ) -> Result<(), CudaError> {
2204 let func = self
2205 .kernels
2206 .get("gru_gates_backward_f32")
2207 .ok_or_else(|| CudaError::KernelNotFound("gru_gates_backward_f32".to_string()))?;
2208 let cfg = cuda_kernels::launch_config(total);
2209 unsafe {
2210 self.stream
2211 .launch_builder(func)
2212 .arg(gates_ih)
2213 .arg(gates_hh)
2214 .arg(h_prev)
2215 .arg(grad_h_new)
2216 .arg(grad_gates_ih)
2217 .arg(grad_gates_hh)
2218 .arg(grad_h_prev)
2219 .arg(&(hidden_size as u32))
2220 .arg(&(total as u32))
2221 .launch(cfg)
2222 .map(|_| ())
2223 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2224 }
2225 Ok(())
2226 }
2227
2228 pub fn batchnorm_stats_f32(
2234 &self,
2235 x: &CudaSlice<f32>,
2236 sum_out: &mut CudaSlice<f32>,
2237 sum_sq_out: &mut CudaSlice<f32>,
2238 n: usize,
2239 c: usize,
2240 spatial: usize,
2241 ) -> Result<(), CudaError> {
2242 let func = self
2243 .kernels
2244 .get("batchnorm_stats_f32")
2245 .ok_or_else(|| CudaError::KernelNotFound("batchnorm_stats_f32".to_string()))?;
2246 let total = n * c * spatial;
2247 let cfg = cuda_kernels::launch_config(total);
2248 unsafe {
2249 self.stream
2250 .launch_builder(func)
2251 .arg(x)
2252 .arg(sum_out)
2253 .arg(sum_sq_out)
2254 .arg(&(n as u32))
2255 .arg(&(c as u32))
2256 .arg(&(spatial as u32))
2257 .launch(cfg)
2258 .map(|_| ())
2259 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2260 }
2261 Ok(())
2262 }
2263
2264 pub fn batchnorm_norm_f32(
2266 &self,
2267 x: &CudaSlice<f32>,
2268 mean: &CudaSlice<f32>,
2269 var: &CudaSlice<f32>,
2270 gamma: &CudaSlice<f32>,
2271 beta: &CudaSlice<f32>,
2272 y: &mut CudaSlice<f32>,
2273 eps: f32,
2274 c: usize,
2275 spatial: usize,
2276 total: usize,
2277 ) -> Result<(), CudaError> {
2278 let func = self
2279 .kernels
2280 .get("batchnorm_norm_f32")
2281 .ok_or_else(|| CudaError::KernelNotFound("batchnorm_norm_f32".to_string()))?;
2282 let cfg = cuda_kernels::launch_config(total);
2283 unsafe {
2284 self.stream
2285 .launch_builder(func)
2286 .arg(x)
2287 .arg(mean)
2288 .arg(var)
2289 .arg(gamma)
2290 .arg(beta)
2291 .arg(y)
2292 .arg(&eps)
2293 .arg(&(c as u32))
2294 .arg(&(spatial as u32))
2295 .arg(&(total as u32))
2296 .launch(cfg)
2297 .map(|_| ())
2298 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2299 }
2300 Ok(())
2301 }
2302}
2303
2304#[cfg(feature = "cuda")]
2309impl CudaBackend {
2310 pub fn fused_attention_fwd_f32(
2316 &self,
2317 q: &CudaSlice<f32>,
2318 k: &CudaSlice<f32>,
2319 v: &CudaSlice<f32>,
2320 output: &mut CudaSlice<f32>,
2321 scale: f32,
2322 batch_size: usize,
2323 num_heads: usize,
2324 tgt_len: usize,
2325 src_len: usize,
2326 head_dim: usize,
2327 is_causal: bool,
2328 ) -> Result<(), CudaError> {
2329 let func = self
2330 .kernels
2331 .get("fused_attention_fwd_f32")
2332 .ok_or_else(|| CudaError::KernelNotFound("fused_attention_fwd_f32".to_string()))?;
2333 let total_rows = batch_size * num_heads * tgt_len;
2334 let cfg = cuda_kernels::launch_config(total_rows);
2335 let is_causal_u32: u32 = if is_causal { 1 } else { 0 };
2336 unsafe {
2337 self.stream
2338 .launch_builder(func)
2339 .arg(q)
2340 .arg(k)
2341 .arg(v)
2342 .arg(output)
2343 .arg(&scale)
2344 .arg(&(batch_size as u32))
2345 .arg(&(num_heads as u32))
2346 .arg(&(tgt_len as u32))
2347 .arg(&(src_len as u32))
2348 .arg(&(head_dim as u32))
2349 .arg(&is_causal_u32)
2350 .launch(cfg)
2351 .map(|_| ())
2352 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2353 }
2354 Ok(())
2355 }
2356}
2357
2358#[cfg(feature = "cuda")]
2363impl CudaBackend {
2364 pub fn fused_attention_bwd_f32(
2372 &self,
2373 q: &CudaSlice<f32>,
2374 k: &CudaSlice<f32>,
2375 v: &CudaSlice<f32>,
2376 o: &CudaSlice<f32>,
2377 grad_o: &CudaSlice<f32>,
2378 grad_q: &mut CudaSlice<f32>,
2379 grad_k: &mut CudaSlice<f32>,
2380 grad_v: &mut CudaSlice<f32>,
2381 scale: f32,
2382 batch_size: usize,
2383 num_heads: usize,
2384 tgt_len: usize,
2385 src_len: usize,
2386 head_dim: usize,
2387 is_causal: bool,
2388 ) -> Result<(), CudaError> {
2389 let func = self
2390 .kernels
2391 .get("fused_attention_bwd_f32")
2392 .ok_or_else(|| CudaError::KernelNotFound("fused_attention_bwd_f32".to_string()))?;
2393 let total_rows = batch_size * num_heads * tgt_len;
2394 let cfg = cuda_kernels::launch_config(total_rows);
2395 let is_causal_u32: u32 = if is_causal { 1 } else { 0 };
2396 unsafe {
2397 self.stream
2398 .launch_builder(func)
2399 .arg(q)
2400 .arg(k)
2401 .arg(v)
2402 .arg(o)
2403 .arg(grad_o)
2404 .arg(grad_q)
2405 .arg(grad_k)
2406 .arg(grad_v)
2407 .arg(&scale)
2408 .arg(&(batch_size as u32))
2409 .arg(&(num_heads as u32))
2410 .arg(&(tgt_len as u32))
2411 .arg(&(src_len as u32))
2412 .arg(&(head_dim as u32))
2413 .arg(&is_causal_u32)
2414 .launch(cfg)
2415 .map(|_| ())
2416 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2417 }
2418 Ok(())
2419 }
2420}
2421
2422#[cfg(feature = "cuda")]
2427impl CudaBackend {
2428 pub fn im2col_f32(
2435 &self,
2436 input: &CudaSlice<f32>,
2437 col: &mut CudaSlice<f32>,
2438 params: &CudaSlice<u32>,
2439 n: usize,
2440 ) -> Result<(), CudaError> {
2441 let func = self
2442 .kernels
2443 .get("im2col_f32")
2444 .ok_or_else(|| CudaError::KernelNotFound("im2col_f32".to_string()))?;
2445
2446 let cfg = cuda_kernels::launch_config(n);
2447 unsafe {
2448 self.stream
2449 .launch_builder(func)
2450 .arg(input)
2451 .arg(col)
2452 .arg(params)
2453 .arg(&(n as u32))
2454 .launch(cfg)
2455 .map(|_| ())
2456 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2457 }
2458 Ok(())
2459 }
2460
2461 pub fn col2im_f32(
2466 &self,
2467 col: &CudaSlice<f32>,
2468 output: &mut CudaSlice<f32>,
2469 params: &CudaSlice<u32>,
2470 n: usize,
2471 ) -> Result<(), CudaError> {
2472 let func = self
2473 .kernels
2474 .get("col2im_f32")
2475 .ok_or_else(|| CudaError::KernelNotFound("col2im_f32".to_string()))?;
2476
2477 let cfg = cuda_kernels::launch_config(n);
2478 unsafe {
2479 self.stream
2480 .launch_builder(func)
2481 .arg(col)
2482 .arg(output)
2483 .arg(params)
2484 .arg(&(n as u32))
2485 .launch(cfg)
2486 .map(|_| ())
2487 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2488 }
2489 Ok(())
2490 }
2491
2492 pub fn bias_add_channels_f32(
2496 &self,
2497 data: &mut CudaSlice<f32>,
2498 bias: &CudaSlice<f32>,
2499 spatial: usize,
2500 n: usize,
2501 ) -> Result<(), CudaError> {
2502 let func = self
2503 .kernels
2504 .get("bias_add_channels_f32")
2505 .ok_or_else(|| CudaError::KernelNotFound("bias_add_channels_f32".to_string()))?;
2506
2507 let cfg = cuda_kernels::launch_config(n);
2508 unsafe {
2509 self.stream
2510 .launch_builder(func)
2511 .arg(data)
2512 .arg(bias)
2513 .arg(&(spatial as u32))
2514 .arg(&(n as u32))
2515 .launch(cfg)
2516 .map(|_| ())
2517 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2518 }
2519 Ok(())
2520 }
2521
2522 pub fn conv2d_forward(
2527 &self,
2528 input: &[f32],
2529 weight: &[f32],
2530 bias: Option<&[f32]>,
2531 batch_size: usize,
2532 in_channels: usize,
2533 in_height: usize,
2534 in_width: usize,
2535 out_channels: usize,
2536 kernel_h: usize,
2537 kernel_w: usize,
2538 stride_h: usize,
2539 stride_w: usize,
2540 pad_h: usize,
2541 pad_w: usize,
2542 ) -> Option<Vec<f32>> {
2543 let out_h = (in_height + 2 * pad_h - kernel_h) / stride_h + 1;
2544 let out_w = (in_width + 2 * pad_w - kernel_w) / stride_w + 1;
2545 let col_h = in_channels * kernel_h * kernel_w;
2546 let col_w = out_h * out_w;
2547 let col_n = col_h * col_w;
2548 let spatial = out_h * out_w;
2549 let out_per_batch = out_channels * spatial;
2550 let in_per_batch = in_channels * in_height * in_width;
2551
2552 use super::cuda_pool::pool_alloc;
2553
2554 let weight_gpu = self.htod_copy(weight).ok()?;
2556
2557 let bias_gpu = bias.and_then(|b| self.htod_copy(b).ok());
2559
2560 let im2col_params: [u32; 10] = [
2562 in_height as u32,
2563 in_width as u32,
2564 kernel_h as u32,
2565 kernel_w as u32,
2566 pad_h as u32,
2567 pad_w as u32,
2568 stride_h as u32,
2569 stride_w as u32,
2570 out_h as u32,
2571 out_w as u32,
2572 ];
2573 let params_gpu = self.htod_copy(&im2col_params[..]).ok()?;
2574
2575 let mut col_gpu = pool_alloc(col_n).ok()?;
2577
2578 let mut batch_out_gpu = pool_alloc(out_per_batch).ok()?;
2580
2581 let mut output = vec![0.0f32; batch_size * out_per_batch];
2582
2583 for b in 0..batch_size {
2584 let input_slice = &input[b * in_per_batch..(b + 1) * in_per_batch];
2586 let input_gpu = self.htod_copy(input_slice).ok()?;
2587
2588 self.im2col_f32(&input_gpu, &mut col_gpu, ¶ms_gpu, col_n)
2590 .ok()?;
2591
2592 self.gemm_f32(
2600 false,
2601 false,
2602 col_w,
2603 out_channels,
2604 col_h,
2605 1.0,
2606 &col_gpu,
2607 col_w,
2608 &weight_gpu,
2609 col_h,
2610 0.0,
2611 &mut batch_out_gpu,
2612 col_w,
2613 )
2614 .ok()?;
2615
2616 if let Some(ref bg) = bias_gpu {
2618 self.bias_add_channels_f32(&mut batch_out_gpu, bg, spatial, out_per_batch)
2619 .ok()?;
2620 }
2621
2622 let batch_result = self.dtoh_copy(&batch_out_gpu).ok()?;
2624 output[b * out_per_batch..(b + 1) * out_per_batch]
2625 .copy_from_slice(&batch_result[..out_per_batch]);
2626 }
2627
2628 Some(output)
2629 }
2630}
2631
2632#[cfg(feature = "cuda")]
2637pub fn cuda_conv2d_forward(
2638 input: &[f32],
2639 weight: &[f32],
2640 bias: Option<&[f32]>,
2641 batch_size: usize,
2642 in_channels: usize,
2643 in_height: usize,
2644 in_width: usize,
2645 out_channels: usize,
2646 kernel_h: usize,
2647 kernel_w: usize,
2648 stride_h: usize,
2649 stride_w: usize,
2650 pad_h: usize,
2651 pad_w: usize,
2652) -> Option<Vec<f32>> {
2653 let cuda = get_cuda_backend()?;
2654 cuda.conv2d_forward(
2655 input,
2656 weight,
2657 bias,
2658 batch_size,
2659 in_channels,
2660 in_height,
2661 in_width,
2662 out_channels,
2663 kernel_h,
2664 kernel_w,
2665 stride_h,
2666 stride_w,
2667 pad_h,
2668 pad_w,
2669 )
2670}
2671
2672#[cfg(not(feature = "cuda"))]
2674pub fn cuda_conv2d_forward(
2675 _input: &[f32],
2676 _weight: &[f32],
2677 _bias: Option<&[f32]>,
2678 _batch_size: usize,
2679 _in_channels: usize,
2680 _in_height: usize,
2681 _in_width: usize,
2682 _out_channels: usize,
2683 _kernel_h: usize,
2684 _kernel_w: usize,
2685 _stride_h: usize,
2686 _stride_w: usize,
2687 _pad_h: usize,
2688 _pad_w: usize,
2689) -> Option<Vec<f32>> {
2690 None
2691}
2692
2693#[cfg(feature = "cuda")]
2698impl CudaBackend {
2699 pub fn maxpool2d_fwd_f32(
2706 &self,
2707 input: &CudaSlice<f32>,
2708 output: &mut CudaSlice<f32>,
2709 indices: &mut CudaSlice<i32>,
2710 params: &CudaSlice<u32>,
2711 channels: usize,
2712 out_h: usize,
2713 out_w: usize,
2714 total: usize,
2715 ) -> Result<(), CudaError> {
2716 let func = self
2717 .kernels
2718 .get("maxpool2d_fwd_f32")
2719 .ok_or_else(|| CudaError::KernelNotFound("maxpool2d_fwd_f32".to_string()))?;
2720
2721 let cfg = cuda_kernels::launch_config(total);
2722 unsafe {
2723 self.stream
2724 .launch_builder(func)
2725 .arg(input)
2726 .arg(output)
2727 .arg(indices)
2728 .arg(params)
2729 .arg(&(channels as u32))
2730 .arg(&(out_h as u32))
2731 .arg(&(out_w as u32))
2732 .arg(&(total as u32))
2733 .launch(cfg)
2734 .map(|_| ())
2735 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2736 }
2737 Ok(())
2738 }
2739
2740 pub fn maxpool2d_bwd_f32(
2745 &self,
2746 grad_output: &CudaSlice<f32>,
2747 indices: &CudaSlice<i32>,
2748 grad_input: &mut CudaSlice<f32>,
2749 total: usize,
2750 ) -> Result<(), CudaError> {
2751 let func = self
2752 .kernels
2753 .get("maxpool2d_bwd_f32")
2754 .ok_or_else(|| CudaError::KernelNotFound("maxpool2d_bwd_f32".to_string()))?;
2755
2756 let cfg = cuda_kernels::launch_config(total);
2757 unsafe {
2758 self.stream
2759 .launch_builder(func)
2760 .arg(grad_output)
2761 .arg(indices)
2762 .arg(grad_input)
2763 .arg(&(total as u32))
2764 .launch(cfg)
2765 .map(|_| ())
2766 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2767 }
2768 Ok(())
2769 }
2770
2771 pub fn avgpool2d_fwd_f32(
2775 &self,
2776 input: &CudaSlice<f32>,
2777 output: &mut CudaSlice<f32>,
2778 params: &CudaSlice<u32>,
2779 channels: usize,
2780 out_h: usize,
2781 out_w: usize,
2782 total: usize,
2783 ) -> Result<(), CudaError> {
2784 let func = self
2785 .kernels
2786 .get("avgpool2d_fwd_f32")
2787 .ok_or_else(|| CudaError::KernelNotFound("avgpool2d_fwd_f32".to_string()))?;
2788
2789 let cfg = cuda_kernels::launch_config(total);
2790 unsafe {
2791 self.stream
2792 .launch_builder(func)
2793 .arg(input)
2794 .arg(output)
2795 .arg(params)
2796 .arg(&(channels as u32))
2797 .arg(&(out_h as u32))
2798 .arg(&(out_w as u32))
2799 .arg(&(total as u32))
2800 .launch(cfg)
2801 .map(|_| ())
2802 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2803 }
2804 Ok(())
2805 }
2806
2807 pub fn avgpool2d_bwd_f32(
2811 &self,
2812 grad_output: &CudaSlice<f32>,
2813 grad_input: &mut CudaSlice<f32>,
2814 params: &CudaSlice<u32>,
2815 channels: usize,
2816 out_h: usize,
2817 out_w: usize,
2818 total: usize,
2819 ) -> Result<(), CudaError> {
2820 let func = self
2821 .kernels
2822 .get("avgpool2d_bwd_f32")
2823 .ok_or_else(|| CudaError::KernelNotFound("avgpool2d_bwd_f32".to_string()))?;
2824
2825 let cfg = cuda_kernels::launch_config(total);
2826 unsafe {
2827 self.stream
2828 .launch_builder(func)
2829 .arg(grad_output)
2830 .arg(grad_input)
2831 .arg(params)
2832 .arg(&(channels as u32))
2833 .arg(&(out_h as u32))
2834 .arg(&(out_w as u32))
2835 .arg(&(total as u32))
2836 .launch(cfg)
2837 .map(|_| ())
2838 .map_err(|e| CudaError::DriverError(e.to_string()))?;
2839 }
2840 Ok(())
2841 }
2842}
2843
2844#[cfg(feature = "cuda")]
2864pub struct PinnedBuffer {
2865 ptr: *mut f32,
2867 len: usize,
2869}
2870
2871#[cfg(feature = "cuda")]
2872unsafe impl Send for PinnedBuffer {}
2873#[cfg(feature = "cuda")]
2874unsafe impl Sync for PinnedBuffer {}
2875
2876#[cfg(feature = "cuda")]
2877impl PinnedBuffer {
2878 pub fn from_slice(data: &[f32]) -> Result<Self, CudaError> {
2888 use std::ptr;
2889
2890 if data.is_empty() {
2891 return Ok(Self {
2892 ptr: ptr::null_mut(),
2893 len: 0,
2894 });
2895 }
2896
2897 let byte_size = data.len() * std::mem::size_of::<f32>();
2898 let mut host_ptr: *mut std::ffi::c_void = ptr::null_mut();
2899
2900 let _ = get_cuda_backend().ok_or(CudaError::DeviceNotFound)?;
2902
2903 unsafe {
2904 let result = cudarc::driver::sys::cuMemAllocHost_v2(&mut host_ptr, byte_size);
2905 if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
2906 return Err(CudaError::AllocationFailed);
2907 }
2908
2909 ptr::copy_nonoverlapping(data.as_ptr(), host_ptr as *mut f32, data.len());
2911 }
2912
2913 Ok(Self {
2914 ptr: host_ptr as *mut f32,
2915 len: data.len(),
2916 })
2917 }
2918
2919 pub fn alloc(len: usize) -> Result<Self, CudaError> {
2928 use std::ptr;
2929
2930 if len == 0 {
2931 return Ok(Self {
2932 ptr: ptr::null_mut(),
2933 len: 0,
2934 });
2935 }
2936
2937 let byte_size = len * std::mem::size_of::<f32>();
2938 let mut host_ptr: *mut std::ffi::c_void = ptr::null_mut();
2939
2940 let _ = get_cuda_backend().ok_or(CudaError::DeviceNotFound)?;
2941
2942 unsafe {
2943 let result = cudarc::driver::sys::cuMemAllocHost_v2(&mut host_ptr, byte_size);
2944 if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS {
2945 return Err(CudaError::AllocationFailed);
2946 }
2947 }
2948
2949 Ok(Self {
2950 ptr: host_ptr as *mut f32,
2951 len,
2952 })
2953 }
2954
2955 pub fn as_slice(&self) -> &[f32] {
2957 if self.ptr.is_null() || self.len == 0 {
2958 return &[];
2959 }
2960 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
2961 }
2962
2963 pub fn as_slice_mut(&mut self) -> &mut [f32] {
2965 if self.ptr.is_null() || self.len == 0 {
2966 return &mut [];
2967 }
2968 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
2969 }
2970
2971 pub fn len(&self) -> usize {
2973 self.len
2974 }
2975
2976 pub fn is_empty(&self) -> bool {
2978 self.len == 0
2979 }
2980
2981 pub fn as_ptr(&self) -> *const f32 {
2983 self.ptr
2984 }
2985
2986 pub fn as_mut_ptr(&mut self) -> *mut f32 {
2988 self.ptr
2989 }
2990
2991 pub fn to_gpu(&self) -> Result<CudaSlice<f32>, CudaError> {
2996 let backend = get_cuda_backend().ok_or(CudaError::DeviceNotFound)?;
2997 backend.htod_copy(self.as_slice())
2998 }
2999}
3000
3001#[cfg(feature = "cuda")]
3002impl Drop for PinnedBuffer {
3003 fn drop(&mut self) {
3004 if !self.ptr.is_null() {
3005 unsafe {
3006 let _ = cudarc::driver::sys::cuMemFreeHost(self.ptr as *mut std::ffi::c_void);
3007 }
3008 self.ptr = std::ptr::null_mut();
3009 }
3010 }
3011}
3012
3013#[cfg(feature = "cuda")]
3020pub fn pin_memory(data: &[f32]) -> Result<PinnedBuffer, CudaError> {
3021 PinnedBuffer::from_slice(data)
3022}
3023
3024#[cfg(not(feature = "cuda"))]
3026pub fn pin_memory(_data: &[f32]) -> Result<(), CudaError> {
3027 Err(CudaError::DeviceNotFound)
3028}
3029
3030#[cfg(test)]
3035mod tests {
3036 use super::*;
3037
3038 #[test]
3039 fn test_cuda_availability() {
3040 let available = is_available();
3041 println!("CUDA available: {}", available);
3042 }
3043
3044 #[test]
3045 fn test_device_count() {
3046 let count = device_count();
3047 println!("CUDA device count: {}", count);
3048 assert!(count <= 16);
3049 }
3050
3051 #[test]
3052 #[cfg(feature = "cuda")]
3053 fn test_cuda_backend_creation() {
3054 if is_available() {
3055 let backend = CudaBackend::new(0);
3056 assert!(backend.is_some());
3057 }
3058 }
3059
3060 #[test]
3061 #[cfg(feature = "cuda")]
3062 fn test_cuda_memory_operations() {
3063 if !is_available() {
3064 return;
3065 }
3066
3067 let backend = CudaBackend::new(0).unwrap();
3068
3069 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
3071 let gpu_data = backend.htod_copy(&data).unwrap();
3072
3073 let result = backend.dtoh_copy(&gpu_data).unwrap();
3075 assert_eq!(data, result);
3076 }
3077
3078 #[test]
3079 #[cfg(feature = "cuda")]
3080 fn test_cuda_gemm() {
3081 if !is_available() {
3082 return;
3083 }
3084
3085 let backend = CudaBackend::new(0).unwrap();
3086
3087 let a: Vec<f32> = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; let b: Vec<f32> = vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]; let c: Vec<f32> = vec![0.0; 4]; let a_gpu = backend.htod_copy(&a).unwrap();
3104 let b_gpu = backend.htod_copy(&b).unwrap();
3105 let mut c_gpu = backend.htod_copy(&c).unwrap();
3106
3107 backend
3115 .gemm_f32(
3116 false, false, 2, 2, 3, 1.0, &a_gpu, 2, &b_gpu, 3, 0.0, &mut c_gpu, 2, )
3123 .unwrap();
3124
3125 let result = backend.dtoh_copy(&c_gpu).unwrap();
3126 assert!((result[0] - 22.0).abs() < 1e-5, "result[0] = {}", result[0]);
3134 assert!((result[1] - 49.0).abs() < 1e-5, "result[1] = {}", result[1]);
3135 assert!((result[2] - 28.0).abs() < 1e-5, "result[2] = {}", result[2]);
3136 assert!((result[3] - 64.0).abs() < 1e-5, "result[3] = {}", result[3]);
3137 }
3138
3139 #[test]
3140 #[cfg(feature = "cuda")]
3141 fn test_cuda_add_kernel() {
3142 if !is_available() {
3143 return;
3144 }
3145
3146 let backend = CudaBackend::new(0).unwrap();
3147
3148 let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
3149 let b: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
3150
3151 let a_gpu = backend.htod_copy(&a).unwrap();
3152 let b_gpu = backend.htod_copy(&b).unwrap();
3153 let mut c_gpu = backend.alloc::<f32>(4).unwrap();
3154
3155 backend.add_f32(&mut c_gpu, &a_gpu, &b_gpu, 4).unwrap();
3156
3157 let result = backend.dtoh_copy(&c_gpu).unwrap();
3158 assert!((result[0] - 6.0).abs() < 1e-5);
3159 assert!((result[1] - 8.0).abs() < 1e-5);
3160 assert!((result[2] - 10.0).abs() < 1e-5);
3161 assert!((result[3] - 12.0).abs() < 1e-5);
3162 }
3163
3164 #[test]
3165 #[cfg(feature = "cuda")]
3166 fn test_cuda_mul_kernel() {
3167 if !is_available() {
3168 return;
3169 }
3170
3171 let backend = CudaBackend::new(0).unwrap();
3172
3173 let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
3174 let b: Vec<f32> = vec![2.0, 3.0, 4.0, 5.0];
3175
3176 let a_gpu = backend.htod_copy(&a).unwrap();
3177 let b_gpu = backend.htod_copy(&b).unwrap();
3178 let mut c_gpu = backend.alloc::<f32>(4).unwrap();
3179
3180 backend.mul_f32(&mut c_gpu, &a_gpu, &b_gpu, 4).unwrap();
3181
3182 let result = backend.dtoh_copy(&c_gpu).unwrap();
3183 assert!((result[0] - 2.0).abs() < 1e-5);
3184 assert!((result[1] - 6.0).abs() < 1e-5);
3185 assert!((result[2] - 12.0).abs() < 1e-5);
3186 assert!((result[3] - 20.0).abs() < 1e-5);
3187 }
3188
3189 #[test]
3190 #[cfg(feature = "cuda")]
3191 fn test_cuda_scale_kernel() {
3192 if !is_available() {
3193 return;
3194 }
3195
3196 let backend = CudaBackend::new(0).unwrap();
3197
3198 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
3199 let mut data_gpu = backend.htod_copy(&data).unwrap();
3200
3201 backend.scale_f32(&mut data_gpu, 2.5, 4).unwrap();
3202
3203 let result = backend.dtoh_copy(&data_gpu).unwrap();
3204 assert!((result[0] - 2.5).abs() < 1e-5);
3205 assert!((result[1] - 5.0).abs() < 1e-5);
3206 assert!((result[2] - 7.5).abs() < 1e-5);
3207 assert!((result[3] - 10.0).abs() < 1e-5);
3208 }
3209
3210 #[test]
3211 #[cfg(feature = "cuda")]
3212 fn test_cuda_relu_kernel() {
3213 if !is_available() {
3214 return;
3215 }
3216
3217 let backend = CudaBackend::new(0).unwrap();
3218
3219 let input: Vec<f32> = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
3220 let input_gpu = backend.htod_copy(&input).unwrap();
3221 let mut output_gpu = backend.alloc::<f32>(5).unwrap();
3222
3223 backend.relu_f32(&mut output_gpu, &input_gpu, 5).unwrap();
3224
3225 let result = backend.dtoh_copy(&output_gpu).unwrap();
3226 assert!((result[0] - 0.0).abs() < 1e-5);
3227 assert!((result[1] - 0.0).abs() < 1e-5);
3228 assert!((result[2] - 0.0).abs() < 1e-5);
3229 assert!((result[3] - 1.0).abs() < 1e-5);
3230 assert!((result[4] - 2.0).abs() < 1e-5);
3231 }
3232
3233 #[test]
3234 #[cfg(feature = "cuda")]
3235 fn test_cuda_sigmoid_kernel() {
3236 if !is_available() {
3237 return;
3238 }
3239
3240 let backend = CudaBackend::new(0).unwrap();
3241
3242 let input: Vec<f32> = vec![0.0, 1.0, -1.0];
3243 let input_gpu = backend.htod_copy(&input).unwrap();
3244 let mut output_gpu = backend.alloc::<f32>(3).unwrap();
3245
3246 backend.sigmoid_f32(&mut output_gpu, &input_gpu, 3).unwrap();
3247
3248 let result = backend.dtoh_copy(&output_gpu).unwrap();
3249 assert!((result[0] - 0.5).abs() < 1e-4);
3251 assert!((result[1] - 0.7311).abs() < 1e-3);
3253 assert!((result[2] - 0.2689).abs() < 1e-3);
3255 }
3256
3257 #[test]
3258 #[cfg(feature = "cuda")]
3259 fn test_cuda_tanh_kernel() {
3260 if !is_available() {
3261 return;
3262 }
3263
3264 let backend = CudaBackend::new(0).unwrap();
3265
3266 let input: Vec<f32> = vec![0.0, 1.0, -1.0];
3267 let input_gpu = backend.htod_copy(&input).unwrap();
3268 let mut output_gpu = backend.alloc::<f32>(3).unwrap();
3269
3270 backend.tanh_f32(&mut output_gpu, &input_gpu, 3).unwrap();
3271
3272 let result = backend.dtoh_copy(&output_gpu).unwrap();
3273 assert!((result[0] - 0.0).abs() < 1e-5);
3275 assert!((result[1] - 0.7616).abs() < 1e-3);
3277 assert!((result[2] - (-0.7616)).abs() < 1e-3);
3279 }
3280
3281 #[test]
3282 #[cfg(feature = "cuda")]
3283 fn test_cuda_large_tensor_add() {
3284 if !is_available() {
3285 return;
3286 }
3287
3288 let backend = CudaBackend::new(0).unwrap();
3289
3290 let n = 1_000_000;
3292 let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
3293 let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
3294
3295 let a_gpu = backend.htod_copy(&a).unwrap();
3296 let b_gpu = backend.htod_copy(&b).unwrap();
3297 let mut c_gpu = backend.alloc::<f32>(n).unwrap();
3298
3299 backend.add_f32(&mut c_gpu, &a_gpu, &b_gpu, n).unwrap();
3300
3301 let result = backend.dtoh_copy(&c_gpu).unwrap();
3302
3303 assert!((result[0] - n as f32).abs() < 1e-3);
3305 assert!((result[n / 2] - n as f32).abs() < 1e-3);
3306 assert!((result[n - 1] - n as f32).abs() < 1e-3);
3307 }
3308
3309 #[test]
3310 #[cfg(feature = "cuda")]
3311 fn test_cuda_conv2d_forward() {
3312 if !is_available() {
3313 return;
3314 }
3315
3316 let input = vec![1.0f32; 1 * 3 * 4 * 4]; let mut weight = vec![0.0f32; 2 * 3 * 1 * 1];
3319 weight[0] = 1.0;
3321 weight[4] = 1.0;
3322 let bias = vec![0.5f32; 2];
3323
3324 let result = cuda_conv2d_forward(
3325 &input,
3326 &weight,
3327 Some(&bias),
3328 1,
3329 3,
3330 4,
3331 4,
3332 2,
3333 1,
3334 1,
3335 1,
3336 1,
3337 0,
3338 0,
3339 );
3340
3341 let out = result.expect("CUDA conv2d should succeed");
3342 assert_eq!(out.len(), 2 * 4 * 4);
3343 assert!(
3345 (out[0] - 1.5).abs() < 0.01,
3346 "1x1 conv ch0: expected 1.5, got {}",
3347 out[0]
3348 );
3349 assert!(
3351 (out[16] - 1.5).abs() < 0.01,
3352 "1x1 conv ch1: expected 1.5, got {}",
3353 out[16]
3354 );
3355
3356 let input2 = vec![1.0f32; 1 * 3 * 8 * 8];
3358 let weight2 = vec![1.0f32; 2 * 3 * 3 * 3]; let bias2 = vec![0.0f32; 2];
3360
3361 let result2 = cuda_conv2d_forward(
3362 &input2,
3363 &weight2,
3364 Some(&bias2),
3365 1,
3366 3,
3367 8,
3368 8,
3369 2,
3370 3,
3371 3,
3372 1,
3373 1,
3374 1,
3375 1,
3376 );
3377
3378 let out2 = result2.expect("CUDA 3x3 conv should succeed");
3379 assert_eq!(out2.len(), 2 * 8 * 8);
3380 let center = 4 * 8 + 4;
3382 assert!(
3383 (out2[center] - 27.0).abs() < 0.1,
3384 "3x3 conv center: expected 27.0, got {}",
3385 out2[center]
3386 );
3387 assert!(
3389 (out2[0] - 12.0).abs() < 0.1,
3390 "3x3 conv corner: expected 12.0, got {}",
3391 out2[0]
3392 );
3393 }
3394}