concision_neural/config/
model_config.rs

1/*
2    Appellation: config <module>
3    Contrib: @FL03
4*/
5use super::Hyperparameters::*;
6use super::{NetworkConfig, TrainingConfiguration};
7
8pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
9
10#[derive(Clone, Debug)]
11#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
12pub struct StandardModelConfig<T> {
13    pub(crate) batch_size: usize,
14    pub(crate) epochs: usize,
15    pub(crate) hyperparameters: ModelConfigMap<T>,
16}
17
18impl<T> Default for StandardModelConfig<T> {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl<T> StandardModelConfig<T> {
25    pub fn new() -> Self {
26        Self {
27            batch_size: 0,
28            epochs: 0,
29            hyperparameters: ModelConfigMap::new(),
30        }
31    }
32    /// returns a copy of the batch size
33    pub const fn batch_size(&self) -> usize {
34        self.batch_size
35    }
36    /// returns a mutable reference to the batch size
37    pub const fn batch_size_mut(&mut self) -> &mut usize {
38        &mut self.batch_size
39    }
40    /// returns a copy of the epochs
41    pub const fn epochs(&self) -> usize {
42        self.epochs
43    }
44    /// returns a mutable reference to the epochs
45    pub const fn epochs_mut(&mut self) -> &mut usize {
46        &mut self.epochs
47    }
48    /// returns a reference to the hyperparameters map
49    pub const fn hyperparameters(&self) -> &ModelConfigMap<T> {
50        &self.hyperparameters
51    }
52    /// returns a mutable reference to the hyperparameters map
53    pub const fn hyperparameters_mut(&mut self) -> &mut ModelConfigMap<T> {
54        &mut self.hyperparameters
55    }
56    /// inserts a hyperparameter into the map, returning the previous value if it exists
57    pub fn add_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
58        self.hyperparameters_mut().insert(key.to_string(), value)
59    }
60    /// gets a reference to a hyperparameter by key, returning None if it does not exist
61    pub fn get_parameter<Q>(&self, key: &Q) -> Option<&T>
62    where
63        Q: ?Sized + Eq + core::hash::Hash,
64        String: core::borrow::Borrow<Q>,
65    {
66        self.hyperparameters().get(key)
67    }
68    /// returns an entry for the hyperparameter, allowing for insertion or modification
69    pub fn parameter<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
70    where
71        Q: ToString,
72    {
73        self.hyperparameters_mut().entry(key.to_string())
74    }
75    /// removes a hyperparameter from the map, returning the value if it exists
76    pub fn remove_hyperparameter(&mut self, key: impl ToString) -> Option<T> {
77        self.hyperparameters_mut().remove(&key.to_string())
78    }
79    /// sets the batch size, returning a mutable reference to the current instance
80    pub fn set_batch_size(&mut self, batch_size: usize) -> &mut Self {
81        self.batch_size = batch_size;
82        self
83    }
84    /// sets the number of epochs, returning a mutable reference to the current instance
85    pub fn set_epochs(&mut self, epochs: usize) -> &mut Self {
86        self.epochs = epochs;
87        self
88    }
89    /// consumes the current instance to create another with the given batch size
90    pub fn with_batch_size(self, batch_size: usize) -> Self {
91        Self { batch_size, ..self }
92    }
93    /// consumes the current instance to create another with the given epochs
94    pub fn with_epochs(self, epochs: usize) -> Self {
95        Self { epochs, ..self }
96    }
97    /// sets the decay hyperparameter, returning the previous value if it exists
98    pub fn set_decay(&mut self, decay: T) -> Option<T> {
99        self.add_parameter(Decay, decay)
100    }
101    pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
102        self.add_parameter(LearningRate, learning_rate)
103    }
104    /// sets the momentum hyperparameter, returning the previous value if it exists
105    pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
106        self.add_parameter(Momentum, momentum)
107    }
108    /// sets the weight decay hyperparameter, returning the previous value if it exists
109    pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
110        self.add_parameter("weight_decay", decay)
111    }
112    /// returns a reference to the learning rate hyperparameter, if it exists
113    pub fn learning_rate(&self) -> Option<&T> {
114        self.get_parameter(LearningRate.as_ref())
115    }
116    /// returns a reference to the momentum hyperparameter, if it exists
117    pub fn momentum(&self) -> Option<&T> {
118        self.get_parameter(Momentum.as_ref())
119    }
120    /// returns a reference to the decay hyperparameter, if it exists
121    pub fn decay(&self) -> Option<&T> {
122        self.get_parameter(Decay.as_ref())
123    }
124    /// returns a reference to the weight decay hyperparameter, if it exists
125    pub fn weight_decay(&self) -> Option<&T> {
126        self.get_parameter("weight_decay")
127    }
128}
129
130impl<T> NetworkConfig<T> for StandardModelConfig<T> {
131    fn get<K>(&self, key: K) -> Option<&T>
132    where
133        K: AsRef<str>,
134    {
135        self.hyperparameters().get(key.as_ref())
136    }
137
138    fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
139    where
140        K: AsRef<str>,
141    {
142        self.hyperparameters_mut().get_mut(key.as_ref())
143    }
144
145    fn set<K>(&mut self, key: K, value: T) -> Option<T>
146    where
147        K: AsRef<str>,
148    {
149        self.hyperparameters_mut()
150            .insert(key.as_ref().to_string(), value)
151    }
152
153    fn remove<K>(&mut self, key: K) -> Option<T>
154    where
155        K: AsRef<str>,
156    {
157        self.hyperparameters_mut().remove(key.as_ref())
158    }
159
160    fn contains<K>(&self, key: K) -> bool
161    where
162        K: AsRef<str>,
163    {
164        self.hyperparameters().contains_key(key.as_ref())
165    }
166
167    fn keys(&self) -> Vec<String> {
168        self.hyperparameters().keys().cloned().collect()
169    }
170}
171
172impl<T> TrainingConfiguration<T> for StandardModelConfig<T> {
173    fn epochs(&self) -> usize {
174        self.epochs
175    }
176
177    fn batch_size(&self) -> usize {
178        self.batch_size
179    }
180}
181#[allow(deprecated)]
182impl<T> StandardModelConfig<T> {
183    #[deprecated(since = "0.1.0", note = "Use `add_parameter` instead.")]
184    pub fn insert_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
185        self.add_parameter(key, value)
186    }
187    #[deprecated(since = "0.1.0", note = "Use `parameter` instead.")]
188    pub fn hyperparam<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
189    where
190        Q: ToString,
191    {
192        self.parameter(key)
193    }
194    #[deprecated(since = "0.1.0", note = "Use `get_parameter` instead.")]
195    pub fn get(&self, key: impl ToString) -> Option<&T> {
196        self.hyperparameters().get(&key.to_string())
197    }
198}