concision_neural/model/
config.rs1use crate::Hyperparameters::*;
7use crate::traits::{NetworkConfig, TrainingConfiguration};
8
9pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
10
11#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
13pub struct StandardModelConfig<T> {
14 pub(crate) batch_size: usize,
15 pub(crate) epochs: usize,
16 pub(crate) hyperparameters: ModelConfigMap<T>,
17}
18
19impl<T> Default for StandardModelConfig<T> {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl<T> StandardModelConfig<T> {
26 pub fn new() -> Self {
27 Self {
28 batch_size: 0,
29 epochs: 0,
30 hyperparameters: ModelConfigMap::new(),
31 }
32 }
33 pub const fn batch_size(&self) -> usize {
35 self.batch_size
36 }
37 pub const fn batch_size_mut(&mut self) -> &mut usize {
39 &mut self.batch_size
40 }
41 pub const fn epochs(&self) -> usize {
43 self.epochs
44 }
45 pub const fn epochs_mut(&mut self) -> &mut usize {
47 &mut self.epochs
48 }
49 pub const fn hyperparameters(&self) -> &ModelConfigMap<T> {
51 &self.hyperparameters
52 }
53 pub const fn hyperparameters_mut(&mut self) -> &mut ModelConfigMap<T> {
55 &mut self.hyperparameters
56 }
57 pub fn add_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
59 self.hyperparameters_mut().insert(key.to_string(), value)
60 }
61 pub fn get_parameter<Q>(&self, key: &Q) -> Option<&T>
63 where
64 Q: ?Sized + Eq + core::hash::Hash,
65 String: core::borrow::Borrow<Q>,
66 {
67 self.hyperparameters().get(key)
68 }
69 pub fn parameter<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
71 where
72 Q: ToString,
73 {
74 self.hyperparameters_mut().entry(key.to_string())
75 }
76 pub fn remove_hyperparameter(&mut self, key: impl ToString) -> Option<T> {
78 self.hyperparameters_mut().remove(&key.to_string())
79 }
80 pub fn set_batch_size(&mut self, batch_size: usize) -> &mut Self {
82 self.batch_size = batch_size;
83 self
84 }
85 pub fn set_epochs(&mut self, epochs: usize) -> &mut Self {
87 self.epochs = epochs;
88 self
89 }
90 pub fn with_batch_size(self, batch_size: usize) -> Self {
92 Self { batch_size, ..self }
93 }
94 pub fn with_epochs(self, epochs: usize) -> Self {
96 Self { epochs, ..self }
97 }
98 pub fn set_decay(&mut self, decay: T) -> Option<T> {
100 self.add_parameter(Decay, decay)
101 }
102 pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
103 self.add_parameter(LearningRate, learning_rate)
104 }
105 pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
107 self.add_parameter(Momentum, momentum)
108 }
109 pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
111 self.add_parameter("weight_decay", decay)
112 }
113 pub fn learning_rate(&self) -> Option<&T> {
115 self.get_parameter(LearningRate.as_ref())
116 }
117 pub fn momentum(&self) -> Option<&T> {
119 self.get_parameter(Momentum.as_ref())
120 }
121 pub fn decay(&self) -> Option<&T> {
123 self.get_parameter(Decay.as_ref())
124 }
125 pub fn weight_decay(&self) -> Option<&T> {
127 self.get_parameter("weight_decay")
128 }
129}
130
131impl<T> NetworkConfig<T> for StandardModelConfig<T> {
132 fn get<K>(&self, key: K) -> Option<&T>
133 where
134 K: AsRef<str>,
135 {
136 self.hyperparameters.get(key.as_ref())
137 }
138
139 fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
140 where
141 K: AsRef<str>,
142 {
143 self.hyperparameters.get_mut(key.as_ref())
144 }
145
146 fn set<K>(&mut self, key: K, value: T) -> Option<T>
147 where
148 K: AsRef<str>,
149 {
150 self.hyperparameters.insert(key.as_ref().to_string(), value)
151 }
152
153 fn remove<K>(&mut self, key: K) -> Option<T>
154 where
155 K: AsRef<str>,
156 {
157 self.hyperparameters.remove(key.as_ref())
158 }
159
160 fn contains<K>(&self, key: K) -> bool
161 where
162 K: AsRef<str>,
163 {
164 self.hyperparameters.contains_key(key.as_ref())
165 }
166
167 fn keys(&self) -> Vec<String> {
168 self.hyperparameters.keys().cloned().collect()
169 }
170}
171
172impl<T> TrainingConfiguration<T> for StandardModelConfig<T> {
173 fn epochs(&self) -> usize {
174 self.epochs
175 }
176
177 fn batch_size(&self) -> usize {
178 self.batch_size
179 }
180}
181#[allow(deprecated)]
182impl<T> StandardModelConfig<T> {
183 #[deprecated(since = "0.1.0", note = "Use `add_parameter` instead.")]
184 pub fn insert_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
185 self.add_parameter(key, value)
186 }
187 #[deprecated(since = "0.1.0", note = "Use `parameter` instead.")]
188 pub fn hyperparam<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
189 where
190 Q: ToString,
191 {
192 self.parameter(key)
193 }
194 #[deprecated(since = "0.1.0", note = "Use `get_parameter` instead.")]
195 pub fn get(&self, key: impl ToString) -> Option<&T> {
196 self.hyperparameters().get(&key.to_string())
197 }
198}