use libc::c_float;
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct IncrementalParams {
pub learning_momentum: c_float,
pub learning_rate: c_float,
}
impl Default for IncrementalParams {
fn default() -> IncrementalParams {
IncrementalParams {
learning_momentum: 0.0,
learning_rate: 0.7,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct BatchParams {
pub learning_rate: c_float,
}
impl Default for BatchParams {
fn default() -> BatchParams {
BatchParams { learning_rate: 0.7 }
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct RpropParams {
pub decrease_factor: c_float,
pub increase_factor: c_float,
pub delta_min: c_float,
pub delta_max: c_float,
pub delta_zero: c_float,
}
impl Default for RpropParams {
fn default() -> RpropParams {
RpropParams {
decrease_factor: 0.5,
increase_factor: 1.2,
delta_min: 0.0,
delta_max: 50.0,
delta_zero: 0.1,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct QuickpropParams {
pub decay: c_float,
pub mu: c_float,
pub learning_rate: c_float,
}
impl Default for QuickpropParams {
fn default() -> QuickpropParams {
QuickpropParams {
decay: -0.0001,
mu: 1.75,
learning_rate: 0.7,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum TrainAlgorithm {
Incremental(IncrementalParams),
Batch(BatchParams),
Rprop(RpropParams),
Quickprop(QuickpropParams),
}
impl Default for TrainAlgorithm {
fn default() -> TrainAlgorithm {
TrainAlgorithm::Rprop(Default::default())
}
}