dfdx 0.13.0

Ergonomic auto differentiation in Rust, with pytorch like apis.
Documentation
use crate::{
    dtypes::*,
    shapes::*,
    tensor::{launch_cfg, Cuda, Tensor},
};

use std::vec::Vec;

use cudarc::driver::{DeviceSlice, LaunchAsync};

use rand::{rngs::StdRng, SeedableRng};
use rand_distr::{Bernoulli, Distribution};

const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/dropout.ptx"));

trait HasCudaKernel<E> {
    const MOD: &'static str;
    const FNS: &'static [&'static str];
}

#[cfg(feature = "f16")]
impl HasCudaKernel<f16> for Cuda {
    const MOD: &'static str = "dropout_f16";
    const FNS: &'static [&'static str] = &["dropout_fwd_f16", "dropout_bwd_f16"];
}

#[cfg(feature = "f16")]
impl HasCudaKernel<AMP<f16>> for Cuda {
    const MOD: &'static str = "dropout_f16";
    const FNS: &'static [&'static str] = &["dropout_fwd_f16", "dropout_bwd_f16"];
}

impl HasCudaKernel<f32> for Cuda {
    const MOD: &'static str = "dropout_f32";
    const FNS: &'static [&'static str] = &["dropout_fwd_f32", "dropout_bwd_f32"];
}

impl HasCudaKernel<f64> for Cuda {
    const MOD: &'static str = "dropout_f64";
    const FNS: &'static [&'static str] = &["dropout_fwd_f64", "dropout_bwd_f64"];
}

impl<E: Dtype> super::DropoutKernel<E> for Cuda
where
    Self: HasCudaKernel<E>,
{
    fn forward<S: Shape>(
        &self,
        op: super::DropoutKernelOp,
        inp: &Tensor<S, E, Self>,
    ) -> Result<Tensor<S, E, Self>, Self::Err> {
        let mask = {
            let mut rng = StdRng::seed_from_u64(op.seed);
            let dist = Bernoulli::new(op.prob).unwrap();
            let mut mask: Vec<bool> = Vec::with_capacity(inp.data.len());
            mask.resize_with(inp.data.len(), || dist.sample(&mut rng));
            self.dev.htod_copy(mask)
        }?;

        let prob = E::from_f64(op.prob).unwrap();

        if !self.dev.has_func(Self::MOD, Self::FNS[0]) {
            self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?;
        }

        let numel = inp.data.len();
        let mut storage = unsafe { self.alloc_empty::<E>(numel) }?;

        let fwd_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap();
        let cfg = launch_cfg::<128>(numel as u32);
        let params = (prob, numel, inp.data.as_ref(), &mask, &mut storage);
        unsafe { fwd_fn.launch(cfg, params) }?;
        Ok(self.build_tensor(inp.shape, inp.strides, storage))
    }
    fn backward<S: Shape>(
        &self,
        op: super::DropoutKernelOp,
        inp: &Tensor<S, E, Self>,
        grad_inp: &mut Self::Vec,
        grad_out: &Self::Vec,
    ) -> Result<(), Self::Err> {
        let mask = {
            let mut rng = StdRng::seed_from_u64(op.seed);
            let dist = Bernoulli::new(op.prob).unwrap();
            let mut mask: Vec<bool> = Vec::with_capacity(inp.data.len());
            mask.resize_with(inp.data.len(), || dist.sample(&mut rng));
            self.dev.htod_copy(mask)
        }?;
        let prob = E::from_f64(op.prob).unwrap();
        let bwd_fn = self.dev.get_func(Self::MOD, Self::FNS[1]).unwrap();
        let numel = inp.data.len();
        let cfg = launch_cfg::<128>(numel as u32);
        let params = (prob, numel, &mask, grad_inp, grad_out);
        unsafe { bwd_fn.launch(cfg, params) }?;
        Ok(())
    }
}