1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use crate::{
    shapes::*,
    tensor::{launch_cfg, Cuda, Tensor},
};
use cudarc::{
    driver::{DeviceSlice, LaunchAsync},
    nvrtc::{compile_ptx_with_opts, CompileOptions},
    types::CudaTypeName,
};

use std::vec::Vec;

impl<E: Dtype + CudaTypeName> super::ReshapeKernel<E> for Cuda {
    fn forward<Src: Shape, Dst: Shape>(
        &self,
        dst: &Dst,
        inp: &Tensor<Src, E, Self>,
    ) -> Result<Tensor<Dst, E, Self>, Self::Err> {
        let module = std::format!("reshape_fwd_{}", E::NAME);
        if !self.dev.has_func(&module, "reshape_fwd") {
            let src = FWD_KERNEL.replace("$T", E::NAME);
            let opts = CompileOptions {
                arch: Some(env!("CUDA_COMPUTE_CAP")),
                include_paths: vec![
                    env!("CUDA_INCLUDE_DIR").to_string(),
                    env!("OUT_DIR").to_string(),
                ],
                ..Default::default()
            };
            let ptx = compile_ptx_with_opts(src, opts).unwrap();
            self.dev.load_ptx(ptx, &module, &["reshape_fwd"])?;
        }
        let fwd_fn = self.dev.get_func(&module, "reshape_fwd").unwrap();

        let numel = inp.shape.num_elements();
        let mut storage = unsafe { self.alloc_empty::<E>(numel) }?;

        let mut info = Vec::with_capacity(Src::NUM_DIMS * 2 + Dst::NUM_DIMS * 2);
        info.extend(inp.shape.concrete());
        info.extend(inp.strides);
        info.extend(dst.concrete());
        info.extend(dst.strides());
        let info = self.dev.htod_copy(info)?;

        let cfg = launch_cfg::<128>(numel as u32);
        let params = (
            numel,             // const size_t numel,
            Src::NUM_DIMS,     // const size_t inp_num_dims,
            Dst::NUM_DIMS,     // const size_t out_num_dims,
            &info,             // const size_t *info,
            inp.data.as_ref(), // const float *inp,
            &mut storage,      // float *out
        );
        unsafe { fwd_fn.launch(cfg, params) }?;

        Ok(self.build_tensor(*dst, dst.strides(), storage))
    }

    fn backward<Src: Shape, Dst: Shape>(
        &self,
        dst: &Dst,
        inp: &Tensor<Src, E, Self>,
        grad_inp: &mut Self::Vec,
        grad_out: &Self::Vec,
    ) -> Result<(), Self::Err> {
        let module = std::format!("reshape_bwd_{}", E::NAME);
        if !self.dev.has_func(&module, "reshape_bwd") {
            let src = BWD_KERNEL.replace("$T", E::NAME);
            let opts = CompileOptions {
                arch: Some(env!("CUDA_COMPUTE_CAP")),
                include_paths: vec![
                    env!("CUDA_INCLUDE_DIR").to_string(),
                    env!("OUT_DIR").to_string(),
                ],
                ..Default::default()
            };
            let ptx = compile_ptx_with_opts(src, opts).unwrap();
            self.dev.load_ptx(ptx, &module, &["reshape_bwd"])?;
        }
        let bwd_fn = self.dev.get_func(&module, "reshape_bwd").unwrap();

        let numel = grad_inp.len();

        let mut info = Vec::with_capacity(Src::NUM_DIMS * 2 + Dst::NUM_DIMS * 2);
        info.extend(inp.shape.concrete());
        info.extend(inp.strides);
        info.extend(dst.concrete());
        info.extend(dst.strides());
        let info = self.dev.htod_copy(info)?;

        let cfg = launch_cfg::<128>(numel as u32);
        let params = (
            numel,         // const size_t numel,
            Src::NUM_DIMS, // const size_t inp_num_dims,
            Dst::NUM_DIMS, // const size_t out_num_dims,
            &info,         // const size_t *info,
            grad_inp,      // float *grad_inp,
            grad_out,      // const float *grad_out,
        );
        unsafe { bwd_fn.launch(cfg, params) }?;
        Ok(())
    }
}

const FWD_KERNEL: &str = "
#if __WORDSIZE == 64
typedef long int intptr_t;
#else
typedef int intptr_t;
#endif

#include \"cuda_utils.cuh\"

extern \"C\" __global__ void reshape_fwd(
    const size_t numel,
    const size_t inp_num_dims,
    const size_t out_num_dims,
    const size_t *info,
    const $T *inp,
    $T *out
) {
    const size_t *inp_dims = info;
    const size_t *inp_strides = info + inp_num_dims;
    const size_t *out_dims = info + 2 * inp_num_dims;
    const size_t *out_strides = info + 2 * inp_num_dims + out_num_dims;
    for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
        unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides);
        unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides);
        out[out_i] = inp[inp_i];
    }
}
";

const BWD_KERNEL: &str = "
#if __WORDSIZE == 64
typedef long int intptr_t;
#else
typedef int intptr_t;
#endif

#include \"cuda_utils.cuh\"

extern \"C\" __global__ void reshape_bwd(
    const size_t numel,
    const size_t inp_num_dims,
    const size_t out_num_dims,
    const size_t *info,
    $T *grad_inp,
    const $T *grad_out
) {
    const size_t *inp_dims = info;
    const size_t *inp_strides = info + inp_num_dims;
    const size_t *out_dims = info + 2 * inp_num_dims;
    const size_t *out_strides = info + 2 * inp_num_dims + out_num_dims;
    for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
        unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides);
        unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides);
        atomicAdd(grad_inp + inp_i, grad_out[out_i]);
    }
}
";