dfdx 0.13.0

Ergonomic auto differentiation in Rust, with pytorch like apis.
Documentation
use super::RMSpropConfig;
use crate::{
    dtypes::*,
    tensor::{launch_cfg, Cuda},
    tensor_ops::optim::*,
};

use cudarc::driver::{DeviceRepr, DeviceSlice, LaunchAsync};

#[repr(C)]
struct CudaRMSpropConfig {
    lr: f64,
    alpha: f64,
    eps: f64,
    centered: bool,
    has_momentum: bool,
    momentum: f64,
    weight_decay_type: WeightDecayType,
    weight_decay: f64,
}

unsafe impl DeviceRepr for CudaRMSpropConfig {}

fn rmsprop_config_to_cuda(config: &RMSpropConfig) -> CudaRMSpropConfig {
    let (weight_decay_type, weight_decay) = weight_decay_to_cuda(config.weight_decay);
    let (has_momentum, momentum) = if let Some(m) = config.momentum {
        (true, m)
    } else {
        (false, Default::default())
    };

    CudaRMSpropConfig {
        lr: config.lr,
        alpha: config.alpha,
        eps: config.eps,
        centered: config.centered,
        has_momentum,
        momentum,
        weight_decay_type,
        weight_decay,
    }
}

const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/rmsprop.ptx"));

trait HasCudaKernel<E> {
    const MOD: &'static str;
    const FWD: &'static str;
}

#[cfg(feature = "f16")]
impl HasCudaKernel<f16> for Cuda {
    const MOD: &'static str = "rmsprop_f16";
    const FWD: &'static str = "rmsprop_update_f16";
}

#[cfg(feature = "f16")]
impl HasCudaKernel<AMP<f16>> for Cuda {
    const MOD: &'static str = "rmsprop_amp_f16";
    const FWD: &'static str = "rmsprop_update_amp_f16";
}

impl HasCudaKernel<f32> for Cuda {
    const MOD: &'static str = "rmsprop_f32";
    const FWD: &'static str = "rmsprop_update_f32";
}

impl HasCudaKernel<f64> for Cuda {
    const MOD: &'static str = "rmsprop_f64";
    const FWD: &'static str = "rmsprop_update_f64";
}

impl<E: Dtype> super::RMSpropKernel<E> for Cuda
where
    Self: HasCudaKernel<E>,
{
    fn rmsprop_kernel(
        &self,
        cfg: &RMSpropConfig,
        param: &mut Self::Vec,
        momentum: &mut Self::Vec,
        square_avg: &mut Self::Vec,
        grad_avg: &mut Self::Vec,
        grad: &Self::Vec,
    ) -> Result<(), Self::Err> {
        if !self.dev.has_func(Self::MOD, Self::FWD) {
            self.dev.load_ptx(PTX_SRC.into(), Self::MOD, &[Self::FWD])?;
        }

        let opt_cfg = rmsprop_config_to_cuda(cfg);
        let numel = param.len();
        let func = self.dev.get_func(Self::MOD, Self::FWD).unwrap();
        let cfg = launch_cfg::<128>(numel as u32);
        let params = (opt_cfg, numel, param, momentum, square_avg, grad_avg, grad);
        unsafe { func.launch(cfg, params) }?;
        Ok(())
    }
}