eryon_surface/model/
config.rs1use num_traits::FromPrimitive;
6
7#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
9#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
10pub struct SurfaceModelConfig<T = f64> {
11 pub(crate) decay: T, pub(crate) learning_rate: T, pub(crate) momentum: T, }
15
16impl<T> SurfaceModelConfig<T> {
17 pub fn new(learning_rate: T, momentum: T, decay: T) -> Self {
18 Self {
19 learning_rate,
20 momentum,
21 decay,
22 }
23 }
24 pub const fn decay(&self) -> &T {
26 &self.decay
27 }
28 pub const fn decay_mut(&mut self) -> &mut T {
30 &mut self.decay
31 }
32 pub const fn learning_rate(&self) -> &T {
34 &self.learning_rate
35 }
36 pub const fn learning_rate_mut(&mut self) -> &mut T {
38 &mut self.learning_rate
39 }
40 pub const fn momentum(&self) -> &T {
42 &self.momentum
43 }
44 pub const fn momentum_mut(&mut self) -> &mut T {
46 &mut self.momentum
47 }
48 pub fn set_decay(&mut self, decay: T) -> &mut Self {
50 self.decay = decay;
51 self
52 }
53 pub fn set_learning_rate(&mut self, learning_rate: T) -> &mut Self {
55 self.learning_rate = learning_rate;
56 self
57 }
58 pub fn set_momentum(&mut self, momentum: T) -> &mut Self {
60 self.momentum = momentum;
61 self
62 }
63 pub fn with_decay(self, decay: T) -> Self {
65 Self { decay, ..self }
66 }
67 pub fn with_learning_rate(self, learning_rate: T) -> Self {
69 Self {
70 learning_rate,
71 ..self
72 }
73 }
74 pub fn with_momentum(self, momentum: T) -> Self {
76 Self { momentum, ..self }
77 }
78}
79
80impl<T> Default for SurfaceModelConfig<T>
81where
82 T: FromPrimitive,
83{
84 fn default() -> Self {
85 Self {
86 learning_rate: T::from_f32(0.06).unwrap(),
87 momentum: T::from_f32(0.5).unwrap(),
88 decay: T::from_f32(0.5).unwrap(),
89 }
90 }
91}