use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Storage, Tensor},
};
use cudarc::driver::{CudaSlice, LaunchAsync};
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/choose.ptx"));
pub(crate) trait HasCudaKernel<E> {
const MOD: &'static str;
const FNS: &'static [&'static str];
}
#[cfg(feature = "f16")]
impl HasCudaKernel<AMP<f16>> for Cuda {
const MOD: &'static str = "choose_f16";
const FNS: &'static [&'static str] = &["choose_fwd_f16", "choose_bwd_f16"];
}
#[cfg(feature = "f16")]
impl HasCudaKernel<f16> for Cuda {
const MOD: &'static str = "choose_f16";
const FNS: &'static [&'static str] = &["choose_fwd_f16", "choose_bwd_f16"];
}
impl HasCudaKernel<f32> for Cuda {
const MOD: &'static str = "choose_f32";
const FNS: &'static [&'static str] = &["choose_fwd_f32", "choose_bwd_f32"];
}
impl HasCudaKernel<f64> for Cuda {
const MOD: &'static str = "choose_f64";
const FNS: &'static [&'static str] = &["choose_fwd_f64", "choose_bwd_f64"];
}
impl<E: Dtype> super::ChooseKernel<E> for Cuda
where
Self: HasCudaKernel<E>,
{
fn forward<S: Shape>(
&self,
cond: &Tensor<S, bool, Self>,
lhs: &Tensor<S, E, Self>,
rhs: &Tensor<S, E, Self>,
) -> Result<Tensor<S, E, Self>, Self::Err> {
if !self.dev.has_func(Self::MOD, Self::FNS[0]) {
self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?;
}
let shape = lhs.shape;
let strides = lhs.shape.strides();
let numel = shape.num_elements();
let mut storage = unsafe { self.alloc_empty::<E>(numel) }?;
let dims: CudaSlice<usize> = self.dev.htod_copy(shape.concrete().into())?;
let cond_strides: CudaSlice<usize> = self.dev.htod_copy(cond.strides.into())?;
let lhs_strides: CudaSlice<usize> = self.dev.htod_copy(lhs.strides.into())?;
let rhs_strides: CudaSlice<usize> = self.dev.htod_copy(rhs.strides.into())?;
let fwd_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap();
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, S::NUM_DIMS, &dims, cond.data.as_ref(), &cond_strides, lhs.data.as_ref(), &lhs_strides, rhs.data.as_ref(), &rhs_strides, &mut storage, );
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(self.build_tensor(shape, strides, storage))
}
fn backward<S: Shape>(
&self,
cond: &Tensor<S, bool, Self>,
lhs: &Tensor<S, E, Self>,
grad_lhs: &mut <Self as Storage<E>>::Vec,
rhs: &Tensor<S, E, Self>,
grad_rhs: &mut <Self as Storage<E>>::Vec,
grad_out: &<Self as Storage<E>>::Vec,
) -> Result<(), Self::Err> {
let bwd_fn = self.dev.get_func(Self::MOD, Self::FNS[1]).unwrap();
let numel = cond.shape.num_elements();
let dims: CudaSlice<usize> = self.dev.htod_copy(cond.shape.concrete().into())?;
let lhs_strides: CudaSlice<usize> = self.dev.htod_copy(lhs.strides.into())?;
let cond_strides: CudaSlice<usize> = self.dev.htod_copy(cond.strides.into())?;
let rhs_strides: CudaSlice<usize> = self.dev.htod_copy(rhs.strides.into())?;
let cfg = launch_cfg::<128>(numel as u32);
let params = (
numel, S::NUM_DIMS, &dims, cond.data.as_ref(), &cond_strides, grad_lhs, &lhs_strides, grad_rhs, &rhs_strides, grad_out, );
unsafe { bwd_fn.launch(cfg, params) }?;
Ok(())
}
}