use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
tensor_ops::reduction_utils::*,
};
use cudarc::driver::{DeviceSlice, LaunchAsync};
use std::vec::Vec;
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/max_to.ptx"));
trait HasCudaKernel<E> {
const INIT: E;
const MOD: &'static str;
const FNS: &'static [&'static str];
}
#[cfg(feature = "f16")]
impl HasCudaKernel<f16> for Cuda {
const INIT: f16 = f16::NEG_INFINITY;
const MOD: &'static str = "max_f16";
const FNS: &'static [&'static str] = &["max_to_fwd_f16", "max_to_bwd_f16", "fill_with_f16"];
}
#[cfg(feature = "f16")]
impl HasCudaKernel<AMP<f16>> for Cuda {
const INIT: AMP<f16> = AMP::<f16>::NEG_INFINITY;
const MOD: &'static str = "max_f16";
const FNS: &'static [&'static str] = &["max_to_fwd_f16", "max_to_bwd_f16", "fill_with_f16"];
}
impl HasCudaKernel<f32> for Cuda {
const INIT: f32 = f32::NEG_INFINITY;
const MOD: &'static str = "max_f32";
const FNS: &'static [&'static str] = &["max_to_fwd_f32", "max_to_bwd_f32", "fill_with_f32"];
}
impl HasCudaKernel<f64> for Cuda {
const INIT: f64 = f64::NEG_INFINITY;
const MOD: &'static str = "max_f64";
const FNS: &'static [&'static str] = &["max_to_fwd_f64", "max_to_bwd_f64", "fill_with_f64"];
}
impl<E: Dtype> super::MaxReduceKernel<E> for Cuda
where
Self: HasCudaKernel<E>,
{
fn forward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
dst: Dst,
inp: &Tensor<Src, E, Self>,
) -> Result<Tensor<Dst, E, Self>, Self::Err>
where
Src: ReduceShapeTo<Dst, Ax>,
{
if !self.dev.has_func(Self::MOD, Self::FNS[0]) {
self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?;
}
let fill_fn = self.dev.get_func(Self::MOD, Self::FNS[2]).unwrap();
let fwd_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap();
let mut storage = unsafe {
let mut storage = self.alloc_empty::<E>(dst.num_elements())?;
fill_fn.launch(
launch_cfg::<128>(dst.num_elements() as u32),
(&mut storage, Self::INIT, dst.num_elements()),
)?;
storage
};
let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides);
let num_dims = dims.len();
let mut info = Vec::with_capacity(num_dims * 2);
info.extend(dims);
info.extend(strides);
let info = self.dev.htod_copy(info)?;
let physical_numel = inp.data.len();
let (dst_physical_numel, dst_strides) =
reduction_output_strides::<Ax, Src, Dst>(inp.strides, dst);
let chunk_len = physical_numel / dst_physical_numel;
let cfg = launch_cfg::<128>(physical_numel as u32);
let params = (
physical_numel, num_dims, chunk_len, &info, inp.data.as_ref(), &mut storage, );
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(self.build_tensor(dst, dst_strides, storage))
}
fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
inp: &Tensor<Src, E, Self>,
grad_inp: &mut Self::Vec,
out: &Tensor<Dst, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), Self::Err>
where
Src: ReduceShapeTo<Dst, Ax>,
{
let bwd_fn = self.dev.get_func(Self::MOD, Self::FNS[1]).unwrap();
let out_strides: Src::Concrete =
BroadcastStridesTo::<Src, Ax>::broadcast_strides(&out.shape, out.strides);
let physical_numel = grad_inp.len();
let elems_per_thread = E::from_usize(reduction_elems_per_thread::<_, Src>(
inp.shape.concrete(),
inp.strides,
Ax::as_array(),
))
.unwrap();
let cfg = launch_cfg::<128>(physical_numel as u32);
let mut info = Vec::with_capacity(Src::NUM_DIMS * 3);
info.extend(inp.shape.concrete());
info.extend(inp.strides);
info.extend(out_strides);
let info = self.dev.htod_copy(info)?;
let params = (
physical_numel, Src::NUM_DIMS, elems_per_thread, &info, inp.data.as_ref(), grad_inp, out.data.as_ref(), grad_out, );
unsafe { bwd_fn.launch(cfg, params) }?;
Ok(())
}
}