#[derive(Debug, Clone, Copy)]
pub enum WeightDecay {
L2(f64),
Decoupled(f64),
}
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub(super) enum WeightDecayType {
None,
L2,
Decoupled,
}
#[cfg(feature = "cuda")]
pub(super) fn weight_decay_to_cuda(wd: Option<WeightDecay>) -> (WeightDecayType, f64) {
match wd {
None => (WeightDecayType::None, Default::default()),
Some(WeightDecay::L2(x)) => (WeightDecayType::L2, x),
Some(WeightDecay::Decoupled(x)) => (WeightDecayType::Decoupled, x),
}
}
#[derive(Debug, Clone, Copy)]
pub enum Momentum {
Classic(f64),
Nesterov(f64),
}
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub(super) enum MomentumType {
None,
Classic,
Nesterov,
}
#[cfg(feature = "cuda")]
pub(super) fn momentum_to_cuda(wd: Option<Momentum>) -> (MomentumType, f64) {
match wd {
None => (MomentumType::None, Default::default()),
Some(Momentum::Classic(x)) => (MomentumType::Classic, x),
Some(Momentum::Nesterov(x)) => (MomentumType::Nesterov, x),
}
}