concision_core/config/
model_config.rs

1/*
2    Appellation: config <module>
3    Contrib: @FL03
4*/
5use super::HyperParam;
6use super::{ExtendedModelConfig, ModelConfiguration, RawConfig};
7use alloc::string::{String, ToString};
8use hashbrown::DefaultHashBuilder;
9use hashbrown::hash_map::{self, HashMap};
10
11/// The [`StandardModelConfig`] struct is a standard implementation of the
12#[derive(Clone, Debug)]
13#[cfg_attr(
14    feature = "serde",
15    derive(serde::Deserialize, serde::Serialize),
16    serde(rename = "snake_case")
17)]
18pub struct StandardModelConfig<T> {
19    pub batch_size: usize,
20    pub epochs: usize,
21    pub hyperspace: HashMap<String, T>,
22}
23
24impl<T> StandardModelConfig<T> {
25    pub fn new() -> Self {
26        Self {
27            batch_size: 0,
28            epochs: 0,
29            hyperspace: HashMap::new(),
30        }
31    }
32    /// returns a copy of the batch size
33    pub const fn batch_size(&self) -> usize {
34        self.batch_size
35    }
36    /// returns a mutable reference to the batch size
37    pub const fn batch_size_mut(&mut self) -> &mut usize {
38        &mut self.batch_size
39    }
40    /// returns a copy of the epochs
41    pub const fn epochs(&self) -> usize {
42        self.epochs
43    }
44    /// returns a mutable reference to the epochs
45    pub const fn epochs_mut(&mut self) -> &mut usize {
46        &mut self.epochs
47    }
48    /// returns a reference to the hyperparameters map
49    pub const fn hyperparameters(&self) -> &HashMap<String, T> {
50        &self.hyperspace
51    }
52    /// returns a mutable reference to the hyperparameters map
53    pub const fn hyperparameters_mut(&mut self) -> &mut HashMap<String, T> {
54        &mut self.hyperspace
55    }
56    /// inserts a hyperparameter into the map, returning the previous value if it exists
57    pub fn add_parameter<K: ToString>(&mut self, key: K, value: T) -> Option<T> {
58        self.hyperparameters_mut().insert(key.to_string(), value)
59    }
60    /// gets a reference to a hyperparameter by key, returning None if it does not exist
61    pub fn get<Q>(&self, key: &Q) -> Option<&T>
62    where
63        Q: ?Sized + Eq + core::hash::Hash,
64        String: core::borrow::Borrow<Q>,
65    {
66        self.hyperparameters().get(key)
67    }
68    /// returns an entry for the hyperparameter, allowing for insertion or modification
69    pub fn parameter<Q: ToString>(
70        &mut self,
71        key: Q,
72    ) -> hash_map::Entry<'_, String, T, DefaultHashBuilder> {
73        self.hyperparameters_mut().entry(key.to_string())
74    }
75    /// removes a hyperparameter from the map, returning the value if it exists
76    pub fn remove_hyperparameter<Q>(&mut self, key: &Q) -> Option<T>
77    where
78        Q: ?Sized + core::hash::Hash + Eq,
79        String: core::borrow::Borrow<Q>,
80    {
81        self.hyperparameters_mut().remove(key)
82    }
83    /// sets the batch size, returning a mutable reference to the current instance
84    pub fn set_batch_size(&mut self, batch_size: usize) -> &mut Self {
85        self.batch_size = batch_size;
86        self
87    }
88    /// sets the number of epochs, returning a mutable reference to the current instance
89    pub fn set_epochs(&mut self, epochs: usize) -> &mut Self {
90        self.epochs = epochs;
91        self
92    }
93    /// consumes the current instance to create another with the given batch size
94    pub fn with_batch_size(self, batch_size: usize) -> Self {
95        Self { batch_size, ..self }
96    }
97    /// consumes the current instance to create another with the given epochs
98    pub fn with_epochs(self, epochs: usize) -> Self {
99        Self { epochs, ..self }
100    }
101}
102
103use HyperParam::*;
104
105impl<T> StandardModelConfig<T> {
106    /// sets the decay hyperparameter, returning the previous value if it exists
107    pub fn set_decay(&mut self, decay: T) -> Option<T> {
108        self.add_parameter(Decay, decay)
109    }
110    pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
111        self.add_parameter(LearningRate, learning_rate)
112    }
113    /// sets the momentum hyperparameter, returning the previous value if it exists
114    pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
115        self.add_parameter(Momentum, momentum)
116    }
117    /// sets the weight decay hyperparameter, returning the previous value if it exists
118    pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
119        self.add_parameter(WeightDecay, decay)
120    }
121    /// returns a reference to the learning rate hyperparameter, if it exists
122    pub fn learning_rate(&self) -> Option<&T> {
123        self.get("learning_rate")
124    }
125    /// returns a reference to the momentum hyperparameter, if it exists
126    pub fn momentum(&self) -> Option<&T> {
127        self.get("momentum")
128    }
129    /// returns a reference to the decay hyperparameter, if it exists
130    pub fn decay(&self) -> Option<&T> {
131        self.get("decay")
132    }
133    /// returns a reference to the weight decay hyperparameter, if it exists
134    pub fn weight_decay(&self) -> Option<&T> {
135        self.get("weight_decay")
136    }
137}
138
139impl<T> Default for StandardModelConfig<T> {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145unsafe impl<T> Send for StandardModelConfig<T> where T: Send {}
146
147unsafe impl<T> Sync for StandardModelConfig<T> where T: Sync {}
148
149impl<T> crate::nn::NetworkConfig<String, T> for StandardModelConfig<T> {
150    type Store = HashMap<String, T, DefaultHashBuilder>;
151
152    fn store(&self) -> &Self::Store {
153        &self.hyperspace
154    }
155
156    fn store_mut(&mut self) -> &mut Self::Store {
157        &mut self.hyperspace
158    }
159}
160
161impl<T> RawConfig for StandardModelConfig<T> {
162    type Ctx = T;
163}
164
165impl<T> ModelConfiguration<T> for StandardModelConfig<T> {
166    fn get<K>(&self, key: K) -> Option<&T>
167    where
168        K: AsRef<str>,
169    {
170        self.hyperparameters().get(key.as_ref())
171    }
172
173    fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
174    where
175        K: AsRef<str>,
176    {
177        self.hyperparameters_mut().get_mut(key.as_ref())
178    }
179
180    fn set<K>(&mut self, key: K, value: T) -> Option<T>
181    where
182        K: AsRef<str>,
183    {
184        self.hyperparameters_mut()
185            .insert(key.as_ref().into(), value)
186    }
187
188    fn remove<K>(&mut self, key: K) -> Option<T>
189    where
190        K: AsRef<str>,
191    {
192        self.hyperparameters_mut().remove(key.as_ref())
193    }
194
195    fn contains<K>(&self, key: K) -> bool
196    where
197        K: AsRef<str>,
198    {
199        self.hyperparameters().contains_key(key.as_ref())
200    }
201
202    fn keys(&self) -> Vec<&str> {
203        self.hyperparameters().keys().map(|k| k.as_str()).collect()
204    }
205}
206
207impl<T> ExtendedModelConfig<T> for StandardModelConfig<T> {
208    fn epochs(&self) -> usize {
209        self.epochs
210    }
211
212    fn batch_size(&self) -> usize {
213        self.batch_size
214    }
215}