1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
use crate::{
    shapes::*,
    tensor::{launch_cfg, Cuda, Tensor},
};
use cudarc::{
    driver::{DeviceSlice, LaunchAsync},
    nvrtc::{compile_ptx_with_opts, CompileOptions},
    types::CudaTypeName,
};
use std::vec::Vec;

impl<E: Dtype + CudaTypeName> super::StackKernel<E> for Cuda {
    fn forward<S: Shape, Num: Dim>(
        &self,
        num: Num,
        inps: &[Tensor<S, E, Self>],
    ) -> Result<Tensor<S::Larger, E, Self>, Self::Err>
    where
        S: super::AddDim<Num>,
    {
        debug_assert_eq!(inps.len(), num.size());

        // check that all the strides are the same
        let item_strides = inps[0].strides;
        for i in inps.iter() {
            assert_eq!(i.strides, item_strides);
        }
        let shape: S::Larger = inps[0].shape().add_dim(num);

        // build the new strides
        let mut strides = shape.strides();
        strides[0] = inps[0].data.len();
        for d in 1..<S::Larger as Shape>::NUM_DIMS {
            strides[d] = item_strides[d - 1];
        }

        // copy the data
        let item_numel = strides[0];
        let mut data = unsafe { self.alloc_empty::<E>(num.size() * item_numel) }?;
        let mut offset = 0;
        for item in inps {
            debug_assert_eq!(item.data.len(), item_numel);
            self.dev.dtod_copy(
                item.data.as_ref(),
                &mut data.slice_mut(offset..offset + item_numel),
            )?;
            offset += item_numel;
        }
        debug_assert_eq!(offset, data.len());
        Ok(self.build_tensor(shape, strides, data))
    }

    fn backward(
        &self,
        mut grad_inp: Vec<&mut Self::Vec>,
        grad_out: &Self::Vec,
    ) -> Result<(), Self::Err> {
        let module_name = std::format!("stack_bwd_{}", E::NAME);
        if !self.dev.has_func(&module_name, "stack_bwd") {
            let src = BWD_KERNEL.replace("$Ty", E::NAME);
            let opts = CompileOptions {
                arch: Some(env!("CUDA_COMPUTE_CAP")),
                include_paths: vec![env!("CUDA_INCLUDE_DIR").to_string()],
                ..Default::default()
            };
            let ptx = compile_ptx_with_opts(src, opts).unwrap();
            self.dev.load_ptx(ptx, &module_name, &["stack_bwd"])?;
        }

        let mut offset = 0;
        for item in grad_inp.drain(..) {
            let f = self.dev.get_func(&module_name, "stack_bwd").unwrap();
            let numel: usize = item.len();
            let cfg = launch_cfg::<128>(numel as u32);
            let sub = grad_out.slice(offset..offset + numel);
            unsafe { f.launch(cfg, (numel, &sub, item)) }?;
            offset += numel;
        }
        debug_assert_eq!(offset, grad_out.len());
        Ok(())
    }
}

const BWD_KERNEL: &str = "
#include \"cuda_fp16.h\"
extern \"C\" __global__ void stack_bwd(const size_t numel, const $Ty *inp, $Ty *out) {
    for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
        out[i] += inp[i];
    }
}
";