dfdx 0.13.0

Ergonomic auto differentiation in Rust, with pytorch like apis.
Documentation
use cudarc::cudnn::{self, Conv2dBackwardData, Conv2dBackwardFilter, Conv2dForward, CudnnDataType};
use cudarc::driver::DeviceSlice;

use crate::{
    dtypes::*,
    shapes::*,
    tensor::{Cuda, Tensor, Tensorlike},
};

use std::sync::Arc;

trait HasCudnnKernel<E> {}
#[cfg(feature = "f16")]
impl HasCudnnKernel<f16> for Cuda {}
#[cfg(feature = "f16")]
impl HasCudnnKernel<AMP<f16>> for Cuda {}
impl HasCudnnKernel<f32> for Cuda {}
impl HasCudnnKernel<f64> for Cuda {}

fn make_4d<S: Shape>(strides: S::Concrete, pad: usize) -> [usize; 4] {
    match S::NUM_DIMS {
        3 => [pad, strides[0], strides[1], strides[2]],
        4 => [strides[0], strides[1], strides[2], strides[3]],
        _ => unreachable!("Only implemented for 3d & 4d arrays"),
    }
}

impl<E: Dtype + CudnnDataType> super::Conv2DKernel<E> for Cuda
where
    Self: HasCudnnKernel<E>,
{
    fn alloc<S: Shape>(&self, shape: S) -> Result<Tensor<S, E, Self>, Self::Err> {
        let data = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
        Ok(self.build_tensor(shape, shape.strides(), data))
    }
    fn forward<L: Shape, R: Shape, O: Shape>(
        &self,
        op: super::Conv2DOp,
        lhs: &Tensor<L, E, Self>,
        rhs: &Tensor<R, E, Self>,
        out: &mut Tensor<O, E, Self>,
    ) -> Result<(), Self::Err> {
        let mut conv = self.cudnn.create_conv2d::<E>(
            [op.padding as i32, op.padding as i32],
            [op.stride as i32, op.stride as i32],
            [op.dilation as i32, op.dilation as i32],
            cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
        )?;
        conv.set_group_count(op.groups as i32)?;
        let img = self.cudnn.create_4d_tensor_ex::<E>(
            make_4d::<L>(lhs.shape.concrete(), 1).map(|x| x as i32),
            make_4d::<L>(lhs.strides, 0).map(|x| x as i32),
        )?;
        let filter = self.cudnn.create_4d_filter::<E>(
            cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
            make_4d::<R>(rhs.shape.concrete(), 1).map(|x| x as i32),
        )?;
        let y = self.cudnn.create_4d_tensor_ex::<E>(
            make_4d::<O>(out.shape.concrete(), 1).map(|x| x as i32),
            make_4d::<O>(out.strides, 0).map(|x| x as i32),
        )?;
        let op = Conv2dForward {
            conv: &conv,
            x: &img,
            w: &filter,
            y: &y,
        };

        let algo = op.pick_algorithm()?;
        let workspace_size_in_bytes = op.get_workspace_size(algo)?;

        unsafe {
            let mut workspace = self.get_workspace::<u8>(workspace_size_in_bytes)?;
            let mut workspace = workspace
                .transmute_mut::<u8>(workspace_size_in_bytes)
                .unwrap();
            assert_eq!(workspace.len(), workspace_size_in_bytes);
            op.launch(
                algo,
                Some(&mut workspace),
                (E::ONE, Default::default()),
                lhs.data.as_ref(),
                rhs.data.as_ref(),
                Arc::get_mut(&mut out.data).unwrap(),
            )?;
        }

        Ok(())
    }

    fn backward<L: Shape, R: Shape, O: Shape>(
        &self,
        op: super::Conv2DOp,
        lhs: &Tensor<L, E, Self>,
        grad_lhs: &mut Self::Vec,
        rhs: &Tensor<R, E, Self>,
        grad_rhs: &mut Self::Vec,
        out: &impl Tensorlike<O, E, Self>,
        grad_out: &Self::Vec,
    ) -> Result<(), Self::Err> {
        let mut conv = self.cudnn.create_conv2d::<E>(
            [op.padding as i32, op.padding as i32],
            [op.stride as i32, op.stride as i32],
            [op.dilation as i32, op.dilation as i32],
            cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
        )?;
        conv.set_group_count(op.groups as i32)?;
        let img = self.cudnn.create_4d_tensor_ex::<E>(
            make_4d::<L>(lhs.shape.concrete(), 1).map(|x| x as i32),
            make_4d::<L>(lhs.strides, 0).map(|x| x as i32),
        )?;
        let filter = self.cudnn.create_4d_filter::<E>(
            cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
            make_4d::<R>(rhs.shape.concrete(), 1).map(|x| x as i32),
        )?;
        let out = self.cudnn.create_4d_tensor_ex::<E>(
            make_4d::<O>(out.shape().concrete(), 1).map(|x| x as i32),
            make_4d::<O>(out.strides(), 0).map(|x| x as i32),
        )?;

        {
            let op = Conv2dBackwardData {
                conv: &conv,
                dx: &img,
                w: &filter,
                dy: &out,
            };
            let algo = op.pick_algorithm()?;
            let workspace_size_in_bytes = op.get_workspace_size(algo)?;

            unsafe {
                let mut workspace = self.get_workspace::<u8>(workspace_size_in_bytes)?;
                let mut workspace = workspace
                    .transmute_mut::<u8>(workspace_size_in_bytes)
                    .unwrap();
                assert_eq!(workspace.len(), workspace_size_in_bytes);
                op.launch(
                    algo,
                    Some(&mut workspace),
                    (E::ONE, Default::default()),
                    grad_lhs,
                    rhs.data.as_ref(),
                    grad_out,
                )
            }?;
        }

        {
            let op = Conv2dBackwardFilter {
                conv: &conv,
                x: &img,
                dw: &filter,
                dy: &out,
            };

            let algo = op.pick_algorithm()?;
            let workspace_size_in_bytes = op.get_workspace_size(algo)?;

            unsafe {
                let mut workspace = self.get_workspace::<u8>(workspace_size_in_bytes)?;
                let mut workspace = workspace
                    .transmute_mut::<u8>(workspace_size_in_bytes)
                    .unwrap();
                assert_eq!(workspace.len(), workspace_size_in_bytes);
                op.launch(
                    algo,
                    Some(&mut workspace),
                    (E::ONE, Default::default()),
                    lhs.data.as_ref(),
                    grad_rhs,
                    grad_out,
                )
            }?;
        }
        Ok(())
    }
}