concision_core/config/
model_config.rs1use super::HyperParam;
6use super::{ExtendedModelConfig, ModelConfiguration, RawConfig};
7use alloc::string::{String, ToString};
8use hashbrown::DefaultHashBuilder;
9use hashbrown::hash_map::{self, HashMap};
10
11#[derive(Clone, Debug)]
13#[cfg_attr(
14 feature = "serde",
15 derive(serde::Deserialize, serde::Serialize),
16 serde(rename = "snake_case")
17)]
18pub struct StandardModelConfig<T> {
19 pub batch_size: usize,
20 pub epochs: usize,
21 pub hyperspace: HashMap<String, T>,
22}
23
24impl<T> StandardModelConfig<T> {
25 pub fn new() -> Self {
26 Self {
27 batch_size: 0,
28 epochs: 0,
29 hyperspace: HashMap::new(),
30 }
31 }
32 pub const fn batch_size(&self) -> usize {
34 self.batch_size
35 }
36 pub const fn batch_size_mut(&mut self) -> &mut usize {
38 &mut self.batch_size
39 }
40 pub const fn epochs(&self) -> usize {
42 self.epochs
43 }
44 pub const fn epochs_mut(&mut self) -> &mut usize {
46 &mut self.epochs
47 }
48 pub const fn hyperparameters(&self) -> &HashMap<String, T> {
50 &self.hyperspace
51 }
52 pub const fn hyperparameters_mut(&mut self) -> &mut HashMap<String, T> {
54 &mut self.hyperspace
55 }
56 pub fn add_parameter<K: ToString>(&mut self, key: K, value: T) -> Option<T> {
58 self.hyperparameters_mut().insert(key.to_string(), value)
59 }
60 pub fn get<Q>(&self, key: &Q) -> Option<&T>
62 where
63 Q: ?Sized + Eq + core::hash::Hash,
64 String: core::borrow::Borrow<Q>,
65 {
66 self.hyperparameters().get(key)
67 }
68 pub fn parameter<Q: ToString>(
70 &mut self,
71 key: Q,
72 ) -> hash_map::Entry<'_, String, T, DefaultHashBuilder> {
73 self.hyperparameters_mut().entry(key.to_string())
74 }
75 pub fn remove_hyperparameter<Q>(&mut self, key: &Q) -> Option<T>
77 where
78 Q: ?Sized + core::hash::Hash + Eq,
79 String: core::borrow::Borrow<Q>,
80 {
81 self.hyperparameters_mut().remove(key)
82 }
83 pub fn set_batch_size(&mut self, batch_size: usize) -> &mut Self {
85 self.batch_size = batch_size;
86 self
87 }
88 pub fn set_epochs(&mut self, epochs: usize) -> &mut Self {
90 self.epochs = epochs;
91 self
92 }
93 pub fn with_batch_size(self, batch_size: usize) -> Self {
95 Self { batch_size, ..self }
96 }
97 pub fn with_epochs(self, epochs: usize) -> Self {
99 Self { epochs, ..self }
100 }
101}
102
103use HyperParam::*;
104
105impl<T> StandardModelConfig<T> {
106 pub fn set_decay(&mut self, decay: T) -> Option<T> {
108 self.add_parameter(Decay, decay)
109 }
110 pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
111 self.add_parameter(LearningRate, learning_rate)
112 }
113 pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
115 self.add_parameter(Momentum, momentum)
116 }
117 pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
119 self.add_parameter(WeightDecay, decay)
120 }
121 pub fn learning_rate(&self) -> Option<&T> {
123 self.get("learning_rate")
124 }
125 pub fn momentum(&self) -> Option<&T> {
127 self.get("momentum")
128 }
129 pub fn decay(&self) -> Option<&T> {
131 self.get("decay")
132 }
133 pub fn weight_decay(&self) -> Option<&T> {
135 self.get("weight_decay")
136 }
137}
138
139impl<T> Default for StandardModelConfig<T> {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145unsafe impl<T> Send for StandardModelConfig<T> where T: Send {}
146
147unsafe impl<T> Sync for StandardModelConfig<T> where T: Sync {}
148
149impl<T> crate::nn::NetworkConfig<String, T> for StandardModelConfig<T> {
150 type Store = HashMap<String, T, DefaultHashBuilder>;
151
152 fn store(&self) -> &Self::Store {
153 &self.hyperspace
154 }
155
156 fn store_mut(&mut self) -> &mut Self::Store {
157 &mut self.hyperspace
158 }
159}
160
161impl<T> RawConfig for StandardModelConfig<T> {
162 type Ctx = T;
163}
164
165impl<T> ModelConfiguration<T> for StandardModelConfig<T> {
166 fn get<K>(&self, key: K) -> Option<&T>
167 where
168 K: AsRef<str>,
169 {
170 self.hyperparameters().get(key.as_ref())
171 }
172
173 fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
174 where
175 K: AsRef<str>,
176 {
177 self.hyperparameters_mut().get_mut(key.as_ref())
178 }
179
180 fn set<K>(&mut self, key: K, value: T) -> Option<T>
181 where
182 K: AsRef<str>,
183 {
184 self.hyperparameters_mut()
185 .insert(key.as_ref().into(), value)
186 }
187
188 fn remove<K>(&mut self, key: K) -> Option<T>
189 where
190 K: AsRef<str>,
191 {
192 self.hyperparameters_mut().remove(key.as_ref())
193 }
194
195 fn contains<K>(&self, key: K) -> bool
196 where
197 K: AsRef<str>,
198 {
199 self.hyperparameters().contains_key(key.as_ref())
200 }
201
202 fn keys(&self) -> Vec<&str> {
203 self.hyperparameters().keys().map(|k| k.as_str()).collect()
204 }
205}
206
207impl<T> ExtendedModelConfig<T> for StandardModelConfig<T> {
208 fn epochs(&self) -> usize {
209 self.epochs
210 }
211
212 fn batch_size(&self) -> usize {
213 self.batch_size
214 }
215}