concision_neural/config/
model_config.rs

1/*
2    Appellation: config <module>
3    Contrib: @FL03
4*/
5use super::Hyperparameters::*;
6use super::{NetworkConfig, RawConfig, TrainingConfiguration};
7
8#[cfg(all(feature = "alloc", not(feature = "std")))]
9pub(crate) type ModelConfigMap<T> = alloc::collections::BTreeMap<String, T>;
10#[cfg(feature = "std")]
11pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
12
13#[derive(Clone, Debug)]
14#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
15pub struct StandardModelConfig<T> {
16    pub(crate) batch_size: usize,
17    pub(crate) epochs: usize,
18    pub(crate) hyperparameters: ModelConfigMap<T>,
19}
20
21impl<T> Default for StandardModelConfig<T> {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl<T> StandardModelConfig<T> {
28    pub fn new() -> Self {
29        Self {
30            batch_size: 0,
31            epochs: 0,
32            hyperparameters: ModelConfigMap::new(),
33        }
34    }
35    /// returns a copy of the batch size
36    pub const fn batch_size(&self) -> usize {
37        self.batch_size
38    }
39    /// returns a mutable reference to the batch size
40    pub const fn batch_size_mut(&mut self) -> &mut usize {
41        &mut self.batch_size
42    }
43    /// returns a copy of the epochs
44    pub const fn epochs(&self) -> usize {
45        self.epochs
46    }
47    /// returns a mutable reference to the epochs
48    pub const fn epochs_mut(&mut self) -> &mut usize {
49        &mut self.epochs
50    }
51    /// returns a reference to the hyperparameters map
52    pub const fn hyperparameters(&self) -> &ModelConfigMap<T> {
53        &self.hyperparameters
54    }
55    /// returns a mutable reference to the hyperparameters map
56    pub const fn hyperparameters_mut(&mut self) -> &mut ModelConfigMap<T> {
57        &mut self.hyperparameters
58    }
59    /// inserts a hyperparameter into the map, returning the previous value if it exists
60    pub fn add_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
61        self.hyperparameters_mut().insert(key.to_string(), value)
62    }
63    /// gets a reference to a hyperparameter by key, returning None if it does not exist
64    pub fn get_parameter<Q>(&self, key: &Q) -> Option<&T>
65    where
66        Q: ?Sized + Eq + core::hash::Hash,
67        String: core::borrow::Borrow<Q>,
68    {
69        self.hyperparameters().get(key)
70    }
71    /// returns an entry for the hyperparameter, allowing for insertion or modification
72    pub fn parameter<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
73    where
74        Q: ToString,
75    {
76        self.hyperparameters_mut().entry(key.to_string())
77    }
78    /// removes a hyperparameter from the map, returning the value if it exists
79    pub fn remove_hyperparameter(&mut self, key: impl ToString) -> Option<T> {
80        self.hyperparameters_mut().remove(&key.to_string())
81    }
82    /// sets the batch size, returning a mutable reference to the current instance
83    pub fn set_batch_size(&mut self, batch_size: usize) -> &mut Self {
84        self.batch_size = batch_size;
85        self
86    }
87    /// sets the number of epochs, returning a mutable reference to the current instance
88    pub fn set_epochs(&mut self, epochs: usize) -> &mut Self {
89        self.epochs = epochs;
90        self
91    }
92    /// consumes the current instance to create another with the given batch size
93    pub fn with_batch_size(self, batch_size: usize) -> Self {
94        Self { batch_size, ..self }
95    }
96    /// consumes the current instance to create another with the given epochs
97    pub fn with_epochs(self, epochs: usize) -> Self {
98        Self { epochs, ..self }
99    }
100    /// sets the decay hyperparameter, returning the previous value if it exists
101    pub fn set_decay(&mut self, decay: T) -> Option<T> {
102        self.add_parameter(Decay, decay)
103    }
104    pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
105        self.add_parameter(LearningRate, learning_rate)
106    }
107    /// sets the momentum hyperparameter, returning the previous value if it exists
108    pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
109        self.add_parameter(Momentum, momentum)
110    }
111    /// sets the weight decay hyperparameter, returning the previous value if it exists
112    pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
113        self.add_parameter("weight_decay", decay)
114    }
115    /// returns a reference to the learning rate hyperparameter, if it exists
116    pub fn learning_rate(&self) -> Option<&T> {
117        self.get_parameter(LearningRate.as_ref())
118    }
119    /// returns a reference to the momentum hyperparameter, if it exists
120    pub fn momentum(&self) -> Option<&T> {
121        self.get_parameter(Momentum.as_ref())
122    }
123    /// returns a reference to the decay hyperparameter, if it exists
124    pub fn decay(&self) -> Option<&T> {
125        self.get_parameter(Decay.as_ref())
126    }
127    /// returns a reference to the weight decay hyperparameter, if it exists
128    pub fn weight_decay(&self) -> Option<&T> {
129        self.get_parameter("weight_decay")
130    }
131}
132
133unsafe impl<T> Send for StandardModelConfig<T> where T: Send {}
134
135unsafe impl<T> Sync for StandardModelConfig<T> where T: Sync {}
136
137impl<T> RawConfig for StandardModelConfig<T> {
138    type Ctx = T;
139}
140
141impl<T> NetworkConfig<T> for StandardModelConfig<T> {
142    fn get<K>(&self, key: K) -> Option<&T>
143    where
144        K: AsRef<str>,
145    {
146        self.hyperparameters().get(key.as_ref())
147    }
148
149    fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
150    where
151        K: AsRef<str>,
152    {
153        self.hyperparameters_mut().get_mut(key.as_ref())
154    }
155
156    fn set<K>(&mut self, key: K, value: T) -> Option<T>
157    where
158        K: AsRef<str>,
159    {
160        self.hyperparameters_mut()
161            .insert(key.as_ref().to_string(), value)
162    }
163
164    fn remove<K>(&mut self, key: K) -> Option<T>
165    where
166        K: AsRef<str>,
167    {
168        self.hyperparameters_mut().remove(key.as_ref())
169    }
170
171    fn contains<K>(&self, key: K) -> bool
172    where
173        K: AsRef<str>,
174    {
175        self.hyperparameters().contains_key(key.as_ref())
176    }
177
178    fn keys(&self) -> Vec<String> {
179        self.hyperparameters().keys().cloned().collect()
180    }
181}
182
183impl<T> TrainingConfiguration<T> for StandardModelConfig<T> {
184    fn epochs(&self) -> usize {
185        self.epochs
186    }
187
188    fn batch_size(&self) -> usize {
189        self.batch_size
190    }
191}
192#[allow(deprecated)]
193impl<T> StandardModelConfig<T> {
194    #[deprecated(since = "0.1.0", note = "Use `add_parameter` instead.")]
195    pub fn insert_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
196        self.add_parameter(key, value)
197    }
198    #[deprecated(since = "0.1.0", note = "Use `parameter` instead.")]
199    pub fn hyperparam<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
200    where
201        Q: ToString,
202    {
203        self.parameter(key)
204    }
205    #[deprecated(since = "0.1.0", note = "Use `get_parameter` instead.")]
206    pub fn get(&self, key: impl ToString) -> Option<&T> {
207        self.hyperparameters().get(&key.to_string())
208    }
209}