concision_neural/model/
config.rs

1/*
2    Appellation: config <module>
3    Contrib: @FL03
4*/
5
6use crate::Hyperparameters::*;
7
8pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
9
10#[derive(Clone, Debug)]
11#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
12pub struct StandardModelConfig<T> {
13    pub(crate) batch_size: usize,
14    pub(crate) epochs: usize,
15    pub(crate) hyperparameters: ModelConfigMap<T>,
16}
17
18impl<T> StandardModelConfig<T> {
19    pub fn new() -> Self {
20        Self {
21            batch_size: 0,
22            epochs: 0,
23            hyperparameters: ModelConfigMap::new(),
24        }
25    }
26
27    pub fn batch_size(&self) -> usize {
28        self.batch_size
29    }
30
31    pub fn epochs(&self) -> usize {
32        self.epochs
33    }
34
35    pub fn with_batch_size(self, batch_size: usize) -> Self {
36        Self { batch_size, ..self }
37    }
38
39    pub fn with_epochs(self, epochs: usize) -> Self {
40        Self { epochs, ..self }
41    }
42
43    pub fn insert_hyperparameter(&mut self, key: impl ToString, value: T) -> Option<T> {
44        self.hyperparameters.insert(key.to_string(), value)
45    }
46
47    pub fn set_decay(&mut self, decay: T) -> Option<T> {
48        self.insert_hyperparameter(Decay, decay)
49    }
50    pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
51        self.insert_hyperparameter(LearningRate, learning_rate)
52    }
53
54    pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
55        self.insert_hyperparameter(Momentum, momentum)
56    }
57
58    pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
59        self.insert_hyperparameter("weight_decay", decay)
60    }
61
62    pub fn get(&self, key: impl ToString) -> Option<&T> {
63        self.hyperparameters.get(&key.to_string())
64    }
65
66    pub fn learning_rate(&self) -> Option<&T> {
67        self.get(LearningRate)
68    }
69
70    pub fn momentum(&self) -> Option<&T> {
71        self.get(Momentum)
72    }
73
74    pub fn decay(&self) -> Option<&T> {
75        self.get(Decay)
76    }
77}
78
79
80impl<T> crate::NetworkConfig<T> for StandardModelConfig<T> {
81    fn get<K>(&self, key: K) -> Option<&T>
82    where
83        K: AsRef<str>,
84    {
85        self.hyperparameters.get(key.as_ref())
86    }
87
88    fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
89    where
90        K: AsRef<str>,
91    {
92        self.hyperparameters.get_mut(key.as_ref())
93    }
94
95    fn set<K>(&mut self, key: K, value: T) -> Option<T>
96    where
97        K: AsRef<str>,
98    {
99        self.hyperparameters.insert(key.as_ref().to_string(), value)
100    }
101
102    fn remove<K>(&mut self, key: K) -> Option<T>
103    where
104        K: AsRef<str>,
105    {
106        self.hyperparameters.remove(key.as_ref())
107    }
108
109    fn contains<K>(&self, key: K) -> bool
110    where
111        K: AsRef<str>,
112    {
113        self.hyperparameters.contains_key(key.as_ref())
114    }
115
116    fn keys(&self) -> Vec<String> {
117        self.hyperparameters.keys().cloned().collect()
118    }
119}
120
121impl<T> crate::TrainingConfiguration<T> for StandardModelConfig<T> {
122    fn epochs(&self) -> usize {
123        self.epochs
124    }
125
126    fn batch_size(&self) -> usize {
127        self.batch_size
128    }
129}