scir_gpu/
lib.rs

1//! GPU Foundations: device array abstraction and CPU-backed baseline ops.
2#![deny(missing_docs)]
3
4use ndarray::{Array1, Array2, Axis};
5use num_traits::NumAssign;
6use std::error::Error;
7use std::fmt;
8
9/// Supported data types for device arrays.
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11pub enum DType {
12    /// 32-bit floating point
13    F32,
14    /// 64-bit floating point
15    F64,
16}
17
18/// Execution device selection.
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum Device {
21    /// Host CPU device
22    Cpu,
23    #[cfg(feature = "cuda")]
24    /// NVIDIA CUDA device (feature `cuda`)
25    Cuda,
26}
27
28/// GPU-related error types.
29#[derive(Debug)]
30pub enum GpuError {
31    /// Backend is not available on this build or platform.
32    BackendUnavailable(&'static str),
33    /// Operation failed due to incompatible shapes.
34    ShapeMismatch,
35}
36
37impl fmt::Display for GpuError {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        match self {
40            GpuError::BackendUnavailable(name) => write!(f, "backend not available: {name}"),
41            GpuError::ShapeMismatch => write!(f, "shape mismatch"),
42        }
43    }
44}
45
46impl Error for GpuError {}
47
48/// A minimal, shaped device array with dtype. Currently CPU-backed.
49#[derive(Clone, Debug)]
50pub struct DeviceArray<T> {
51    shape: Vec<usize>,
52    dtype: DType,
53    device: Device,
54    // CPU storage; future backends will switch to an enum for storage.
55    host: Vec<T>,
56}
57
58impl<T: Copy> DeviceArray<T> {
59    /// Create a `DeviceArray` from a CPU slice and explicit shape/dtype.
60    ///
61    /// # Examples
62    /// ```
63    /// let data = vec![1.0f32, 2.0, 3.0, 4.0];
64    /// let arr = scir_gpu::DeviceArray::from_cpu_slice(&[2,2], scir_gpu::DType::F32, &data);
65    /// assert_eq!(arr.shape(), &[2,2]);
66    /// ```
67    pub fn from_cpu_slice(shape: &[usize], dtype: DType, data: &[T]) -> Self {
68        assert_eq!(shape.iter().product::<usize>(), data.len());
69        Self {
70            shape: shape.to_vec(),
71            dtype,
72            device: Device::Cpu,
73            host: data.to_vec(),
74        }
75    }
76
77    /// Copy data back to a CPU-owned `Vec<T>`.
78    ///
79    /// # Examples
80    /// ```
81    /// let data = vec![1i32,2,3];
82    /// let arr = scir_gpu::DeviceArray::from_cpu_slice(&[3], scir_gpu::DType::F32, &data);
83    /// let back = arr.to_cpu_vec();
84    /// assert_eq!(back, vec![1,2,3]);
85    /// ```
86    pub fn to_cpu_vec(&self) -> Vec<T> {
87        self.host.clone()
88    }
89
90    /// Return the logical shape of the array.
91    ///
92    /// # Examples
93    /// ```
94    /// let data = vec![0u8; 6];
95    /// let arr = scir_gpu::DeviceArray::from_cpu_slice(&[2,3], scir_gpu::DType::F32, &data);
96    /// assert_eq!(arr.shape(), &[2,3]);
97    /// ```
98    pub fn shape(&self) -> &[usize] {
99        &self.shape
100    }
101
102    /// Return the element data type.
103    ///
104    /// # Examples
105    /// ```
106    /// let data = vec![0u8; 4];
107    /// let arr = scir_gpu::DeviceArray::from_cpu_slice(&[4], scir_gpu::DType::F32, &data);
108    /// assert!(matches!(arr.dtype(), scir_gpu::DType::F32));
109    /// ```
110    pub fn dtype(&self) -> DType {
111        self.dtype
112    }
113
114    /// Return the current device of this array.
115    ///
116    /// # Examples
117    /// ```
118    /// let data = vec![1u8,2,3,4];
119    /// let arr = scir_gpu::DeviceArray::from_cpu_slice(&[4], scir_gpu::DType::F32, &data);
120    /// assert!(matches!(arr.device(), scir_gpu::Device::Cpu));
121    /// ```
122    pub fn device(&self) -> Device {
123        self.device
124    }
125}
126
127impl<T: Copy> DeviceArray<T> {
128    #[cfg(feature = "cuda")]
129    /// Move the array to a device (CPU or CUDA if enabled).
130    ///
131    /// # Examples
132    /// ```
133    /// let data = vec![1.0f32, 2.0, 3.0, 4.0];
134    /// let mut arr = scir_gpu::DeviceArray::from_cpu_slice(&[4], scir_gpu::DType::F32, &data);
135    /// // Always available
136    /// arr.to_device(scir_gpu::Device::Cpu).unwrap();
137    /// ```
138    pub fn to_device(&mut self, device: Device) -> Result<(), GpuError> {
139        match device {
140            Device::Cpu => {
141                self.device = Device::Cpu;
142                Ok(())
143            }
144            Device::Cuda => {
145                // Placeholder: actual upload would allocate device memory and copy.
146                self.device = Device::Cuda;
147                Ok(())
148            }
149        }
150    }
151
152    #[cfg(not(feature = "cuda"))]
153    /// Move the array to a device (CPU only, when CUDA is disabled).
154    pub fn to_device(&mut self, device: Device) -> Result<(), GpuError> {
155        match device {
156            Device::Cpu => {
157                self.device = Device::Cpu;
158                Ok(())
159            }
160        }
161    }
162}
163
164// Elementwise ops (CPU baseline)
165impl<T> DeviceArray<T>
166where
167    T: Copy + NumAssign,
168{
169    /// Add a scalar to each element (CPU baseline).
170    ///
171    /// # Examples
172    /// ```
173    /// let data = vec![1.0f32, 2.0, 3.0];
174    /// let arr = scir_gpu::DeviceArray::from_cpu_slice(&[3], scir_gpu::DType::F32, &data);
175    /// let out = arr.add_scalar(1.0f32);
176    /// assert_eq!(out.to_cpu_vec(), vec![2.0f32, 3.0, 4.0]);
177    /// ```
178    pub fn add_scalar(&self, alpha: T) -> Self {
179        let mut out = self.clone();
180        for v in &mut out.host {
181            *v += alpha;
182        }
183        out
184    }
185
186    /// Multiply each element by a scalar (CPU baseline).
187    ///
188    /// # Examples
189    /// ```
190    /// let data = vec![1.0f32, 2.0, 3.0];
191    /// let arr = scir_gpu::DeviceArray::from_cpu_slice(&[3], scir_gpu::DType::F32, &data);
192    /// let out = arr.mul_scalar(2.0f32);
193    /// assert_eq!(out.to_cpu_vec(), vec![2.0f32, 4.0, 6.0]);
194    /// ```
195    pub fn mul_scalar(&self, alpha: T) -> Self {
196        let mut out = self.clone();
197        for v in &mut out.host {
198            *v *= alpha;
199        }
200        out
201    }
202}
203
204impl<T> DeviceArray<T>
205where
206    T: Copy + NumAssign,
207{
208    /// Elementwise addition between arrays (CPU baseline).
209    ///
210    /// # Examples
211    /// ```
212    /// let a = scir_gpu::DeviceArray::from_cpu_slice(&[3], scir_gpu::DType::F32, &[1.0f32,2.0,3.0]);
213    /// let b = scir_gpu::DeviceArray::from_cpu_slice(&[3], scir_gpu::DType::F32, &[0.5f32,1.5,2.5]);
214    /// let c = a.add(&b).unwrap();
215    /// assert_eq!(c.to_cpu_vec(), vec![1.5f32, 3.5, 5.5]);
216    /// ```
217    pub fn add(&self, other: &Self) -> Result<Self, GpuError> {
218        if self.shape != other.shape {
219            return Err(GpuError::ShapeMismatch);
220        }
221        let mut out = self.clone();
222        for (o, r) in out.host.iter_mut().zip(other.host.iter()) {
223            *o += *r;
224        }
225        Ok(out)
226    }
227}
228
229impl DeviceArray<f32> {
230    /// Elementwise add-scalar, with CUDA dispatch when available.
231    ///
232    /// # Examples
233    /// ```
234    /// let data = vec![1.0f32, 2.0, 3.0];
235    /// let mut a = scir_gpu::DeviceArray::from_cpu_slice(&[3], scir_gpu::DType::F32, &data);
236    /// a.to_device(scir_gpu::Device::Cpu).unwrap();
237    /// let out = a.add_scalar_auto(1.0);
238    /// assert_eq!(out.to_cpu_vec(), vec![2.0f32, 3.0, 4.0]);
239    /// ```
240    pub fn add_scalar_auto(&self, alpha: f32) -> Self {
241        #[cfg(feature = "cuda")]
242        {
243            match self.device {
244                Device::Cpu => self.mul_scalar(1.0f32).add_scalar(alpha), // reuse CPU path
245                Device::Cuda => {
246                    let mut out = vec![0.0f32; self.host.len()];
247                    if let Err(_) = crate::add_scalar_f32_cuda(&self.host, alpha, &mut out) {
248                        // Fallback to CPU on failure
249                        return self.mul_scalar(1.0f32).add_scalar(alpha);
250                    }
251                    DeviceArray {
252                        shape: self.shape.clone(),
253                        dtype: self.dtype,
254                        device: self.device,
255                        host: out,
256                    }
257                }
258            }
259        }
260        #[cfg(not(feature = "cuda"))]
261        {
262            self.mul_scalar(1.0f32).add_scalar(alpha)
263        }
264    }
265
266    /// Elementwise add of two arrays, with CUDA dispatch when available.
267    ///
268    /// # Examples
269    /// ```
270    /// let a = scir_gpu::DeviceArray::from_cpu_slice(&[2], scir_gpu::DType::F32, &[1.0f32,2.0]);
271    /// let b = scir_gpu::DeviceArray::from_cpu_slice(&[2], scir_gpu::DType::F32, &[0.5f32,1.5]);
272    /// let c = a.add_auto(&b).unwrap();
273    /// assert_eq!(c.to_cpu_vec(), vec![1.5f32, 3.5]);
274    /// ```
275    pub fn add_auto(&self, other: &Self) -> Result<Self, GpuError> {
276        if self.shape != other.shape {
277            return Err(GpuError::ShapeMismatch);
278        }
279        #[cfg(feature = "cuda")]
280        {
281            match (self.device, other.device) {
282                (Device::Cpu, Device::Cpu) => self.add(other),
283                (Device::Cuda, Device::Cuda) => {
284                    let mut out = vec![0.0f32; self.host.len()];
285                    if let Err(_) = crate::add_vec_f32_cuda(&self.host, &other.host, &mut out) {
286                        // Fallback to CPU on failure
287                        return self.add(other);
288                    }
289                    Ok(DeviceArray {
290                        shape: self.shape.clone(),
291                        dtype: self.dtype,
292                        device: self.device,
293                        host: out,
294                    })
295                }
296                _ => self.add(other),
297            }
298        }
299        #[cfg(not(feature = "cuda"))]
300        {
301            self.add(other)
302        }
303    }
304
305    /// Elementwise mul-scalar with device dispatch.
306    ///
307    /// # Examples
308    /// ```
309    /// let data = vec![1.0f32, 2.0, 3.0];
310    /// let a = scir_gpu::DeviceArray::from_cpu_slice(&[3], scir_gpu::DType::F32, &data);
311    /// let out = a.mul_scalar_auto(2.0);
312    /// assert_eq!(out.to_cpu_vec(), vec![2.0f32, 4.0, 6.0]);
313    /// ```
314    pub fn mul_scalar_auto(&self, alpha: f32) -> Self {
315        #[cfg(feature = "cuda")]
316        {
317            match self.device {
318                Device::Cpu => self.mul_scalar(alpha),
319                Device::Cuda => {
320                    let mut out = vec![0.0f32; self.host.len()];
321                    if let Err(_) = crate::mul_scalar_f32_cuda(&self.host, alpha, &mut out) {
322                        return self.mul_scalar(alpha);
323                    }
324                    DeviceArray {
325                        shape: self.shape.clone(),
326                        dtype: self.dtype,
327                        device: self.device,
328                        host: out,
329                    }
330                }
331            }
332        }
333        #[cfg(not(feature = "cuda"))]
334        {
335            self.mul_scalar(alpha)
336        }
337    }
338}
339
340/// Dispatch FIR to CUDA if requested; otherwise use CPU baseline.
341/// FIR over each row of `x` using `taps` with device dispatch (CUDA when available).
342///
343/// This function accepts a `Device` hint and will attempt to run on
344/// CUDA when compiled with the `cuda` feature and a device is present,
345/// otherwise it falls back to the CPU baseline.
346///
347/// # Examples
348/// ```
349/// use ndarray::{array, Array1, Array2};
350/// let x: Array2<f32> = array![[1.0, 2.0, 3.0, 4.0]];
351/// let taps: Array1<f32> = array![0.25, 0.5, 0.25];
352/// // Explicitly run on CPU; returns shape-identical output
353/// let y = scir_gpu::fir1d_batched_f32_auto(&x, &taps, scir_gpu::Device::Cpu);
354/// assert_eq!(y.shape(), &[1, 4]);
355/// ```
356pub fn fir1d_batched_f32_auto(x: &Array2<f32>, taps: &Array1<f32>, device: Device) -> Array2<f32> {
357    /// FIR over each row of `x` using `taps` with device dispatch (CUDA when available).
358    #[cfg(feature = "cuda")]
359    {
360        return match device {
361            Device::Cpu => fir1d_batched_f32(x, taps),
362            Device::Cuda => match crate::fir1d_batched_f32_cuda(x, taps) {
363                Ok(y) => y,
364                Err(_) => fir1d_batched_f32(x, taps),
365            },
366        };
367    }
368    #[cfg(not(feature = "cuda"))]
369    {
370        let _ = device;
371        fir1d_batched_f32(x, taps)
372    }
373}
374
375#[cfg(feature = "cuda")]
376mod cuda {
377    use super::*;
378    use std::ffi::c_void;
379    use std::ptr;
380
381    // Minimal CUDA Driver API bindings
382    type CUdevice = i32;
383    type CUcontext = *mut c_void;
384    type CUmodule = *mut c_void;
385    type CUfunction = *mut c_void;
386    type CUdeviceptr = u64;
387    type CUresult = i32;
388
389    const CUDA_SUCCESS: CUresult = 0;
390
391    #[link(name = "cuda")]
392    extern "C" {
393        fn cuInit(flags: u32) -> CUresult;
394        fn cuDeviceGet(device: *mut CUdevice, ordinal: i32) -> CUresult;
395        fn cuCtxCreate(ctx: *mut CUcontext, flags: u32, device: CUdevice) -> CUresult;
396        fn cuCtxDestroy(ctx: CUcontext) -> CUresult;
397        fn cuModuleLoadData(module: *mut CUmodule, image: *const c_void) -> CUresult;
398        fn cuModuleGetFunction(hfunc: *mut CUfunction, hmod: CUmodule, name: *const u8)
399            -> CUresult;
400        fn cuMemAlloc(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult;
401        fn cuMemFree(dptr: CUdeviceptr) -> CUresult;
402        fn cuMemcpyHtoD(
403            dstDevice: CUdeviceptr,
404            srcHost: *const c_void,
405            ByteCount: usize,
406        ) -> CUresult;
407        fn cuMemcpyDtoH(dstHost: *mut c_void, srcDevice: CUdeviceptr, ByteCount: usize)
408            -> CUresult;
409        fn cuLaunchKernel(
410            f: CUfunction,
411            gridDimX: u32,
412            gridDimY: u32,
413            gridDimZ: u32,
414            blockDimX: u32,
415            blockDimY: u32,
416            blockDimZ: u32,
417            sharedMemBytes: u32,
418            hStream: *mut c_void,
419            kernelParams: *mut *mut c_void,
420            extra: *mut *mut c_void,
421        ) -> CUresult;
422        fn cuCtxSynchronize() -> CUresult;
423    }
424
425    fn check(res: CUresult, msg: &str) -> Result<(), GpuError> {
426        if res == CUDA_SUCCESS {
427            Ok(())
428        } else {
429            Err(GpuError::BackendUnavailable(msg))
430        }
431    }
432
433    pub fn cuda_available() -> bool {
434        unsafe {
435            cuInit(0) == CUDA_SUCCESS && {
436                let mut d = 0;
437                cuDeviceGet(&mut d as *mut _, 0) == CUDA_SUCCESS
438            }
439        }
440    }
441
442    struct CudaCtx {
443        ctx: CUcontext,
444    }
445    impl CudaCtx {
446        fn create_default() -> Result<Self, GpuError> {
447            unsafe {
448                check(cuInit(0), "cuInit")?;
449                let mut dev: CUdevice = 0;
450                check(cuDeviceGet(&mut dev as *mut _, 0), "cuDeviceGet")?;
451                let mut ctx: CUcontext = ptr::null_mut();
452                check(cuCtxCreate(&mut ctx as *mut _, 0, dev), "cuCtxCreate")?;
453                Ok(Self { ctx })
454            }
455        }
456    }
457    impl Drop for CudaCtx {
458        fn drop(&mut self) {
459            unsafe {
460                let _ = cuCtxDestroy(self.ctx);
461            }
462        }
463    }
464
465    static PTX: &str = r#"
466.version 7.0
467.target sm_52
468.address_size 64
469
470.visible .entry add_vec_f32(
471    .param .u64 out,
472    .param .u64 a,
473    .param .u64 b,
474    .param .u32 n)
475{
476    .reg .pred %p;
477    .reg .b32 %r<6>;
478    .reg .b64 %rd<10>;
479    .reg .f32 %f<4>;
480
481    ld.param.u64 %rd1, [out];
482    ld.param.u64 %rd2, [a];
483    ld.param.u64 %rd3, [b];
484    ld.param.u32 %r1, [n];
485
486    mov.u32 %r2, %tid.x;
487    mov.u32 %r3, %ctaid.x;
488    mov.u32 %r4, %ntid.x;
489    mad.lo.s32 %r5, %r3, %r4, %r2; // idx
490    setp.ge.s32 %p, %r5, %r1;
491    @%p ret;
492
493    mul.wide.s32 %rd4, %r5, 4;
494    add.s64 %rd5, %rd2, %rd4;
495    add.s64 %rd6, %rd3, %rd4;
496    add.s64 %rd7, %rd1, %rd4;
497    ld.global.f32 %f1, [%rd5];
498    ld.global.f32 %f2, [%rd6];
499    add.f32 %f3, %f1, %f2;
500    st.global.f32 [%rd7], %f3;
501    ret;
502}
503
504.visible .entry add_scalar_f32(
505    .param .u64 out,
506    .param .u64 a,
507    .param .f32 alpha,
508    .param .u32 n)
509{
510    .reg .pred %p;
511    .reg .b32 %r<6>;
512    .reg .b64 %rd<10>;
513    .reg .f32 %f<4>;
514
515    ld.param.u64 %rd1, [out];
516    ld.param.u64 %rd2, [a];
517    ld.param.f32 %f1, [alpha];
518    ld.param.u32 %r1, [n];
519
520    mov.u32 %r2, %tid.x;
521    mov.u32 %r3, %ctaid.x;
522    mov.u32 %r4, %ntid.x;
523    mad.lo.s32 %r5, %r3, %r4, %r2; // idx
524    setp.ge.s32 %p, %r5, %r1;
525    @%p ret;
526
527    mul.wide.s32 %rd4, %r5, 4;
528    add.s64 %rd5, %rd2, %rd4;
529    add.s64 %rd6, %rd1, %rd4;
530    ld.global.f32 %f2, [%rd5];
531    add.f32 %f3, %f2, %f1;
532    st.global.f32 [%rd6], %f3;
533    ret;
534}
535
536.visible .entry mul_scalar_f32(
537    .param .u64 out,
538    .param .u64 a,
539    .param .f32 alpha,
540    .param .u32 n)
541{
542    .reg .pred %p;
543    .reg .b32 %r<6>;
544    .reg .b64 %rd<10>;
545    .reg .f32 %f<4>;
546
547    ld.param.u64 %rd1, [out];
548    ld.param.u64 %rd2, [a];
549    ld.param.f32 %f1, [alpha];
550    ld.param.u32 %r1, [n];
551
552    mov.u32 %r2, %tid.x;
553    mov.u32 %r3, %ctaid.x;
554    mov.u32 %r4, %ntid.x;
555    mad.lo.s32 %r5, %r3, %r4, %r2; // idx
556    setp.ge.s32 %p, %r5, %r1;
557    @%p ret;
558
559    mul.wide.s32 %rd4, %r5, 4;
560    add.s64 %rd5, %rd2, %rd4;
561    add.s64 %rd6, %rd1, %rd4;
562    ld.global.f32 %f2, [%rd5];
563    mul.f32 %f3, %f2, %f1;
564    st.global.f32 [%rd6], %f3;
565    ret;
566}
567
568.visible .entry fir1d_batched_f32(
569    .param .u64 out,
570    .param .u64 x,
571    .param .u64 taps,
572    .param .u32 b,
573    .param .u32 n,
574    .param .u32 k)
575{
576    .reg .pred %p<3>;
577    .reg .b32 %r<20>;
578    .reg .b64 %rd<20>;
579    .reg .f32 %f<6>;
580
581    // Load params
582    ld.param.u64 %rd1, [out];
583    ld.param.u64 %rd2, [x];
584    ld.param.u64 %rd3, [taps];
585    ld.param.u32 %rB, [b];
586    ld.param.u32 %rN, [n];
587    ld.param.u32 %rK, [k];
588
589    // idx = blockIdx.x * blockDim.x + threadIdx.x
590    mov.u32 %r2, %tid.x;
591    mov.u32 %r3, %ctaid.x;
592    mov.u32 %r4, %ntid.x;
593    mad.lo.s32 %rIdx, %r3, %r4, %r2;
594
595    // total = b*n
596    mul.lo.u32 %rTotal, %rB, %rN;
597    setp.ge.u32 %p0, %rIdx, %rTotal;
598    @%p0 ret;
599
600    // bi = idx / n; i = idx % n
601    div.u32 %rBi, %rIdx, %rN;
602    rem.u32 %rI, %rIdx, %rN;
603
604    // start = (i + 1 > k) ? (i + 1 - k) : 0
605    add.u32 %rTmp, %rI, 1;
606    setp.gt.u32 %p1, %rTmp, %rK;
607    mov.u32 %rStart, 0;
608    @%p1 sub.u32 %rStart, %rTmp, %rK;
609
610    // acc = 0.0f; j = i; t_idx = 0
611    mov.f32 %fAcc, 0f00000000; // 0.0
612    mov.u32 %rJ, %rI;
613    mov.u32 %rTIdx, 0;
614
615L_LOOP:
616    // if (j < start) break;
617    setp.lt.u32 %p2, %rJ, %rStart;
618    @%p2 bra L_DONE;
619
620    // tap_index = k - 1 - t_idx
621    mov.u32 %rKminus1, 0;
622    add.u32 %rKminus1, %rK, 0xffffffff; // k-1
623    sub.u32 %rTapIdx, %rKminus1, %rTIdx;
624    mul.wide.u32 %rdTapOff, %rTapIdx, 4;
625    add.s64 %rdTapPtr, %rd3, %rdTapOff;
626    ld.global.f32 %fTap, [%rdTapPtr];
627
628    // x index: bi*n + j
629    mul.lo.u32 %rRowOff, %rBi, %rN;
630    add.u32 %rXIdx, %rRowOff, %rJ;
631    mul.wide.u32 %rdXOff, %rXIdx, 4;
632    add.s64 %rdXPtr, %rd2, %rdXOff;
633    ld.global.f32 %fX, [%rdXPtr];
634
635    // acc += tap * x
636    mul.f32 %fMul, %fTap, %fX;
637    add.f32 %fAcc, %fAcc, %fMul;
638
639    // j--, t_idx++
640    add.u32 %rJ, %rJ, 0xffffffff; // j-1
641    add.u32 %rTIdx, %rTIdx, 1;
642    bra L_LOOP;
643
644L_DONE:
645    // out index: bi*n + i
646    mul.lo.u32 %rOutIdxBase, %rBi, %rN;
647    add.u32 %rOutIdx, %rOutIdxBase, %rI;
648    mul.wide.u32 %rdOutOff, %rOutIdx, 4;
649    add.s64 %rdOutPtr, %rd1, %rdOutOff;
650    st.global.f32 [%rdOutPtr], %fAcc;
651    ret;
652}
653"#;
654
655    fn load_module() -> Result<(CudaCtx, CUmodule), GpuError> {
656        unsafe {
657            let ctx = CudaCtx::create_default()?;
658            let mut module: CUmodule = ptr::null_mut();
659            check(
660                cuModuleLoadData(&mut module as *mut _, PTX.as_ptr() as *const c_void),
661                "cuModuleLoadData",
662            )?;
663            Ok((ctx, module))
664        }
665    }
666
667    unsafe fn get_function(module: CUmodule, name: &str) -> Result<CUfunction, GpuError> {
668        let mut func: CUfunction = ptr::null_mut();
669        let cname = name.as_bytes();
670        check(
671            cuModuleGetFunction(&mut func as *mut _, module, cname.as_ptr()),
672            name,
673        )?;
674        Ok(func)
675    }
676
677    pub fn add_vec_f32_cuda(a: &[f32], b: &[f32], out: &mut [f32]) -> Result<(), GpuError> {
678        assert_eq!(a.len(), b.len());
679        assert_eq!(a.len(), out.len());
680        let n = a.len() as u32;
681        unsafe {
682            let (_ctx, module) = load_module()?;
683            let func = get_function(module, "add_vec_f32")?;
684            let bytes = (n as usize) * std::mem::size_of::<f32>();
685
686            let mut d_a: CUdeviceptr = 0;
687            let mut d_b: CUdeviceptr = 0;
688            let mut d_out: CUdeviceptr = 0;
689            check(cuMemAlloc(&mut d_a as *mut _, bytes), "cuMemAlloc a")?;
690            check(cuMemAlloc(&mut d_b as *mut _, bytes), "cuMemAlloc b")?;
691            check(cuMemAlloc(&mut d_out as *mut _, bytes), "cuMemAlloc out")?;
692
693            check(
694                cuMemcpyHtoD(d_a, a.as_ptr() as *const c_void, bytes),
695                "cuMemcpyHtoD a",
696            )?;
697            check(
698                cuMemcpyHtoD(d_b, b.as_ptr() as *const c_void, bytes),
699                "cuMemcpyHtoD b",
700            )?;
701
702            let mut out_ptr = d_out as *mut c_void;
703            let mut a_ptr = d_a as *mut c_void;
704            let mut b_ptr = d_b as *mut c_void;
705            let mut n_val = n;
706            let mut params = vec![
707                &mut out_ptr as *mut _ as *mut c_void,
708                &mut a_ptr as *mut _ as *mut c_void,
709                &mut b_ptr as *mut _ as *mut c_void,
710                &mut n_val as *mut _ as *mut c_void,
711            ];
712
713            let block = 256u32;
714            let grid = ((n + block - 1) / block) as u32;
715            check(
716                cuLaunchKernel(
717                    func,
718                    grid,
719                    1,
720                    1,
721                    block,
722                    1,
723                    1,
724                    0,
725                    ptr::null_mut(),
726                    params.as_mut_ptr(),
727                    ptr::null_mut(),
728                ),
729                "cuLaunchKernel add_vec_f32",
730            )?;
731            check(cuCtxSynchronize(), "cuCtxSynchronize")?;
732
733            check(
734                cuMemcpyDtoH(out.as_mut_ptr() as *mut c_void, d_out, bytes),
735                "cuMemcpyDtoH out",
736            )?;
737
738            let _ = cuMemFree(d_a);
739            let _ = cuMemFree(d_b);
740            let _ = cuMemFree(d_out);
741            Ok(())
742        }
743    }
744
745    pub fn add_scalar_f32_cuda(a: &[f32], alpha: f32, out: &mut [f32]) -> Result<(), GpuError> {
746        assert_eq!(a.len(), out.len());
747        let n = a.len() as u32;
748        unsafe {
749            let (_ctx, module) = load_module()?;
750            let func = get_function(module, "add_scalar_f32")?;
751            let bytes = (n as usize) * std::mem::size_of::<f32>();
752
753            let mut d_a: CUdeviceptr = 0;
754            let mut d_out: CUdeviceptr = 0;
755            check(cuMemAlloc(&mut d_a as *mut _, bytes), "cuMemAlloc a")?;
756            check(cuMemAlloc(&mut d_out as *mut _, bytes), "cuMemAlloc out")?;
757            check(
758                cuMemcpyHtoD(d_a, a.as_ptr() as *const c_void, bytes),
759                "cuMemcpyHtoD a",
760            )?;
761
762            let mut out_ptr = d_out as *mut c_void;
763            let mut a_ptr = d_a as *mut c_void;
764            let mut alpha_val = alpha;
765            let mut n_val = n;
766            let mut params = vec![
767                &mut out_ptr as *mut _ as *mut c_void,
768                &mut a_ptr as *mut _ as *mut c_void,
769                &mut alpha_val as *mut _ as *mut c_void,
770                &mut n_val as *mut _ as *mut c_void,
771            ];
772
773            let block = 256u32;
774            let grid = ((n + block - 1) / block) as u32;
775            check(
776                cuLaunchKernel(
777                    func,
778                    grid,
779                    1,
780                    1,
781                    block,
782                    1,
783                    1,
784                    0,
785                    ptr::null_mut(),
786                    params.as_mut_ptr(),
787                    ptr::null_mut(),
788                ),
789                "cuLaunchKernel add_scalar_f32",
790            )?;
791            check(cuCtxSynchronize(), "cuCtxSynchronize")?;
792
793            check(
794                cuMemcpyDtoH(out.as_mut_ptr() as *mut c_void, d_out, bytes),
795                "cuMemcpyDtoH out",
796            )?;
797            let _ = cuMemFree(d_a);
798            let _ = cuMemFree(d_out);
799            Ok(())
800        }
801    }
802
803    pub fn mul_scalar_f32_cuda(a: &[f32], alpha: f32, out: &mut [f32]) -> Result<(), GpuError> {
804        assert_eq!(a.len(), out.len());
805        let n = a.len() as u32;
806        unsafe {
807            let (_ctx, module) = load_module()?;
808            let func = get_function(module, "mul_scalar_f32")?;
809            let bytes = (n as usize) * std::mem::size_of::<f32>();
810
811            let mut d_a: CUdeviceptr = 0;
812            let mut d_out: CUdeviceptr = 0;
813            check(cuMemAlloc(&mut d_a as *mut _, bytes), "cuMemAlloc a")?;
814            check(cuMemAlloc(&mut d_out as *mut _, bytes), "cuMemAlloc out")?;
815            check(
816                cuMemcpyHtoD(d_a, a.as_ptr() as *const c_void, bytes),
817                "cuMemcpyHtoD a",
818            )?;
819
820            let mut out_ptr = d_out as *mut c_void;
821            let mut a_ptr = d_a as *mut c_void;
822            let mut alpha_val = alpha;
823            let mut n_val = n;
824            let mut params = vec![
825                &mut out_ptr as *mut _ as *mut c_void,
826                &mut a_ptr as *mut _ as *mut c_void,
827                &mut alpha_val as *mut _ as *mut c_void,
828                &mut n_val as *mut _ as *mut c_void,
829            ];
830
831            let block = 256u32;
832            let grid = ((n + block - 1) / block) as u32;
833            check(
834                cuLaunchKernel(
835                    func,
836                    grid,
837                    1,
838                    1,
839                    block,
840                    1,
841                    1,
842                    0,
843                    std::ptr::null_mut(),
844                    params.as_mut_ptr(),
845                    std::ptr::null_mut(),
846                ),
847                "cuLaunchKernel mul_scalar_f32",
848            )?;
849            check(cuCtxSynchronize(), "cuCtxSynchronize")?;
850
851            check(
852                cuMemcpyDtoH(out.as_mut_ptr() as *mut c_void, d_out, bytes),
853                "cuMemcpyDtoH out",
854            )?;
855            let _ = cuMemFree(d_a);
856            let _ = cuMemFree(d_out);
857            Ok(())
858        }
859    }
860
861    pub fn fir1d_batched_f32_cuda(
862        x: &Array2<f32>,
863        taps: &Array1<f32>,
864    ) -> Result<Array2<f32>, GpuError> {
865        let (b, n) = x.dim();
866        let k = taps.len();
867        let mut x_host = x.to_owned().into_raw_vec();
868        let taps_host = taps.as_slice().unwrap();
869        let mut out_host = vec![0.0f32; b * n];
870        unsafe {
871            let (_ctx, module) = load_module()?;
872            let func = get_function(module, "fir1d_batched_f32")?;
873
874            let bytes_x = x_host.len() * std::mem::size_of::<f32>();
875            let bytes_t = k * std::mem::size_of::<f32>();
876            let bytes_y = out_host.len() * std::mem::size_of::<f32>();
877
878            let mut d_x: CUdeviceptr = 0;
879            let mut d_t: CUdeviceptr = 0;
880            let mut d_y: CUdeviceptr = 0;
881            check(cuMemAlloc(&mut d_y as *mut _, bytes_y), "cuMemAlloc y")?;
882            check(cuMemAlloc(&mut d_x as *mut _, bytes_x), "cuMemAlloc x")?;
883            check(cuMemAlloc(&mut d_t as *mut _, bytes_t), "cuMemAlloc t")?;
884            check(
885                cuMemcpyHtoD(d_x, x_host.as_ptr() as *const c_void, bytes_x),
886                "HtoD x",
887            )?;
888            check(
889                cuMemcpyHtoD(d_t, taps_host.as_ptr() as *const c_void, bytes_t),
890                "HtoD t",
891            )?;
892
893            let mut y_ptr = d_y as *mut c_void;
894            let mut x_ptr = d_x as *mut c_void;
895            let mut t_ptr = d_t as *mut c_void;
896            let mut b_u32 = b as u32;
897            let mut n_u32 = n as u32;
898            let mut k_u32 = k as u32;
899            let mut params = vec![
900                &mut y_ptr as *mut _ as *mut c_void,
901                &mut x_ptr as *mut _ as *mut c_void,
902                &mut t_ptr as *mut _ as *mut c_void,
903                &mut b_u32 as *mut _ as *mut c_void,
904                &mut n_u32 as *mut _ as *mut c_void,
905                &mut k_u32 as *mut _ as *mut c_void,
906            ];
907
908            let total = (b * n) as u32;
909            let block = 256u32;
910            let grid = ((total + block - 1) / block) as u32;
911            check(
912                cuLaunchKernel(
913                    func,
914                    grid,
915                    1,
916                    1,
917                    block,
918                    1,
919                    1,
920                    0,
921                    std::ptr::null_mut(),
922                    params.as_mut_ptr(),
923                    std::ptr::null_mut(),
924                ),
925                "cuLaunchKernel fir1d_batched_f32",
926            )?;
927            check(cuCtxSynchronize(), "cuCtxSynchronize")?;
928            check(
929                cuMemcpyDtoH(out_host.as_mut_ptr() as *mut c_void, d_y, bytes_y),
930                "DtoH y",
931            )?;
932
933            let _ = cuMemFree(d_x);
934            let _ = cuMemFree(d_t);
935            let _ = cuMemFree(d_y);
936        }
937        Ok(Array2::from_shape_vec((b, n), out_host).unwrap())
938    }
939}
940
941#[cfg(feature = "cuda")]
942pub use cuda::{add_scalar_f32_cuda, add_vec_f32_cuda, mul_scalar_f32_cuda};
943
944/// Causal FIR over each row of `x` using `taps` (CPU baseline, f32).
945///
946/// Input shape is `(batch, n)` and the same shape is returned.
947///
948/// # Examples
949/// ```
950/// use ndarray::{array, Array1, Array2};
951/// // Two rows (batch=2), four samples each
952/// let x: Array2<f32> = array![[1.0, 2.0, 3.0, 4.0], [0.5, 0.0, -0.5, -1.0]];
953/// let taps: Array1<f32> = array![0.25, 0.5, 0.25];
954/// let y = scir_gpu::fir1d_batched_f32(&x, &taps);
955/// assert_eq!(y.shape(), &[2, 4]);
956/// ```
957pub fn fir1d_batched_f32(x: &Array2<f32>, taps: &Array1<f32>) -> Array2<f32> {
958    let (b, n) = x.dim();
959    let k = taps.len();
960    let mut y = Array2::<f32>::zeros((b, n));
961    for bi in 0..b {
962        let xin = x.index_axis(Axis(0), bi);
963        let mut yout = y.index_axis_mut(Axis(0), bi);
964        for i in 0..n {
965            let mut acc = 0.0f32;
966            let start = (i + 1).saturating_sub(k);
967            for (t_idx, xi) in (start..=i).rev().enumerate() {
968                let tap = taps[k - 1 - t_idx];
969                acc += tap * xin[xi];
970            }
971            yout[i] = acc;
972        }
973    }
974    y
975}
976
977/// Causal FIR over each row of `x` using `taps` (CPU baseline, f64).
978///
979/// Input shape is `(batch, n)` and the same shape is returned.
980///
981/// # Examples
982/// ```
983/// use ndarray::{array, Array1, Array2};
984/// let x: Array2<f64> = array![[1.0, 2.0, 3.0, 4.0]];
985/// let taps: Array1<f64> = array![0.25, 0.5, 0.25];
986/// let y = scir_gpu::fir1d_batched_f64(&x, &taps);
987/// assert_eq!(y.shape(), &[1, 4]);
988/// ```
989pub fn fir1d_batched_f64(x: &Array2<f64>, taps: &Array1<f64>) -> Array2<f64> {
990    let (b, n) = x.dim();
991    let k = taps.len();
992    let mut y = Array2::<f64>::zeros((b, n));
993    for bi in 0..b {
994        let xin = x.index_axis(Axis(0), bi);
995        let mut yout = y.index_axis_mut(Axis(0), bi);
996        for i in 0..n {
997            let mut acc = 0.0f64;
998            let start = (i + 1).saturating_sub(k);
999            for (t_idx, xi) in (start..=i).rev().enumerate() {
1000                let tap = taps[k - 1 - t_idx];
1001                acc += tap * xin[xi];
1002            }
1003            yout[i] = acc;
1004        }
1005    }
1006    y
1007}
1008
1009#[cfg(test)]
1010mod tests {
1011    use super::*;
1012    use ndarray::{array, Array1, Array2};
1013    use rand::Rng;
1014    use scir_core::assert_close;
1015
1016    #[test]
1017    fn device_array_roundtrip() {
1018        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1019        let arr = DeviceArray::from_cpu_slice(&[2, 2], DType::F32, &data);
1020        assert_eq!(arr.shape(), &[2, 2]);
1021        assert_eq!(arr.dtype(), DType::F32);
1022        assert_eq!(arr.device(), Device::Cpu);
1023        assert_eq!(arr.to_cpu_vec(), data);
1024    }
1025
1026    #[cfg(feature = "cuda")]
1027    #[test]
1028    fn cuda_add_and_add_scalar_f32() {
1029        // Attempt GPU op; if unavailable, treat as skip
1030        let a = vec![1.0f32, 2.0, 3.0, 4.0];
1031        let b = vec![0.5f32, 1.5, 2.5, 3.5];
1032        let mut out = vec![0.0f32; 4];
1033        match crate::add_vec_f32_cuda(&a, &b, &mut out) {
1034            Ok(()) => {
1035                let out_f64: Vec<f64> = out.iter().copied().map(|v| v as f64).collect();
1036                assert_close!(&out_f64, &[1.5, 3.5, 5.5, 7.5], slice, tol = 1e-6);
1037            }
1038            Err(_) => {
1039                eprintln!("CUDA not available; skipping CUDA test");
1040                return;
1041            }
1042        }
1043        let mut out2 = vec![0.0f32; 4];
1044        crate::add_scalar_f32_cuda(&a, 1.0, &mut out2).unwrap();
1045        let out2_f64: Vec<f64> = out2.iter().copied().map(|v| v as f64).collect();
1046        assert_close!(&out2_f64, &[2.0, 3.0, 4.0, 5.0], slice, tol = 1e-6);
1047
1048        let mut out3 = vec![0.0f32; 4];
1049        crate::mul_scalar_f32_cuda(&a, 2.0, &mut out3).unwrap();
1050        let out3_f64: Vec<f64> = out3.iter().copied().map(|v| v as f64).collect();
1051        assert_close!(&out3_f64, &[2.0, 4.0, 6.0, 8.0], slice, tol = 1e-6);
1052    }
1053
1054    #[test]
1055    fn elementwise_ops_baseline() {
1056        let data = vec![1.0f64, 2.0, 3.0, 4.0];
1057        let arr = DeviceArray::from_cpu_slice(&[4], DType::F64, &data);
1058        let add = arr.add_scalar(1.0);
1059        let mul = arr.mul_scalar(2.0);
1060        assert_close!(&add.to_cpu_vec(), &[2.0, 3.0, 4.0, 5.0], slice, tol = 0.0);
1061        assert_close!(&mul.to_cpu_vec(), &[2.0, 4.0, 6.0, 8.0], slice, tol = 0.0);
1062    }
1063
1064    #[test]
1065    fn add_arrays() {
1066        let a = DeviceArray::from_cpu_slice(&[3], DType::F32, &[1.0, 2.0, 3.0]);
1067        let b = DeviceArray::from_cpu_slice(&[3], DType::F32, &[0.5, 1.5, 2.5]);
1068        let c = a.add(&b).unwrap();
1069        assert_eq!(c.to_cpu_vec(), vec![1.5f32, 3.5, 5.5]);
1070    }
1071
1072    #[test]
1073    fn fir_batched_matches_naive_f32() {
1074        let x: Array2<f32> = array![[1.0, 2.0, 3.0, 4.0], [0.5, 0.0, -0.5, -1.0]];
1075        let taps: Array1<f32> = array![0.25, 0.5, 0.25];
1076        let y = fir1d_batched_f32(&x, &taps);
1077        // Manually compute expected for first row (compare as f64)
1078        let expected0_f64 = array![0.25f64, 1.0, 2.0, 3.0];
1079        let expected1_f64 = array![0.125f64, 0.25, 0.0, -0.5];
1080        let y0_f64 = y.index_axis(Axis(0), 0).to_owned().mapv(|v| v as f64);
1081        let y1_f64 = y.index_axis(Axis(0), 1).to_owned().mapv(|v| v as f64);
1082        assert_close!(&y0_f64, &expected0_f64, array, atol = 1e-7, rtol = 1e-7);
1083        assert_close!(&y1_f64, &expected1_f64, array, atol = 1e-7, rtol = 1e-7);
1084    }
1085
1086    #[test]
1087    fn fir_batched_random_f64() {
1088        let mut rng = rand::thread_rng();
1089        let b = 3usize;
1090        let n = 32usize;
1091        let k = 5usize;
1092        let mut x = Array2::<f64>::zeros((b, n));
1093        for mut row in x.axis_iter_mut(Axis(0)) {
1094            for v in row.iter_mut() {
1095                *v = rng.gen::<f64>() * 2.0 - 1.0;
1096            }
1097        }
1098        let taps = Array1::from((0..k).map(|i| 1.0 / (i as f64 + 1.0)).collect::<Vec<_>>());
1099        let y = fir1d_batched_f64(&x, &taps);
1100
1101        // Compare against a slow scalar reference
1102        let mut y_ref = Array2::<f64>::zeros((b, n));
1103        for bi in 0..b {
1104            for i in 0..n {
1105                let mut acc = 0.0f64;
1106                let start = (i + 1).saturating_sub(k);
1107                for (t_idx, xi) in (start..=i).rev().enumerate() {
1108                    let tap = taps[k - 1 - t_idx];
1109                    acc += tap * x[[bi, xi]];
1110                }
1111                y_ref[[bi, i]] = acc;
1112            }
1113        }
1114        assert_close!(
1115            &y.into_raw_vec(),
1116            &y_ref.into_raw_vec(),
1117            slice,
1118            atol = 1e-12,
1119            rtol = 1e-12
1120        );
1121    }
1122
1123    #[cfg(feature = "cuda")]
1124    #[test]
1125    fn cuda_fir1d_batched_f32_parity_small() {
1126        // Small deterministic input
1127        let x: Array2<f32> = array![[1.0, 2.0, 3.0, 4.0], [0.5, 0.0, -0.5, -1.0]];
1128        let taps: Array1<f32> = array![0.25, 0.5, 0.25];
1129        // Try CUDA; if not available, skip
1130        match crate::fir1d_batched_f32_cuda(&x, &taps) {
1131            Ok(y_cuda) => {
1132                let y_cpu = super::fir1d_batched_f32(&x, &taps);
1133                let y_cuda_f64: Vec<f64> = y_cuda
1134                    .into_raw_vec()
1135                    .into_iter()
1136                    .map(|v| v as f64)
1137                    .collect();
1138                let y_cpu_f64: Vec<f64> =
1139                    y_cpu.into_raw_vec().into_iter().map(|v| v as f64).collect();
1140                assert_close!(&y_cuda_f64, &y_cpu_f64, slice, atol = 1e-5, rtol = 1e-6);
1141            }
1142            Err(_) => {
1143                eprintln!("CUDA not available; skipping CUDA FIR test");
1144            }
1145        }
1146    }
1147}