#[allow(unused_imports)]
use crate::{
dtypes::*,
shapes::{RemoveDimTo, ReplaceDimTo, Shape},
tensor::{launch_cfg, Cuda, Storage, Tensor},
};
use cudarc::driver::{DeviceSlice, LaunchAsync};
const GATHER_PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/gather.ptx"));
const SELECT_PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/select.ptx"));
macro_rules! impl_cuda_kernels {
($TypeName:ty, $GatherMod:tt, $GatherFwd:tt, $GatherBwd:tt, $SelectMod:tt, $SelectFwd:tt, $SelectBwd:tt) => {
impl super::ReplaceDimKernel<$TypeName> for Cuda {
fn forward<Src: Shape, Dst: Shape, Idx: Shape>(
&self,
inp: &Tensor<Src, $TypeName, Self>,
idx: &Tensor<Idx, usize, Self>,
) -> Result<Tensor<Dst, $TypeName, Self>, Self::Err>
where
Src: ReplaceDimTo<Dst, Idx>,
{
if !self.dev.has_func($GatherMod, $GatherFwd) {
self.dev.load_ptx(
GATHER_PTX_SRC.into(),
$GatherMod,
&[$GatherFwd, $GatherBwd],
)?;
}
let dst = inp.shape.replace(idx.shape);
let numel = dst.num_elements();
let mut storage = unsafe { self.alloc_empty::<$TypeName>(numel) }?;
self.dev.memset_zeros(&mut storage)?;
let inp_dims = self.dev.htod_copy(inp.shape.concrete().into())?;
let idx_dims = self.dev.htod_copy(idx.shape.concrete().into())?;
let inp_strides = self.dev.htod_copy(inp.strides.into())?;
let idx_strides = self.dev.htod_copy(idx.strides.into())?;
let fwd_fn = self.dev.get_func($GatherMod, $GatherFwd).unwrap();
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, inp.data.as_ref(), Src::NUM_DIMS, &inp_dims, &inp_strides, idx.data.as_ref(), Idx::NUM_DIMS, &idx_dims, &idx_strides, &mut storage, Dst::NUM_DIMS, );
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(self.build_tensor(dst, dst.strides(), storage))
}
fn backward<Src: Shape, Dst: Shape, Idx: Shape>(
&self,
inp: &Tensor<Src, $TypeName, Self>,
grad_inp: &mut <Self as Storage<$TypeName>>::Vec,
idx: &Tensor<Idx, usize, Self>,
_: &Tensor<Dst, $TypeName, Self>,
grad_out: &<Self as Storage<$TypeName>>::Vec,
) -> Result<(), Self::Err>
where
Src: ReplaceDimTo<Dst, Idx>,
{
let bwd_fn = self.dev.get_func($GatherMod, $GatherBwd).unwrap();
let numel = grad_out.len();
let inp_dims = self.dev.htod_copy(inp.shape.concrete().into())?;
let idx_dims = self.dev.htod_copy(idx.shape.concrete().into())?;
let inp_strides = self.dev.htod_copy(inp.strides.into())?;
let idx_strides = self.dev.htod_copy(idx.strides.into())?;
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, grad_inp, Src::NUM_DIMS, &inp_dims, &inp_strides, idx.data.as_ref(), Idx::NUM_DIMS, &idx_dims, &idx_strides, grad_out, Dst::NUM_DIMS, );
unsafe { bwd_fn.launch(cfg, params) }?;
Ok(())
}
}
impl super::RemoveDimKernel<$TypeName> for Cuda {
fn forward<Src: Shape, Dst: Shape, Idx: Shape>(
&self,
inp: &Tensor<Src, $TypeName, Self>,
idx: &Tensor<Idx, usize, Self>,
) -> Result<Tensor<Dst, $TypeName, Self>, Self::Err>
where
Src: RemoveDimTo<Dst, Idx>,
{
if !self.dev.has_func($SelectMod, $SelectFwd) {
self.dev.load_ptx(
SELECT_PTX_SRC.into(),
$SelectMod,
&[$SelectFwd, $SelectBwd],
)?;
}
let dst = inp.shape.remove(idx.shape);
let numel = dst.num_elements();
let mut storage = unsafe { self.alloc_empty::<$TypeName>(numel) }?;
self.dev.memset_zeros(&mut storage)?;
let inp_dims = self.dev.htod_copy(inp.shape.concrete().into())?;
let idx_dims = self.dev.htod_copy(idx.shape.concrete().into())?;
let dst_dims = self.dev.htod_copy(dst.concrete().into())?;
let inp_strides = self.dev.htod_copy(inp.strides.into())?;
let idx_strides = self.dev.htod_copy(idx.strides.into())?;
let dst_strides = self.dev.htod_copy(dst.strides().into())?;
let fwd_fn = self.dev.get_func($SelectMod, $SelectFwd).unwrap();
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, inp.data.as_ref(), Src::NUM_DIMS, &inp_dims, &inp_strides, idx.data.as_ref(), Idx::NUM_DIMS, &idx_dims, &idx_strides, &mut storage, &dst_dims, &dst_strides, );
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(self.build_tensor(dst, dst.strides(), storage))
}
fn backward<Src: Shape, Dst: Shape, Idx: Shape>(
&self,
inp: &Tensor<Src, $TypeName, Self>,
grad_inp: &mut <Self as Storage<$TypeName>>::Vec,
idx: &Tensor<Idx, usize, Self>,
out: &Tensor<Dst, $TypeName, Self>,
grad_out: &<Self as Storage<$TypeName>>::Vec,
) -> Result<(), Self::Err>
where
Src: RemoveDimTo<Dst, Idx>,
{
let bwd_fn = self.dev.get_func($SelectMod, $SelectBwd).unwrap();
let numel = grad_out.len();
let inp_dims = self.dev.htod_copy(inp.shape.concrete().into())?;
let idx_dims = self.dev.htod_copy(idx.shape.concrete().into())?;
let out_dims = self.dev.htod_copy(out.shape.concrete().into())?;
let inp_strides = self.dev.htod_copy(inp.strides.into())?;
let idx_strides = self.dev.htod_copy(idx.strides.into())?;
let out_strides = self.dev.htod_copy(out.strides.into())?;
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, grad_inp, Src::NUM_DIMS, &inp_dims, &inp_strides, idx.data.as_ref(), Idx::NUM_DIMS, &idx_dims, &idx_strides, grad_out, &out_dims, &out_strides, );
unsafe { bwd_fn.launch(cfg, params) }?;
Ok(())
}
}
};
}
#[cfg(feature = "f16")]
impl_cuda_kernels!(
f16,
"gather_f16",
"gather_fwd_f16",
"gather_bwd_f16",
"select_f16",
"select_fwd_f16",
"select_bwd_f16"
);
#[cfg(feature = "f16")]
impl_cuda_kernels!(
AMP<f16>,
"gather_f16",
"gather_fwd_f16",
"gather_bwd_f16",
"select_f16",
"select_fwd_f16",
"select_bwd_f16"
);
impl_cuda_kernels!(
f32,
"gather_f32",
"gather_fwd_f32",
"gather_bwd_f32",
"select_f32",
"select_fwd_f32",
"select_bwd_f32"
);
impl_cuda_kernels!(
f64,
"gather_f64",
"gather_fwd_f64",
"gather_bwd_f64",
"select_f64",
"select_fwd_f64",
"select_bwd_f64"
);