1#![deny(missing_docs)]
3
4use ndarray::{Array1, Array2, Axis};
5use num_traits::NumAssign;
6use std::error::Error;
7use std::fmt;
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11pub enum DType {
12 F32,
14 F64,
16}
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum Device {
21 Cpu,
23 #[cfg(feature = "cuda")]
24 Cuda,
26}
27
28#[derive(Debug)]
30pub enum GpuError {
31 BackendUnavailable(&'static str),
33 ShapeMismatch,
35}
36
37impl fmt::Display for GpuError {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 match self {
40 GpuError::BackendUnavailable(name) => write!(f, "backend not available: {name}"),
41 GpuError::ShapeMismatch => write!(f, "shape mismatch"),
42 }
43 }
44}
45
46impl Error for GpuError {}
47
48#[derive(Clone, Debug)]
50pub struct DeviceArray<T> {
51 shape: Vec<usize>,
52 dtype: DType,
53 device: Device,
54 host: Vec<T>,
56}
57
58impl<T: Copy> DeviceArray<T> {
59 pub fn from_cpu_slice(shape: &[usize], dtype: DType, data: &[T]) -> Self {
68 assert_eq!(shape.iter().product::<usize>(), data.len());
69 Self {
70 shape: shape.to_vec(),
71 dtype,
72 device: Device::Cpu,
73 host: data.to_vec(),
74 }
75 }
76
77 pub fn to_cpu_vec(&self) -> Vec<T> {
87 self.host.clone()
88 }
89
90 pub fn shape(&self) -> &[usize] {
99 &self.shape
100 }
101
102 pub fn dtype(&self) -> DType {
111 self.dtype
112 }
113
114 pub fn device(&self) -> Device {
123 self.device
124 }
125}
126
127impl<T: Copy> DeviceArray<T> {
128 #[cfg(feature = "cuda")]
129 pub fn to_device(&mut self, device: Device) -> Result<(), GpuError> {
139 match device {
140 Device::Cpu => {
141 self.device = Device::Cpu;
142 Ok(())
143 }
144 Device::Cuda => {
145 self.device = Device::Cuda;
147 Ok(())
148 }
149 }
150 }
151
152 #[cfg(not(feature = "cuda"))]
153 pub fn to_device(&mut self, device: Device) -> Result<(), GpuError> {
155 match device {
156 Device::Cpu => {
157 self.device = Device::Cpu;
158 Ok(())
159 }
160 }
161 }
162}
163
164impl<T> DeviceArray<T>
166where
167 T: Copy + NumAssign,
168{
169 pub fn add_scalar(&self, alpha: T) -> Self {
179 let mut out = self.clone();
180 for v in &mut out.host {
181 *v += alpha;
182 }
183 out
184 }
185
186 pub fn mul_scalar(&self, alpha: T) -> Self {
196 let mut out = self.clone();
197 for v in &mut out.host {
198 *v *= alpha;
199 }
200 out
201 }
202}
203
204impl<T> DeviceArray<T>
205where
206 T: Copy + NumAssign,
207{
208 pub fn add(&self, other: &Self) -> Result<Self, GpuError> {
218 if self.shape != other.shape {
219 return Err(GpuError::ShapeMismatch);
220 }
221 let mut out = self.clone();
222 for (o, r) in out.host.iter_mut().zip(other.host.iter()) {
223 *o += *r;
224 }
225 Ok(out)
226 }
227}
228
229impl DeviceArray<f32> {
230 pub fn add_scalar_auto(&self, alpha: f32) -> Self {
241 #[cfg(feature = "cuda")]
242 {
243 match self.device {
244 Device::Cpu => self.mul_scalar(1.0f32).add_scalar(alpha), Device::Cuda => {
246 let mut out = vec![0.0f32; self.host.len()];
247 if let Err(_) = crate::add_scalar_f32_cuda(&self.host, alpha, &mut out) {
248 return self.mul_scalar(1.0f32).add_scalar(alpha);
250 }
251 DeviceArray {
252 shape: self.shape.clone(),
253 dtype: self.dtype,
254 device: self.device,
255 host: out,
256 }
257 }
258 }
259 }
260 #[cfg(not(feature = "cuda"))]
261 {
262 self.mul_scalar(1.0f32).add_scalar(alpha)
263 }
264 }
265
266 pub fn add_auto(&self, other: &Self) -> Result<Self, GpuError> {
276 if self.shape != other.shape {
277 return Err(GpuError::ShapeMismatch);
278 }
279 #[cfg(feature = "cuda")]
280 {
281 match (self.device, other.device) {
282 (Device::Cpu, Device::Cpu) => self.add(other),
283 (Device::Cuda, Device::Cuda) => {
284 let mut out = vec![0.0f32; self.host.len()];
285 if let Err(_) = crate::add_vec_f32_cuda(&self.host, &other.host, &mut out) {
286 return self.add(other);
288 }
289 Ok(DeviceArray {
290 shape: self.shape.clone(),
291 dtype: self.dtype,
292 device: self.device,
293 host: out,
294 })
295 }
296 _ => self.add(other),
297 }
298 }
299 #[cfg(not(feature = "cuda"))]
300 {
301 self.add(other)
302 }
303 }
304
305 pub fn mul_scalar_auto(&self, alpha: f32) -> Self {
315 #[cfg(feature = "cuda")]
316 {
317 match self.device {
318 Device::Cpu => self.mul_scalar(alpha),
319 Device::Cuda => {
320 let mut out = vec![0.0f32; self.host.len()];
321 if let Err(_) = crate::mul_scalar_f32_cuda(&self.host, alpha, &mut out) {
322 return self.mul_scalar(alpha);
323 }
324 DeviceArray {
325 shape: self.shape.clone(),
326 dtype: self.dtype,
327 device: self.device,
328 host: out,
329 }
330 }
331 }
332 }
333 #[cfg(not(feature = "cuda"))]
334 {
335 self.mul_scalar(alpha)
336 }
337 }
338}
339
340pub fn fir1d_batched_f32_auto(x: &Array2<f32>, taps: &Array1<f32>, device: Device) -> Array2<f32> {
357 #[cfg(feature = "cuda")]
359 {
360 return match device {
361 Device::Cpu => fir1d_batched_f32(x, taps),
362 Device::Cuda => match crate::fir1d_batched_f32_cuda(x, taps) {
363 Ok(y) => y,
364 Err(_) => fir1d_batched_f32(x, taps),
365 },
366 };
367 }
368 #[cfg(not(feature = "cuda"))]
369 {
370 let _ = device;
371 fir1d_batched_f32(x, taps)
372 }
373}
374
375#[cfg(feature = "cuda")]
376mod cuda {
377 use super::*;
378 use std::ffi::c_void;
379 use std::ptr;
380
381 type CUdevice = i32;
383 type CUcontext = *mut c_void;
384 type CUmodule = *mut c_void;
385 type CUfunction = *mut c_void;
386 type CUdeviceptr = u64;
387 type CUresult = i32;
388
389 const CUDA_SUCCESS: CUresult = 0;
390
391 #[link(name = "cuda")]
392 extern "C" {
393 fn cuInit(flags: u32) -> CUresult;
394 fn cuDeviceGet(device: *mut CUdevice, ordinal: i32) -> CUresult;
395 fn cuCtxCreate(ctx: *mut CUcontext, flags: u32, device: CUdevice) -> CUresult;
396 fn cuCtxDestroy(ctx: CUcontext) -> CUresult;
397 fn cuModuleLoadData(module: *mut CUmodule, image: *const c_void) -> CUresult;
398 fn cuModuleGetFunction(hfunc: *mut CUfunction, hmod: CUmodule, name: *const u8)
399 -> CUresult;
400 fn cuMemAlloc(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult;
401 fn cuMemFree(dptr: CUdeviceptr) -> CUresult;
402 fn cuMemcpyHtoD(
403 dstDevice: CUdeviceptr,
404 srcHost: *const c_void,
405 ByteCount: usize,
406 ) -> CUresult;
407 fn cuMemcpyDtoH(dstHost: *mut c_void, srcDevice: CUdeviceptr, ByteCount: usize)
408 -> CUresult;
409 fn cuLaunchKernel(
410 f: CUfunction,
411 gridDimX: u32,
412 gridDimY: u32,
413 gridDimZ: u32,
414 blockDimX: u32,
415 blockDimY: u32,
416 blockDimZ: u32,
417 sharedMemBytes: u32,
418 hStream: *mut c_void,
419 kernelParams: *mut *mut c_void,
420 extra: *mut *mut c_void,
421 ) -> CUresult;
422 fn cuCtxSynchronize() -> CUresult;
423 }
424
425 fn check(res: CUresult, msg: &str) -> Result<(), GpuError> {
426 if res == CUDA_SUCCESS {
427 Ok(())
428 } else {
429 Err(GpuError::BackendUnavailable(msg))
430 }
431 }
432
433 pub fn cuda_available() -> bool {
434 unsafe {
435 cuInit(0) == CUDA_SUCCESS && {
436 let mut d = 0;
437 cuDeviceGet(&mut d as *mut _, 0) == CUDA_SUCCESS
438 }
439 }
440 }
441
442 struct CudaCtx {
443 ctx: CUcontext,
444 }
445 impl CudaCtx {
446 fn create_default() -> Result<Self, GpuError> {
447 unsafe {
448 check(cuInit(0), "cuInit")?;
449 let mut dev: CUdevice = 0;
450 check(cuDeviceGet(&mut dev as *mut _, 0), "cuDeviceGet")?;
451 let mut ctx: CUcontext = ptr::null_mut();
452 check(cuCtxCreate(&mut ctx as *mut _, 0, dev), "cuCtxCreate")?;
453 Ok(Self { ctx })
454 }
455 }
456 }
457 impl Drop for CudaCtx {
458 fn drop(&mut self) {
459 unsafe {
460 let _ = cuCtxDestroy(self.ctx);
461 }
462 }
463 }
464
465 static PTX: &str = r#"
466.version 7.0
467.target sm_52
468.address_size 64
469
470.visible .entry add_vec_f32(
471 .param .u64 out,
472 .param .u64 a,
473 .param .u64 b,
474 .param .u32 n)
475{
476 .reg .pred %p;
477 .reg .b32 %r<6>;
478 .reg .b64 %rd<10>;
479 .reg .f32 %f<4>;
480
481 ld.param.u64 %rd1, [out];
482 ld.param.u64 %rd2, [a];
483 ld.param.u64 %rd3, [b];
484 ld.param.u32 %r1, [n];
485
486 mov.u32 %r2, %tid.x;
487 mov.u32 %r3, %ctaid.x;
488 mov.u32 %r4, %ntid.x;
489 mad.lo.s32 %r5, %r3, %r4, %r2; // idx
490 setp.ge.s32 %p, %r5, %r1;
491 @%p ret;
492
493 mul.wide.s32 %rd4, %r5, 4;
494 add.s64 %rd5, %rd2, %rd4;
495 add.s64 %rd6, %rd3, %rd4;
496 add.s64 %rd7, %rd1, %rd4;
497 ld.global.f32 %f1, [%rd5];
498 ld.global.f32 %f2, [%rd6];
499 add.f32 %f3, %f1, %f2;
500 st.global.f32 [%rd7], %f3;
501 ret;
502}
503
504.visible .entry add_scalar_f32(
505 .param .u64 out,
506 .param .u64 a,
507 .param .f32 alpha,
508 .param .u32 n)
509{
510 .reg .pred %p;
511 .reg .b32 %r<6>;
512 .reg .b64 %rd<10>;
513 .reg .f32 %f<4>;
514
515 ld.param.u64 %rd1, [out];
516 ld.param.u64 %rd2, [a];
517 ld.param.f32 %f1, [alpha];
518 ld.param.u32 %r1, [n];
519
520 mov.u32 %r2, %tid.x;
521 mov.u32 %r3, %ctaid.x;
522 mov.u32 %r4, %ntid.x;
523 mad.lo.s32 %r5, %r3, %r4, %r2; // idx
524 setp.ge.s32 %p, %r5, %r1;
525 @%p ret;
526
527 mul.wide.s32 %rd4, %r5, 4;
528 add.s64 %rd5, %rd2, %rd4;
529 add.s64 %rd6, %rd1, %rd4;
530 ld.global.f32 %f2, [%rd5];
531 add.f32 %f3, %f2, %f1;
532 st.global.f32 [%rd6], %f3;
533 ret;
534}
535
536.visible .entry mul_scalar_f32(
537 .param .u64 out,
538 .param .u64 a,
539 .param .f32 alpha,
540 .param .u32 n)
541{
542 .reg .pred %p;
543 .reg .b32 %r<6>;
544 .reg .b64 %rd<10>;
545 .reg .f32 %f<4>;
546
547 ld.param.u64 %rd1, [out];
548 ld.param.u64 %rd2, [a];
549 ld.param.f32 %f1, [alpha];
550 ld.param.u32 %r1, [n];
551
552 mov.u32 %r2, %tid.x;
553 mov.u32 %r3, %ctaid.x;
554 mov.u32 %r4, %ntid.x;
555 mad.lo.s32 %r5, %r3, %r4, %r2; // idx
556 setp.ge.s32 %p, %r5, %r1;
557 @%p ret;
558
559 mul.wide.s32 %rd4, %r5, 4;
560 add.s64 %rd5, %rd2, %rd4;
561 add.s64 %rd6, %rd1, %rd4;
562 ld.global.f32 %f2, [%rd5];
563 mul.f32 %f3, %f2, %f1;
564 st.global.f32 [%rd6], %f3;
565 ret;
566}
567
568.visible .entry fir1d_batched_f32(
569 .param .u64 out,
570 .param .u64 x,
571 .param .u64 taps,
572 .param .u32 b,
573 .param .u32 n,
574 .param .u32 k)
575{
576 .reg .pred %p<3>;
577 .reg .b32 %r<20>;
578 .reg .b64 %rd<20>;
579 .reg .f32 %f<6>;
580
581 // Load params
582 ld.param.u64 %rd1, [out];
583 ld.param.u64 %rd2, [x];
584 ld.param.u64 %rd3, [taps];
585 ld.param.u32 %rB, [b];
586 ld.param.u32 %rN, [n];
587 ld.param.u32 %rK, [k];
588
589 // idx = blockIdx.x * blockDim.x + threadIdx.x
590 mov.u32 %r2, %tid.x;
591 mov.u32 %r3, %ctaid.x;
592 mov.u32 %r4, %ntid.x;
593 mad.lo.s32 %rIdx, %r3, %r4, %r2;
594
595 // total = b*n
596 mul.lo.u32 %rTotal, %rB, %rN;
597 setp.ge.u32 %p0, %rIdx, %rTotal;
598 @%p0 ret;
599
600 // bi = idx / n; i = idx % n
601 div.u32 %rBi, %rIdx, %rN;
602 rem.u32 %rI, %rIdx, %rN;
603
604 // start = (i + 1 > k) ? (i + 1 - k) : 0
605 add.u32 %rTmp, %rI, 1;
606 setp.gt.u32 %p1, %rTmp, %rK;
607 mov.u32 %rStart, 0;
608 @%p1 sub.u32 %rStart, %rTmp, %rK;
609
610 // acc = 0.0f; j = i; t_idx = 0
611 mov.f32 %fAcc, 0f00000000; // 0.0
612 mov.u32 %rJ, %rI;
613 mov.u32 %rTIdx, 0;
614
615L_LOOP:
616 // if (j < start) break;
617 setp.lt.u32 %p2, %rJ, %rStart;
618 @%p2 bra L_DONE;
619
620 // tap_index = k - 1 - t_idx
621 mov.u32 %rKminus1, 0;
622 add.u32 %rKminus1, %rK, 0xffffffff; // k-1
623 sub.u32 %rTapIdx, %rKminus1, %rTIdx;
624 mul.wide.u32 %rdTapOff, %rTapIdx, 4;
625 add.s64 %rdTapPtr, %rd3, %rdTapOff;
626 ld.global.f32 %fTap, [%rdTapPtr];
627
628 // x index: bi*n + j
629 mul.lo.u32 %rRowOff, %rBi, %rN;
630 add.u32 %rXIdx, %rRowOff, %rJ;
631 mul.wide.u32 %rdXOff, %rXIdx, 4;
632 add.s64 %rdXPtr, %rd2, %rdXOff;
633 ld.global.f32 %fX, [%rdXPtr];
634
635 // acc += tap * x
636 mul.f32 %fMul, %fTap, %fX;
637 add.f32 %fAcc, %fAcc, %fMul;
638
639 // j--, t_idx++
640 add.u32 %rJ, %rJ, 0xffffffff; // j-1
641 add.u32 %rTIdx, %rTIdx, 1;
642 bra L_LOOP;
643
644L_DONE:
645 // out index: bi*n + i
646 mul.lo.u32 %rOutIdxBase, %rBi, %rN;
647 add.u32 %rOutIdx, %rOutIdxBase, %rI;
648 mul.wide.u32 %rdOutOff, %rOutIdx, 4;
649 add.s64 %rdOutPtr, %rd1, %rdOutOff;
650 st.global.f32 [%rdOutPtr], %fAcc;
651 ret;
652}
653"#;
654
655 fn load_module() -> Result<(CudaCtx, CUmodule), GpuError> {
656 unsafe {
657 let ctx = CudaCtx::create_default()?;
658 let mut module: CUmodule = ptr::null_mut();
659 check(
660 cuModuleLoadData(&mut module as *mut _, PTX.as_ptr() as *const c_void),
661 "cuModuleLoadData",
662 )?;
663 Ok((ctx, module))
664 }
665 }
666
667 unsafe fn get_function(module: CUmodule, name: &str) -> Result<CUfunction, GpuError> {
668 let mut func: CUfunction = ptr::null_mut();
669 let cname = name.as_bytes();
670 check(
671 cuModuleGetFunction(&mut func as *mut _, module, cname.as_ptr()),
672 name,
673 )?;
674 Ok(func)
675 }
676
677 pub fn add_vec_f32_cuda(a: &[f32], b: &[f32], out: &mut [f32]) -> Result<(), GpuError> {
678 assert_eq!(a.len(), b.len());
679 assert_eq!(a.len(), out.len());
680 let n = a.len() as u32;
681 unsafe {
682 let (_ctx, module) = load_module()?;
683 let func = get_function(module, "add_vec_f32")?;
684 let bytes = (n as usize) * std::mem::size_of::<f32>();
685
686 let mut d_a: CUdeviceptr = 0;
687 let mut d_b: CUdeviceptr = 0;
688 let mut d_out: CUdeviceptr = 0;
689 check(cuMemAlloc(&mut d_a as *mut _, bytes), "cuMemAlloc a")?;
690 check(cuMemAlloc(&mut d_b as *mut _, bytes), "cuMemAlloc b")?;
691 check(cuMemAlloc(&mut d_out as *mut _, bytes), "cuMemAlloc out")?;
692
693 check(
694 cuMemcpyHtoD(d_a, a.as_ptr() as *const c_void, bytes),
695 "cuMemcpyHtoD a",
696 )?;
697 check(
698 cuMemcpyHtoD(d_b, b.as_ptr() as *const c_void, bytes),
699 "cuMemcpyHtoD b",
700 )?;
701
702 let mut out_ptr = d_out as *mut c_void;
703 let mut a_ptr = d_a as *mut c_void;
704 let mut b_ptr = d_b as *mut c_void;
705 let mut n_val = n;
706 let mut params = vec![
707 &mut out_ptr as *mut _ as *mut c_void,
708 &mut a_ptr as *mut _ as *mut c_void,
709 &mut b_ptr as *mut _ as *mut c_void,
710 &mut n_val as *mut _ as *mut c_void,
711 ];
712
713 let block = 256u32;
714 let grid = ((n + block - 1) / block) as u32;
715 check(
716 cuLaunchKernel(
717 func,
718 grid,
719 1,
720 1,
721 block,
722 1,
723 1,
724 0,
725 ptr::null_mut(),
726 params.as_mut_ptr(),
727 ptr::null_mut(),
728 ),
729 "cuLaunchKernel add_vec_f32",
730 )?;
731 check(cuCtxSynchronize(), "cuCtxSynchronize")?;
732
733 check(
734 cuMemcpyDtoH(out.as_mut_ptr() as *mut c_void, d_out, bytes),
735 "cuMemcpyDtoH out",
736 )?;
737
738 let _ = cuMemFree(d_a);
739 let _ = cuMemFree(d_b);
740 let _ = cuMemFree(d_out);
741 Ok(())
742 }
743 }
744
745 pub fn add_scalar_f32_cuda(a: &[f32], alpha: f32, out: &mut [f32]) -> Result<(), GpuError> {
746 assert_eq!(a.len(), out.len());
747 let n = a.len() as u32;
748 unsafe {
749 let (_ctx, module) = load_module()?;
750 let func = get_function(module, "add_scalar_f32")?;
751 let bytes = (n as usize) * std::mem::size_of::<f32>();
752
753 let mut d_a: CUdeviceptr = 0;
754 let mut d_out: CUdeviceptr = 0;
755 check(cuMemAlloc(&mut d_a as *mut _, bytes), "cuMemAlloc a")?;
756 check(cuMemAlloc(&mut d_out as *mut _, bytes), "cuMemAlloc out")?;
757 check(
758 cuMemcpyHtoD(d_a, a.as_ptr() as *const c_void, bytes),
759 "cuMemcpyHtoD a",
760 )?;
761
762 let mut out_ptr = d_out as *mut c_void;
763 let mut a_ptr = d_a as *mut c_void;
764 let mut alpha_val = alpha;
765 let mut n_val = n;
766 let mut params = vec![
767 &mut out_ptr as *mut _ as *mut c_void,
768 &mut a_ptr as *mut _ as *mut c_void,
769 &mut alpha_val as *mut _ as *mut c_void,
770 &mut n_val as *mut _ as *mut c_void,
771 ];
772
773 let block = 256u32;
774 let grid = ((n + block - 1) / block) as u32;
775 check(
776 cuLaunchKernel(
777 func,
778 grid,
779 1,
780 1,
781 block,
782 1,
783 1,
784 0,
785 ptr::null_mut(),
786 params.as_mut_ptr(),
787 ptr::null_mut(),
788 ),
789 "cuLaunchKernel add_scalar_f32",
790 )?;
791 check(cuCtxSynchronize(), "cuCtxSynchronize")?;
792
793 check(
794 cuMemcpyDtoH(out.as_mut_ptr() as *mut c_void, d_out, bytes),
795 "cuMemcpyDtoH out",
796 )?;
797 let _ = cuMemFree(d_a);
798 let _ = cuMemFree(d_out);
799 Ok(())
800 }
801 }
802
803 pub fn mul_scalar_f32_cuda(a: &[f32], alpha: f32, out: &mut [f32]) -> Result<(), GpuError> {
804 assert_eq!(a.len(), out.len());
805 let n = a.len() as u32;
806 unsafe {
807 let (_ctx, module) = load_module()?;
808 let func = get_function(module, "mul_scalar_f32")?;
809 let bytes = (n as usize) * std::mem::size_of::<f32>();
810
811 let mut d_a: CUdeviceptr = 0;
812 let mut d_out: CUdeviceptr = 0;
813 check(cuMemAlloc(&mut d_a as *mut _, bytes), "cuMemAlloc a")?;
814 check(cuMemAlloc(&mut d_out as *mut _, bytes), "cuMemAlloc out")?;
815 check(
816 cuMemcpyHtoD(d_a, a.as_ptr() as *const c_void, bytes),
817 "cuMemcpyHtoD a",
818 )?;
819
820 let mut out_ptr = d_out as *mut c_void;
821 let mut a_ptr = d_a as *mut c_void;
822 let mut alpha_val = alpha;
823 let mut n_val = n;
824 let mut params = vec![
825 &mut out_ptr as *mut _ as *mut c_void,
826 &mut a_ptr as *mut _ as *mut c_void,
827 &mut alpha_val as *mut _ as *mut c_void,
828 &mut n_val as *mut _ as *mut c_void,
829 ];
830
831 let block = 256u32;
832 let grid = ((n + block - 1) / block) as u32;
833 check(
834 cuLaunchKernel(
835 func,
836 grid,
837 1,
838 1,
839 block,
840 1,
841 1,
842 0,
843 std::ptr::null_mut(),
844 params.as_mut_ptr(),
845 std::ptr::null_mut(),
846 ),
847 "cuLaunchKernel mul_scalar_f32",
848 )?;
849 check(cuCtxSynchronize(), "cuCtxSynchronize")?;
850
851 check(
852 cuMemcpyDtoH(out.as_mut_ptr() as *mut c_void, d_out, bytes),
853 "cuMemcpyDtoH out",
854 )?;
855 let _ = cuMemFree(d_a);
856 let _ = cuMemFree(d_out);
857 Ok(())
858 }
859 }
860
861 pub fn fir1d_batched_f32_cuda(
862 x: &Array2<f32>,
863 taps: &Array1<f32>,
864 ) -> Result<Array2<f32>, GpuError> {
865 let (b, n) = x.dim();
866 let k = taps.len();
867 let mut x_host = x.to_owned().into_raw_vec();
868 let taps_host = taps.as_slice().unwrap();
869 let mut out_host = vec![0.0f32; b * n];
870 unsafe {
871 let (_ctx, module) = load_module()?;
872 let func = get_function(module, "fir1d_batched_f32")?;
873
874 let bytes_x = x_host.len() * std::mem::size_of::<f32>();
875 let bytes_t = k * std::mem::size_of::<f32>();
876 let bytes_y = out_host.len() * std::mem::size_of::<f32>();
877
878 let mut d_x: CUdeviceptr = 0;
879 let mut d_t: CUdeviceptr = 0;
880 let mut d_y: CUdeviceptr = 0;
881 check(cuMemAlloc(&mut d_y as *mut _, bytes_y), "cuMemAlloc y")?;
882 check(cuMemAlloc(&mut d_x as *mut _, bytes_x), "cuMemAlloc x")?;
883 check(cuMemAlloc(&mut d_t as *mut _, bytes_t), "cuMemAlloc t")?;
884 check(
885 cuMemcpyHtoD(d_x, x_host.as_ptr() as *const c_void, bytes_x),
886 "HtoD x",
887 )?;
888 check(
889 cuMemcpyHtoD(d_t, taps_host.as_ptr() as *const c_void, bytes_t),
890 "HtoD t",
891 )?;
892
893 let mut y_ptr = d_y as *mut c_void;
894 let mut x_ptr = d_x as *mut c_void;
895 let mut t_ptr = d_t as *mut c_void;
896 let mut b_u32 = b as u32;
897 let mut n_u32 = n as u32;
898 let mut k_u32 = k as u32;
899 let mut params = vec![
900 &mut y_ptr as *mut _ as *mut c_void,
901 &mut x_ptr as *mut _ as *mut c_void,
902 &mut t_ptr as *mut _ as *mut c_void,
903 &mut b_u32 as *mut _ as *mut c_void,
904 &mut n_u32 as *mut _ as *mut c_void,
905 &mut k_u32 as *mut _ as *mut c_void,
906 ];
907
908 let total = (b * n) as u32;
909 let block = 256u32;
910 let grid = ((total + block - 1) / block) as u32;
911 check(
912 cuLaunchKernel(
913 func,
914 grid,
915 1,
916 1,
917 block,
918 1,
919 1,
920 0,
921 std::ptr::null_mut(),
922 params.as_mut_ptr(),
923 std::ptr::null_mut(),
924 ),
925 "cuLaunchKernel fir1d_batched_f32",
926 )?;
927 check(cuCtxSynchronize(), "cuCtxSynchronize")?;
928 check(
929 cuMemcpyDtoH(out_host.as_mut_ptr() as *mut c_void, d_y, bytes_y),
930 "DtoH y",
931 )?;
932
933 let _ = cuMemFree(d_x);
934 let _ = cuMemFree(d_t);
935 let _ = cuMemFree(d_y);
936 }
937 Ok(Array2::from_shape_vec((b, n), out_host).unwrap())
938 }
939}
940
941#[cfg(feature = "cuda")]
942pub use cuda::{add_scalar_f32_cuda, add_vec_f32_cuda, mul_scalar_f32_cuda};
943
944pub fn fir1d_batched_f32(x: &Array2<f32>, taps: &Array1<f32>) -> Array2<f32> {
958 let (b, n) = x.dim();
959 let k = taps.len();
960 let mut y = Array2::<f32>::zeros((b, n));
961 for bi in 0..b {
962 let xin = x.index_axis(Axis(0), bi);
963 let mut yout = y.index_axis_mut(Axis(0), bi);
964 for i in 0..n {
965 let mut acc = 0.0f32;
966 let start = (i + 1).saturating_sub(k);
967 for (t_idx, xi) in (start..=i).rev().enumerate() {
968 let tap = taps[k - 1 - t_idx];
969 acc += tap * xin[xi];
970 }
971 yout[i] = acc;
972 }
973 }
974 y
975}
976
977pub fn fir1d_batched_f64(x: &Array2<f64>, taps: &Array1<f64>) -> Array2<f64> {
990 let (b, n) = x.dim();
991 let k = taps.len();
992 let mut y = Array2::<f64>::zeros((b, n));
993 for bi in 0..b {
994 let xin = x.index_axis(Axis(0), bi);
995 let mut yout = y.index_axis_mut(Axis(0), bi);
996 for i in 0..n {
997 let mut acc = 0.0f64;
998 let start = (i + 1).saturating_sub(k);
999 for (t_idx, xi) in (start..=i).rev().enumerate() {
1000 let tap = taps[k - 1 - t_idx];
1001 acc += tap * xin[xi];
1002 }
1003 yout[i] = acc;
1004 }
1005 }
1006 y
1007}
1008
1009#[cfg(test)]
1010mod tests {
1011 use super::*;
1012 use ndarray::{array, Array1, Array2};
1013 use rand::Rng;
1014 use scir_core::assert_close;
1015
1016 #[test]
1017 fn device_array_roundtrip() {
1018 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1019 let arr = DeviceArray::from_cpu_slice(&[2, 2], DType::F32, &data);
1020 assert_eq!(arr.shape(), &[2, 2]);
1021 assert_eq!(arr.dtype(), DType::F32);
1022 assert_eq!(arr.device(), Device::Cpu);
1023 assert_eq!(arr.to_cpu_vec(), data);
1024 }
1025
1026 #[cfg(feature = "cuda")]
1027 #[test]
1028 fn cuda_add_and_add_scalar_f32() {
1029 let a = vec![1.0f32, 2.0, 3.0, 4.0];
1031 let b = vec![0.5f32, 1.5, 2.5, 3.5];
1032 let mut out = vec![0.0f32; 4];
1033 match crate::add_vec_f32_cuda(&a, &b, &mut out) {
1034 Ok(()) => {
1035 let out_f64: Vec<f64> = out.iter().copied().map(|v| v as f64).collect();
1036 assert_close!(&out_f64, &[1.5, 3.5, 5.5, 7.5], slice, tol = 1e-6);
1037 }
1038 Err(_) => {
1039 eprintln!("CUDA not available; skipping CUDA test");
1040 return;
1041 }
1042 }
1043 let mut out2 = vec![0.0f32; 4];
1044 crate::add_scalar_f32_cuda(&a, 1.0, &mut out2).unwrap();
1045 let out2_f64: Vec<f64> = out2.iter().copied().map(|v| v as f64).collect();
1046 assert_close!(&out2_f64, &[2.0, 3.0, 4.0, 5.0], slice, tol = 1e-6);
1047
1048 let mut out3 = vec![0.0f32; 4];
1049 crate::mul_scalar_f32_cuda(&a, 2.0, &mut out3).unwrap();
1050 let out3_f64: Vec<f64> = out3.iter().copied().map(|v| v as f64).collect();
1051 assert_close!(&out3_f64, &[2.0, 4.0, 6.0, 8.0], slice, tol = 1e-6);
1052 }
1053
1054 #[test]
1055 fn elementwise_ops_baseline() {
1056 let data = vec![1.0f64, 2.0, 3.0, 4.0];
1057 let arr = DeviceArray::from_cpu_slice(&[4], DType::F64, &data);
1058 let add = arr.add_scalar(1.0);
1059 let mul = arr.mul_scalar(2.0);
1060 assert_close!(&add.to_cpu_vec(), &[2.0, 3.0, 4.0, 5.0], slice, tol = 0.0);
1061 assert_close!(&mul.to_cpu_vec(), &[2.0, 4.0, 6.0, 8.0], slice, tol = 0.0);
1062 }
1063
1064 #[test]
1065 fn add_arrays() {
1066 let a = DeviceArray::from_cpu_slice(&[3], DType::F32, &[1.0, 2.0, 3.0]);
1067 let b = DeviceArray::from_cpu_slice(&[3], DType::F32, &[0.5, 1.5, 2.5]);
1068 let c = a.add(&b).unwrap();
1069 assert_eq!(c.to_cpu_vec(), vec![1.5f32, 3.5, 5.5]);
1070 }
1071
1072 #[test]
1073 fn fir_batched_matches_naive_f32() {
1074 let x: Array2<f32> = array![[1.0, 2.0, 3.0, 4.0], [0.5, 0.0, -0.5, -1.0]];
1075 let taps: Array1<f32> = array![0.25, 0.5, 0.25];
1076 let y = fir1d_batched_f32(&x, &taps);
1077 let expected0_f64 = array![0.25f64, 1.0, 2.0, 3.0];
1079 let expected1_f64 = array![0.125f64, 0.25, 0.0, -0.5];
1080 let y0_f64 = y.index_axis(Axis(0), 0).to_owned().mapv(|v| v as f64);
1081 let y1_f64 = y.index_axis(Axis(0), 1).to_owned().mapv(|v| v as f64);
1082 assert_close!(&y0_f64, &expected0_f64, array, atol = 1e-7, rtol = 1e-7);
1083 assert_close!(&y1_f64, &expected1_f64, array, atol = 1e-7, rtol = 1e-7);
1084 }
1085
1086 #[test]
1087 fn fir_batched_random_f64() {
1088 let mut rng = rand::thread_rng();
1089 let b = 3usize;
1090 let n = 32usize;
1091 let k = 5usize;
1092 let mut x = Array2::<f64>::zeros((b, n));
1093 for mut row in x.axis_iter_mut(Axis(0)) {
1094 for v in row.iter_mut() {
1095 *v = rng.gen::<f64>() * 2.0 - 1.0;
1096 }
1097 }
1098 let taps = Array1::from((0..k).map(|i| 1.0 / (i as f64 + 1.0)).collect::<Vec<_>>());
1099 let y = fir1d_batched_f64(&x, &taps);
1100
1101 let mut y_ref = Array2::<f64>::zeros((b, n));
1103 for bi in 0..b {
1104 for i in 0..n {
1105 let mut acc = 0.0f64;
1106 let start = (i + 1).saturating_sub(k);
1107 for (t_idx, xi) in (start..=i).rev().enumerate() {
1108 let tap = taps[k - 1 - t_idx];
1109 acc += tap * x[[bi, xi]];
1110 }
1111 y_ref[[bi, i]] = acc;
1112 }
1113 }
1114 assert_close!(
1115 &y.into_raw_vec(),
1116 &y_ref.into_raw_vec(),
1117 slice,
1118 atol = 1e-12,
1119 rtol = 1e-12
1120 );
1121 }
1122
1123 #[cfg(feature = "cuda")]
1124 #[test]
1125 fn cuda_fir1d_batched_f32_parity_small() {
1126 let x: Array2<f32> = array![[1.0, 2.0, 3.0, 4.0], [0.5, 0.0, -0.5, -1.0]];
1128 let taps: Array1<f32> = array![0.25, 0.5, 0.25];
1129 match crate::fir1d_batched_f32_cuda(&x, &taps) {
1131 Ok(y_cuda) => {
1132 let y_cpu = super::fir1d_batched_f32(&x, &taps);
1133 let y_cuda_f64: Vec<f64> = y_cuda
1134 .into_raw_vec()
1135 .into_iter()
1136 .map(|v| v as f64)
1137 .collect();
1138 let y_cpu_f64: Vec<f64> =
1139 y_cpu.into_raw_vec().into_iter().map(|v| v as f64).collect();
1140 assert_close!(&y_cuda_f64, &y_cpu_f64, slice, atol = 1e-5, rtol = 1e-6);
1141 }
1142 Err(_) => {
1143 eprintln!("CUDA not available; skipping CUDA FIR test");
1144 }
1145 }
1146 }
1147}