concision_neural/config/
model_config.rs1use super::Hyperparameters::*;
6use super::{NetworkConfig, RawConfig, TrainingConfiguration};
7
8#[cfg(all(feature = "alloc", not(feature = "std")))]
9pub(crate) type ModelConfigMap<T> = alloc::collections::BTreeMap<String, T>;
10#[cfg(feature = "std")]
11pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
12
13#[derive(Clone, Debug)]
14#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
15pub struct StandardModelConfig<T> {
16 pub(crate) batch_size: usize,
17 pub(crate) epochs: usize,
18 pub(crate) hyperparameters: ModelConfigMap<T>,
19}
20
21impl<T> Default for StandardModelConfig<T> {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl<T> StandardModelConfig<T> {
28 pub fn new() -> Self {
29 Self {
30 batch_size: 0,
31 epochs: 0,
32 hyperparameters: ModelConfigMap::new(),
33 }
34 }
35 pub const fn batch_size(&self) -> usize {
37 self.batch_size
38 }
39 pub const fn batch_size_mut(&mut self) -> &mut usize {
41 &mut self.batch_size
42 }
43 pub const fn epochs(&self) -> usize {
45 self.epochs
46 }
47 pub const fn epochs_mut(&mut self) -> &mut usize {
49 &mut self.epochs
50 }
51 pub const fn hyperparameters(&self) -> &ModelConfigMap<T> {
53 &self.hyperparameters
54 }
55 pub const fn hyperparameters_mut(&mut self) -> &mut ModelConfigMap<T> {
57 &mut self.hyperparameters
58 }
59 pub fn add_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
61 self.hyperparameters_mut().insert(key.to_string(), value)
62 }
63 pub fn get_parameter<Q>(&self, key: &Q) -> Option<&T>
65 where
66 Q: ?Sized + Eq + core::hash::Hash,
67 String: core::borrow::Borrow<Q>,
68 {
69 self.hyperparameters().get(key)
70 }
71 pub fn parameter<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
73 where
74 Q: ToString,
75 {
76 self.hyperparameters_mut().entry(key.to_string())
77 }
78 pub fn remove_hyperparameter(&mut self, key: impl ToString) -> Option<T> {
80 self.hyperparameters_mut().remove(&key.to_string())
81 }
82 pub fn set_batch_size(&mut self, batch_size: usize) -> &mut Self {
84 self.batch_size = batch_size;
85 self
86 }
87 pub fn set_epochs(&mut self, epochs: usize) -> &mut Self {
89 self.epochs = epochs;
90 self
91 }
92 pub fn with_batch_size(self, batch_size: usize) -> Self {
94 Self { batch_size, ..self }
95 }
96 pub fn with_epochs(self, epochs: usize) -> Self {
98 Self { epochs, ..self }
99 }
100 pub fn set_decay(&mut self, decay: T) -> Option<T> {
102 self.add_parameter(Decay, decay)
103 }
104 pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
105 self.add_parameter(LearningRate, learning_rate)
106 }
107 pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
109 self.add_parameter(Momentum, momentum)
110 }
111 pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
113 self.add_parameter("weight_decay", decay)
114 }
115 pub fn learning_rate(&self) -> Option<&T> {
117 self.get_parameter(LearningRate.as_ref())
118 }
119 pub fn momentum(&self) -> Option<&T> {
121 self.get_parameter(Momentum.as_ref())
122 }
123 pub fn decay(&self) -> Option<&T> {
125 self.get_parameter(Decay.as_ref())
126 }
127 pub fn weight_decay(&self) -> Option<&T> {
129 self.get_parameter("weight_decay")
130 }
131}
132
133unsafe impl<T> Send for StandardModelConfig<T> where T: Send {}
134
135unsafe impl<T> Sync for StandardModelConfig<T> where T: Sync {}
136
137impl<T> RawConfig for StandardModelConfig<T> {
138 type Ctx = T;
139}
140
141impl<T> NetworkConfig<T> for StandardModelConfig<T> {
142 fn get<K>(&self, key: K) -> Option<&T>
143 where
144 K: AsRef<str>,
145 {
146 self.hyperparameters().get(key.as_ref())
147 }
148
149 fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
150 where
151 K: AsRef<str>,
152 {
153 self.hyperparameters_mut().get_mut(key.as_ref())
154 }
155
156 fn set<K>(&mut self, key: K, value: T) -> Option<T>
157 where
158 K: AsRef<str>,
159 {
160 self.hyperparameters_mut()
161 .insert(key.as_ref().to_string(), value)
162 }
163
164 fn remove<K>(&mut self, key: K) -> Option<T>
165 where
166 K: AsRef<str>,
167 {
168 self.hyperparameters_mut().remove(key.as_ref())
169 }
170
171 fn contains<K>(&self, key: K) -> bool
172 where
173 K: AsRef<str>,
174 {
175 self.hyperparameters().contains_key(key.as_ref())
176 }
177
178 fn keys(&self) -> Vec<String> {
179 self.hyperparameters().keys().cloned().collect()
180 }
181}
182
183impl<T> TrainingConfiguration<T> for StandardModelConfig<T> {
184 fn epochs(&self) -> usize {
185 self.epochs
186 }
187
188 fn batch_size(&self) -> usize {
189 self.batch_size
190 }
191}
192#[allow(deprecated)]
193impl<T> StandardModelConfig<T> {
194 #[deprecated(since = "0.1.0", note = "Use `add_parameter` instead.")]
195 pub fn insert_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
196 self.add_parameter(key, value)
197 }
198 #[deprecated(since = "0.1.0", note = "Use `parameter` instead.")]
199 pub fn hyperparam<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
200 where
201 Q: ToString,
202 {
203 self.parameter(key)
204 }
205 #[deprecated(since = "0.1.0", note = "Use `get_parameter` instead.")]
206 pub fn get(&self, key: impl ToString) -> Option<&T> {
207 self.hyperparameters().get(&key.to_string())
208 }
209}