use crate::{
dtypes::*,
prelude::cpu::NdIndex,
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
};
use cudarc::driver::{CudaSlice, LaunchAsync};
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/slice.ptx"));
pub(crate) trait HasCudaKernel<E> {
const MOD: &'static str;
const FNS: &'static [&'static str];
}
macro_rules! has_kernels {
($($dtype:ty),*) => {
$(
impl HasCudaKernel<$dtype> for Cuda {
const MOD: &'static str = concat!("slice_", stringify!($dtype));
const FNS: &'static [&'static str] = &[concat!("slice_fwd_", stringify!($dtype))];
}
)*
}
}
has_kernels!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, bool);
#[cfg(feature = "f16")]
impl HasCudaKernel<f16> for Cuda {
const MOD: &'static str = "slice_f16";
const FNS: &'static [&'static str] = &["slice_fwd_f16", "slice_bwd_f16"];
}
#[cfg(feature = "f16")]
impl HasCudaKernel<AMP<f16>> for Cuda {
const MOD: &'static str = "slice_f16";
const FNS: &'static [&'static str] = &["slice_fwd_f16", "slice_bwd_f16"];
}
impl HasCudaKernel<f32> for Cuda {
const MOD: &'static str = "slice_f32";
const FNS: &'static [&'static str] = &["slice_fwd_f32", "slice_bwd_f32"];
}
impl HasCudaKernel<f64> for Cuda {
const MOD: &'static str = "slice_f64";
const FNS: &'static [&'static str] = &["slice_fwd_f64", "slice_bwd_f64"];
}
impl<E: Dtype> super::SliceKernel<E> for Cuda
where
Self: HasCudaKernel<E>,
{
fn forward<Src: Shape + SliceShape<Slice>, Slice>(
&self,
inp: &Tensor<Src, E, Self>,
slice: &Slice,
) -> Result<Tensor<Src::Sliced, E, Self>, Self::Err> {
if !self.dev.has_func(Self::MOD, Self::FNS[0]) {
self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?;
}
let dst = inp.shape.slice(slice).unwrap();
let strides = inp.strides;
let numel = dst.num_elements();
let start_idx = NdIndex::new(inp.shape, inp.strides)
.get_strided_index(inp.shape.first_idx_in_slice(slice));
let mut storage = unsafe { self.alloc_empty::<E>(numel) }?;
let dims: CudaSlice<usize> = self.dev.htod_copy(dst.concrete().into())?;
let strides: CudaSlice<usize> = self.dev.htod_copy(strides.into())?;
let fwd_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap();
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, Src::NUM_DIMS, &dims, &strides, start_idx, inp.data.as_ref(), &mut storage, );
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(self.build_tensor(dst, dst.strides(), storage))
}
fn backward<Src: Shape + SliceShape<Slice>, Slice>(
&self,
inp: &Tensor<Src, E, Self>,
grad_inp: &mut Self::Vec,
grad_out: &Self::Vec,
slice: &Slice,
) -> Result<(), Self::Err> {
if !self.dev.has_func(Self::MOD, Self::FNS[1]) {
self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?;
}
let dst = inp.shape.slice(slice).unwrap();
let strides = inp.strides;
let numel = dst.num_elements();
let start_idx = NdIndex::new(inp.shape, inp.strides)
.get_strided_index(inp.shape.first_idx_in_slice(slice));
let dims: CudaSlice<usize> = self.dev.htod_copy(dst.concrete().into())?;
let strides: CudaSlice<usize> = self.dev.htod_copy(strides.into())?;
let bwd_fn = self.dev.get_func(Self::MOD, Self::FNS[1]).unwrap();
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, Src::NUM_DIMS, &dims, &strides, start_idx, grad_inp, grad_out, );
unsafe { bwd_fn.launch(cfg, params) }?;
Ok(())
}
}