concision_neural/config/
model_config.rs1use super::Hyperparameters::*;
6use super::{NetworkConfig, TrainingConfiguration};
7
8pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
9
10#[derive(Clone, Debug)]
11#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
12pub struct StandardModelConfig<T> {
13 pub(crate) batch_size: usize,
14 pub(crate) epochs: usize,
15 pub(crate) hyperparameters: ModelConfigMap<T>,
16}
17
18impl<T> Default for StandardModelConfig<T> {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl<T> StandardModelConfig<T> {
25 pub fn new() -> Self {
26 Self {
27 batch_size: 0,
28 epochs: 0,
29 hyperparameters: ModelConfigMap::new(),
30 }
31 }
32 pub const fn batch_size(&self) -> usize {
34 self.batch_size
35 }
36 pub const fn batch_size_mut(&mut self) -> &mut usize {
38 &mut self.batch_size
39 }
40 pub const fn epochs(&self) -> usize {
42 self.epochs
43 }
44 pub const fn epochs_mut(&mut self) -> &mut usize {
46 &mut self.epochs
47 }
48 pub const fn hyperparameters(&self) -> &ModelConfigMap<T> {
50 &self.hyperparameters
51 }
52 pub const fn hyperparameters_mut(&mut self) -> &mut ModelConfigMap<T> {
54 &mut self.hyperparameters
55 }
56 pub fn add_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
58 self.hyperparameters_mut().insert(key.to_string(), value)
59 }
60 pub fn get_parameter<Q>(&self, key: &Q) -> Option<&T>
62 where
63 Q: ?Sized + Eq + core::hash::Hash,
64 String: core::borrow::Borrow<Q>,
65 {
66 self.hyperparameters().get(key)
67 }
68 pub fn parameter<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
70 where
71 Q: ToString,
72 {
73 self.hyperparameters_mut().entry(key.to_string())
74 }
75 pub fn remove_hyperparameter(&mut self, key: impl ToString) -> Option<T> {
77 self.hyperparameters_mut().remove(&key.to_string())
78 }
79 pub fn set_batch_size(&mut self, batch_size: usize) -> &mut Self {
81 self.batch_size = batch_size;
82 self
83 }
84 pub fn set_epochs(&mut self, epochs: usize) -> &mut Self {
86 self.epochs = epochs;
87 self
88 }
89 pub fn with_batch_size(self, batch_size: usize) -> Self {
91 Self { batch_size, ..self }
92 }
93 pub fn with_epochs(self, epochs: usize) -> Self {
95 Self { epochs, ..self }
96 }
97 pub fn set_decay(&mut self, decay: T) -> Option<T> {
99 self.add_parameter(Decay, decay)
100 }
101 pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
102 self.add_parameter(LearningRate, learning_rate)
103 }
104 pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
106 self.add_parameter(Momentum, momentum)
107 }
108 pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
110 self.add_parameter("weight_decay", decay)
111 }
112 pub fn learning_rate(&self) -> Option<&T> {
114 self.get_parameter(LearningRate.as_ref())
115 }
116 pub fn momentum(&self) -> Option<&T> {
118 self.get_parameter(Momentum.as_ref())
119 }
120 pub fn decay(&self) -> Option<&T> {
122 self.get_parameter(Decay.as_ref())
123 }
124 pub fn weight_decay(&self) -> Option<&T> {
126 self.get_parameter("weight_decay")
127 }
128}
129
130impl<T> NetworkConfig<T> for StandardModelConfig<T> {
131 fn get<K>(&self, key: K) -> Option<&T>
132 where
133 K: AsRef<str>,
134 {
135 self.hyperparameters().get(key.as_ref())
136 }
137
138 fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
139 where
140 K: AsRef<str>,
141 {
142 self.hyperparameters_mut().get_mut(key.as_ref())
143 }
144
145 fn set<K>(&mut self, key: K, value: T) -> Option<T>
146 where
147 K: AsRef<str>,
148 {
149 self.hyperparameters_mut()
150 .insert(key.as_ref().to_string(), value)
151 }
152
153 fn remove<K>(&mut self, key: K) -> Option<T>
154 where
155 K: AsRef<str>,
156 {
157 self.hyperparameters_mut().remove(key.as_ref())
158 }
159
160 fn contains<K>(&self, key: K) -> bool
161 where
162 K: AsRef<str>,
163 {
164 self.hyperparameters().contains_key(key.as_ref())
165 }
166
167 fn keys(&self) -> Vec<String> {
168 self.hyperparameters().keys().cloned().collect()
169 }
170}
171
172impl<T> TrainingConfiguration<T> for StandardModelConfig<T> {
173 fn epochs(&self) -> usize {
174 self.epochs
175 }
176
177 fn batch_size(&self) -> usize {
178 self.batch_size
179 }
180}
181#[allow(deprecated)]
182impl<T> StandardModelConfig<T> {
183 #[deprecated(since = "0.1.0", note = "Use `add_parameter` instead.")]
184 pub fn insert_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
185 self.add_parameter(key, value)
186 }
187 #[deprecated(since = "0.1.0", note = "Use `parameter` instead.")]
188 pub fn hyperparam<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
189 where
190 Q: ToString,
191 {
192 self.parameter(key)
193 }
194 #[deprecated(since = "0.1.0", note = "Use `get_parameter` instead.")]
195 pub fn get(&self, key: impl ToString) -> Option<&T> {
196 self.hyperparameters().get(&key.to_string())
197 }
198}