dfdx 0.13.0

Ergonomic auto differentiation in Rust, with pytorch like apis.
Documentation
use crate::{dtypes::*, shapes::Shape, tensor::*};

use cudarc::driver::{DeviceRepr, LaunchAsync};

unsafe impl DeviceRepr for super::RollOp {}

const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/roll.ptx"));

trait HasCudaKernel<E> {
    const FNS: &'static [&'static str];
}
#[cfg(feature = "f16")]
impl HasCudaKernel<f16> for Cuda {
    const FNS: &'static [&'static str] = &["roll_fwd_f16", "roll_bwd_f16"];
}
#[cfg(feature = "f16")]
impl HasCudaKernel<AMP<f16>> for Cuda {
    const FNS: &'static [&'static str] = &["roll_fwd_f16", "roll_bwd_f16"];
}
impl HasCudaKernel<f32> for Cuda {
    const FNS: &'static [&'static str] = &["roll_fwd_f32", "roll_bwd_f32"];
}
impl HasCudaKernel<f64> for Cuda {
    const FNS: &'static [&'static str] = &["roll_fwd_f64", "roll_bwd_f64"];
}

impl<E: Dtype> super::RollKernel<E> for Cuda
where
    Self: HasCudaKernel<E>,
{
    fn forward<S: Shape>(
        &self,
        op: super::RollOp,
        inp: &Tensor<S, E, Self>,
    ) -> Result<Tensor<S, E, Self>, Self::Err> {
        if !self.dev.has_func(Self::FNS[0], Self::FNS[0]) {
            self.dev.load_ptx(PTX_SRC.into(), Self::FNS[0], Self::FNS)?;
        }

        let numel = inp.shape.num_elements();
        let strides = inp.shape.strides();

        let mut out = unsafe { self.alloc_empty::<E>(numel) }?;
        let dims = self.dev.htod_copy(inp.shape.concrete().into())?;
        let inp_strides = self.dev.htod_copy(inp.strides.into())?;
        let out_strides = self.dev.htod_copy(strides.into())?;

        let fwd = self.dev.get_func(Self::FNS[0], Self::FNS[0]).unwrap();
        let cfg = launch_cfg::<128>(inp.shape.num_elements() as u32);
        let params = (
            op,
            S::NUM_DIMS,
            numel,
            &dims,
            &inp_strides,
            &out_strides,
            inp.data.as_ref(),
            &mut out,
        );
        unsafe { fwd.launch(cfg, params) }?;
        Ok(self.build_tensor(inp.shape, strides, out))
    }
    fn backward<S: Shape>(
        &self,
        op: super::RollOp,
        inp: &Tensor<S, E, Self>,
        grad_inp: &mut Self::Vec,
        grad_out: &Self::Vec,
    ) -> Result<(), Self::Err> {
        let numel = inp.shape.num_elements();
        let strides = inp.shape.strides();

        let dims = self.dev.htod_copy(inp.shape.concrete().into())?;
        let inp_strides = self.dev.htod_copy(inp.strides.into())?;
        let out_strides = self.dev.htod_copy(strides.into())?;

        let bwd = self.dev.get_func(Self::FNS[0], Self::FNS[1]).unwrap();
        let cfg = launch_cfg::<128>(inp.shape.num_elements() as u32);
        let params = (
            op,
            S::NUM_DIMS,
            numel,
            &dims,
            &inp_strides,
            &out_strides,
            grad_inp,
            grad_out,
        );
        unsafe { bwd.launch(cfg, params) }?;
        Ok(())
    }
}