concision_neural/model/
config.rs

1/*
2    Appellation: config <module>
3    Contrib: @FL03
4*/
5
6use crate::Hyperparameters::*;
7use crate::traits::{NetworkConfig, TrainingConfiguration};
8
9pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
10
11#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
13pub struct StandardModelConfig<T> {
14    pub(crate) batch_size: usize,
15    pub(crate) epochs: usize,
16    pub(crate) hyperparameters: ModelConfigMap<T>,
17}
18
19impl<T> Default for StandardModelConfig<T> {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl<T> StandardModelConfig<T> {
26    pub fn new() -> Self {
27        Self {
28            batch_size: 0,
29            epochs: 0,
30            hyperparameters: ModelConfigMap::new(),
31        }
32    }
33    /// returns a copy of the batch size
34    pub const fn batch_size(&self) -> usize {
35        self.batch_size
36    }
37    /// returns a mutable reference to the batch size
38    pub const fn batch_size_mut(&mut self) -> &mut usize {
39        &mut self.batch_size
40    }
41    /// returns a copy of the epochs
42    pub const fn epochs(&self) -> usize {
43        self.epochs
44    }
45    /// returns a mutable reference to the epochs
46    pub const fn epochs_mut(&mut self) -> &mut usize {
47        &mut self.epochs
48    }
49    /// returns a reference to the hyperparameters map
50    pub const fn hyperparameters(&self) -> &ModelConfigMap<T> {
51        &self.hyperparameters
52    }
53    /// returns a mutable reference to the hyperparameters map
54    pub const fn hyperparameters_mut(&mut self) -> &mut ModelConfigMap<T> {
55        &mut self.hyperparameters
56    }
57    /// inserts a hyperparameter into the map, returning the previous value if it exists
58    pub fn add_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
59        self.hyperparameters_mut().insert(key.to_string(), value)
60    }
61    /// gets a reference to a hyperparameter by key, returning None if it does not exist
62    pub fn get_parameter<Q>(&self, key: &Q) -> Option<&T>
63    where
64        Q: ?Sized + Eq + core::hash::Hash,
65        String: core::borrow::Borrow<Q>,
66    {
67        self.hyperparameters().get(key)
68    }
69    /// returns an entry for the hyperparameter, allowing for insertion or modification
70    pub fn parameter<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
71    where
72        Q: ToString,
73    {
74        self.hyperparameters_mut().entry(key.to_string())
75    }
76    /// removes a hyperparameter from the map, returning the value if it exists
77    pub fn remove_hyperparameter(&mut self, key: impl ToString) -> Option<T> {
78        self.hyperparameters_mut().remove(&key.to_string())
79    }
80    /// sets the batch size, returning a mutable reference to the current instance
81    pub fn set_batch_size(&mut self, batch_size: usize) -> &mut Self {
82        self.batch_size = batch_size;
83        self
84    }
85    /// sets the number of epochs, returning a mutable reference to the current instance
86    pub fn set_epochs(&mut self, epochs: usize) -> &mut Self {
87        self.epochs = epochs;
88        self
89    }
90    /// consumes the current instance to create another with the given batch size
91    pub fn with_batch_size(self, batch_size: usize) -> Self {
92        Self { batch_size, ..self }
93    }
94    /// consumes the current instance to create another with the given epochs
95    pub fn with_epochs(self, epochs: usize) -> Self {
96        Self { epochs, ..self }
97    }
98    /// sets the decay hyperparameter, returning the previous value if it exists
99    pub fn set_decay(&mut self, decay: T) -> Option<T> {
100        self.add_parameter(Decay, decay)
101    }
102    pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
103        self.add_parameter(LearningRate, learning_rate)
104    }
105    /// sets the momentum hyperparameter, returning the previous value if it exists
106    pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
107        self.add_parameter(Momentum, momentum)
108    }
109    /// sets the weight decay hyperparameter, returning the previous value if it exists
110    pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
111        self.add_parameter("weight_decay", decay)
112    }
113    /// returns a reference to the learning rate hyperparameter, if it exists
114    pub fn learning_rate(&self) -> Option<&T> {
115        self.get_parameter(LearningRate.as_ref())
116    }
117    /// returns a reference to the momentum hyperparameter, if it exists
118    pub fn momentum(&self) -> Option<&T> {
119        self.get_parameter(Momentum.as_ref())
120    }
121    /// returns a reference to the decay hyperparameter, if it exists
122    pub fn decay(&self) -> Option<&T> {
123        self.get_parameter(Decay.as_ref())
124    }
125    /// returns a reference to the weight decay hyperparameter, if it exists
126    pub fn weight_decay(&self) -> Option<&T> {
127        self.get_parameter("weight_decay")
128    }
129}
130
131impl<T> NetworkConfig<T> for StandardModelConfig<T> {
132    fn get<K>(&self, key: K) -> Option<&T>
133    where
134        K: AsRef<str>,
135    {
136        self.hyperparameters.get(key.as_ref())
137    }
138
139    fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
140    where
141        K: AsRef<str>,
142    {
143        self.hyperparameters.get_mut(key.as_ref())
144    }
145
146    fn set<K>(&mut self, key: K, value: T) -> Option<T>
147    where
148        K: AsRef<str>,
149    {
150        self.hyperparameters.insert(key.as_ref().to_string(), value)
151    }
152
153    fn remove<K>(&mut self, key: K) -> Option<T>
154    where
155        K: AsRef<str>,
156    {
157        self.hyperparameters.remove(key.as_ref())
158    }
159
160    fn contains<K>(&self, key: K) -> bool
161    where
162        K: AsRef<str>,
163    {
164        self.hyperparameters.contains_key(key.as_ref())
165    }
166
167    fn keys(&self) -> Vec<String> {
168        self.hyperparameters.keys().cloned().collect()
169    }
170}
171
172impl<T> TrainingConfiguration<T> for StandardModelConfig<T> {
173    fn epochs(&self) -> usize {
174        self.epochs
175    }
176
177    fn batch_size(&self) -> usize {
178        self.batch_size
179    }
180}
181#[allow(deprecated)]
182impl<T> StandardModelConfig<T> {
183    #[deprecated(since = "0.1.0", note = "Use `add_parameter` instead.")]
184    pub fn insert_parameter(&mut self, key: impl ToString, value: T) -> Option<T> {
185        self.add_parameter(key, value)
186    }
187    #[deprecated(since = "0.1.0", note = "Use `parameter` instead.")]
188    pub fn hyperparam<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T>
189    where
190        Q: ToString,
191    {
192        self.parameter(key)
193    }
194    #[deprecated(since = "0.1.0", note = "Use `get_parameter` instead.")]
195    pub fn get(&self, key: impl ToString) -> Option<&T> {
196        self.hyperparameters().get(&key.to_string())
197    }
198}