concision_neural/model/
config.rs

1/*
2    Appellation: config <module>
3    Contrib: @FL03
4*/
5
6use crate::Hyperparameters::*;
7use crate::traits::{NetworkConfig, TrainingConfiguration};
8
9pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
10
11#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
13pub struct StandardModelConfig<T> {
14    pub(crate) batch_size: usize,
15    pub(crate) epochs: usize,
16    pub(crate) hyperparameters: ModelConfigMap<T>,
17}
18
19impl<T> StandardModelConfig<T> {
20    pub fn new() -> Self {
21        Self {
22            batch_size: 0,
23            epochs: 0,
24            hyperparameters: ModelConfigMap::new(),
25        }
26    }
27
28    pub fn batch_size(&self) -> usize {
29        self.batch_size
30    }
31
32    pub fn epochs(&self) -> usize {
33        self.epochs
34    }
35
36    pub fn with_batch_size(self, batch_size: usize) -> Self {
37        Self { batch_size, ..self }
38    }
39
40    pub fn with_epochs(self, epochs: usize) -> Self {
41        Self { epochs, ..self }
42    }
43
44    pub fn hyperparameters(&self) -> &ModelConfigMap<T> {
45        &self.hyperparameters
46    }
47
48    pub fn hyperparameters_mut(&mut self) -> &mut ModelConfigMap<T> {
49        &mut self.hyperparameters
50    }
51
52    pub fn insert_hyperparameter(&mut self, key: impl ToString, value: T) -> Option<T> {
53        self.hyperparameters_mut().insert(key.to_string(), value)
54    }
55
56    pub fn hyperparam<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T> where Q: ToString  {
57        self.hyperparameters_mut().entry(key.to_string())
58    }
59
60    pub fn remove_hyperparameter(&mut self, key: impl ToString) -> Option<T> {
61        self.hyperparameters_mut().remove(&key.to_string())
62    }
63
64    pub fn set_decay(&mut self, decay: T) -> Option<T> {
65        self.insert_hyperparameter(Decay, decay)
66    }
67    pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
68        self.insert_hyperparameter(LearningRate, learning_rate)
69    }
70
71    pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
72        self.insert_hyperparameter(Momentum, momentum)
73    }
74
75    pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
76        self.insert_hyperparameter("weight_decay", decay)
77    }
78
79    pub fn get(&self, key: impl ToString) -> Option<&T> {
80        self.hyperparameters().get(&key.to_string())
81    }
82
83    pub fn learning_rate(&self) -> Option<&T> {
84        self.get(LearningRate)
85    }
86
87    pub fn momentum(&self) -> Option<&T> {
88        self.get(Momentum)
89    }
90
91    pub fn decay(&self) -> Option<&T> {
92        self.get(Decay)
93    }
94
95    pub fn weight_decay(&self) -> Option<&T> {
96        self.get("weight_decay")
97    }
98}
99
100impl<T> NetworkConfig<T> for StandardModelConfig<T> {
101    fn get<K>(&self, key: K) -> Option<&T>
102    where
103        K: AsRef<str>,
104    {
105        self.hyperparameters.get(key.as_ref())
106    }
107
108    fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
109    where
110        K: AsRef<str>,
111    {
112        self.hyperparameters.get_mut(key.as_ref())
113    }
114
115    fn set<K>(&mut self, key: K, value: T) -> Option<T>
116    where
117        K: AsRef<str>,
118    {
119        self.hyperparameters.insert(key.as_ref().to_string(), value)
120    }
121
122    fn remove<K>(&mut self, key: K) -> Option<T>
123    where
124        K: AsRef<str>,
125    {
126        self.hyperparameters.remove(key.as_ref())
127    }
128
129    fn contains<K>(&self, key: K) -> bool
130    where
131        K: AsRef<str>,
132    {
133        self.hyperparameters.contains_key(key.as_ref())
134    }
135
136    fn keys(&self) -> Vec<String> {
137        self.hyperparameters.keys().cloned().collect()
138    }
139}
140
141impl<T> TrainingConfiguration<T> for StandardModelConfig<T> {
142    fn epochs(&self) -> usize {
143        self.epochs
144    }
145
146    fn batch_size(&self) -> usize {
147        self.batch_size
148    }
149}