eryon_surface/model/
config.rs

1/*
2    Appellation: config <module>
3    Contrib: @FL03
4*/
5use num_traits::FromPrimitive;
6
7/// Hyperparameters for the multi-layer perceptron model
8#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
9#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
10pub struct SurfaceModelConfig<T = f64> {
11    pub(crate) decay: T,         // decay
12    pub(crate) learning_rate: T, // learning rate
13    pub(crate) momentum: T,      // momentum
14}
15
16impl<T> SurfaceModelConfig<T> {
17    pub fn new(learning_rate: T, momentum: T, decay: T) -> Self {
18        Self {
19            learning_rate,
20            momentum,
21            decay,
22        }
23    }
24    /// returns a reference to the decay of the model
25    pub const fn decay(&self) -> &T {
26        &self.decay
27    }
28    /// returns a mutable reference to the decay of the model
29    pub const fn decay_mut(&mut self) -> &mut T {
30        &mut self.decay
31    }
32    /// returns a reference to the learning rate of the model
33    pub const fn learning_rate(&self) -> &T {
34        &self.learning_rate
35    }
36    /// returns a mutable reference to the learning rate of the model
37    pub const fn learning_rate_mut(&mut self) -> &mut T {
38        &mut self.learning_rate
39    }
40    /// returns a reference to the momentum of the model
41    pub const fn momentum(&self) -> &T {
42        &self.momentum
43    }
44    /// returns a mutable reference to the momentum of the model
45    pub const fn momentum_mut(&mut self) -> &mut T {
46        &mut self.momentum
47    }
48    /// update the decay and return a mutable reference to the config
49    pub fn set_decay(&mut self, decay: T) -> &mut Self {
50        self.decay = decay;
51        self
52    }
53    /// update the learning rate and return a mutable reference to the config
54    pub fn set_learning_rate(&mut self, learning_rate: T) -> &mut Self {
55        self.learning_rate = learning_rate;
56        self
57    }
58    /// update the momentum and return a mutable reference to the config
59    pub fn set_momentum(&mut self, momentum: T) -> &mut Self {
60        self.momentum = momentum;
61        self
62    }
63    /// consumes the current instance to create another with the given decay
64    pub fn with_decay(self, decay: T) -> Self {
65        Self { decay, ..self }
66    }
67    /// consumes the current instance to create another with the given learning rate
68    pub fn with_learning_rate(self, learning_rate: T) -> Self {
69        Self {
70            learning_rate,
71            ..self
72        }
73    }
74    /// consumes the current instance to create another with the given momentum
75    pub fn with_momentum(self, momentum: T) -> Self {
76        Self { momentum, ..self }
77    }
78}
79
80impl<T> Default for SurfaceModelConfig<T>
81where
82    T: FromPrimitive,
83{
84    fn default() -> Self {
85        Self {
86            learning_rate: T::from_f32(0.06).unwrap(),
87            momentum: T::from_f32(0.5).unwrap(),
88            decay: T::from_f32(0.5).unwrap(),
89        }
90    }
91}