use crate::{
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
};
use cudarc::{
driver::{DeviceSlice, LaunchAsync},
nvrtc::{compile_ptx_with_opts, CompileOptions},
types::CudaTypeName,
};
use std::vec::Vec;
impl<E: Dtype + CudaTypeName> super::ReshapeKernel<E> for Cuda {
fn forward<Src: Shape, Dst: Shape>(
&self,
dst: &Dst,
inp: &Tensor<Src, E, Self>,
) -> Result<Tensor<Dst, E, Self>, Self::Err> {
let module = std::format!("reshape_fwd_{}", E::NAME);
if !self.dev.has_func(&module, "reshape_fwd") {
let src = FWD_KERNEL.replace("$T", E::NAME);
let opts = CompileOptions {
arch: Some(env!("CUDA_COMPUTE_CAP")),
include_paths: vec![
env!("CUDA_INCLUDE_DIR").to_string(),
env!("OUT_DIR").to_string(),
],
..Default::default()
};
let ptx = compile_ptx_with_opts(src, opts).unwrap();
self.dev.load_ptx(ptx, &module, &["reshape_fwd"])?;
}
let fwd_fn = self.dev.get_func(&module, "reshape_fwd").unwrap();
let numel = inp.shape.num_elements();
let mut storage = unsafe { self.alloc_empty::<E>(numel) }?;
let mut info = Vec::with_capacity(Src::NUM_DIMS * 2 + Dst::NUM_DIMS * 2);
info.extend(inp.shape.concrete());
info.extend(inp.strides);
info.extend(dst.concrete());
info.extend(dst.strides());
let info = self.dev.htod_copy(info)?;
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, Src::NUM_DIMS, Dst::NUM_DIMS, &info, inp.data.as_ref(), &mut storage, );
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(self.build_tensor(*dst, dst.strides(), storage))
}
fn backward<Src: Shape, Dst: Shape>(
&self,
dst: &Dst,
inp: &Tensor<Src, E, Self>,
grad_inp: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Self::Err> {
let module = std::format!("reshape_bwd_{}", E::NAME);
if !self.dev.has_func(&module, "reshape_bwd") {
let src = BWD_KERNEL.replace("$T", E::NAME);
let opts = CompileOptions {
arch: Some(env!("CUDA_COMPUTE_CAP")),
include_paths: vec![
env!("CUDA_INCLUDE_DIR").to_string(),
env!("OUT_DIR").to_string(),
],
..Default::default()
};
let ptx = compile_ptx_with_opts(src, opts).unwrap();
self.dev.load_ptx(ptx, &module, &["reshape_bwd"])?;
}
let bwd_fn = self.dev.get_func(&module, "reshape_bwd").unwrap();
let numel = grad_inp.len();
let mut info = Vec::with_capacity(Src::NUM_DIMS * 2 + Dst::NUM_DIMS * 2);
info.extend(inp.shape.concrete());
info.extend(inp.strides);
info.extend(dst.concrete());
info.extend(dst.strides());
let info = self.dev.htod_copy(info)?;
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, Src::NUM_DIMS, Dst::NUM_DIMS, &info, grad_inp, grad_out, );
unsafe { bwd_fn.launch(cfg, params) }?;
Ok(())
}
}
const FWD_KERNEL: &str = "
#if __WORDSIZE == 64
typedef long int intptr_t;
#else
typedef int intptr_t;
#endif
#include \"cuda_utils.cuh\"
extern \"C\" __global__ void reshape_fwd(
const size_t numel,
const size_t inp_num_dims,
const size_t out_num_dims,
const size_t *info,
const $T *inp,
$T *out
) {
const size_t *inp_dims = info;
const size_t *inp_strides = info + inp_num_dims;
const size_t *out_dims = info + 2 * inp_num_dims;
const size_t *out_strides = info + 2 * inp_num_dims + out_num_dims;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides);
unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides);
out[out_i] = inp[inp_i];
}
}
";
const BWD_KERNEL: &str = "
#if __WORDSIZE == 64
typedef long int intptr_t;
#else
typedef int intptr_t;
#endif
#include \"cuda_utils.cuh\"
extern \"C\" __global__ void reshape_bwd(
const size_t numel,
const size_t inp_num_dims,
const size_t out_num_dims,
const size_t *info,
$T *grad_inp,
const $T *grad_out
) {
const size_t *inp_dims = info;
const size_t *inp_strides = info + inp_num_dims;
const size_t *out_dims = info + 2 * inp_num_dims;
const size_t *out_strides = info + 2 * inp_num_dims + out_num_dims;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides);
unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides);
atomicAdd(grad_inp + inp_i, grad_out[out_i]);
}
}
";