concision_neural/model/
config.rs1use crate::Hyperparameters::*;
7use crate::traits::{NetworkConfig, TrainingConfiguration};
8
9pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
10
11#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
13pub struct StandardModelConfig<T> {
14 pub(crate) batch_size: usize,
15 pub(crate) epochs: usize,
16 pub(crate) hyperparameters: ModelConfigMap<T>,
17}
18
19impl<T> StandardModelConfig<T> {
20 pub fn new() -> Self {
21 Self {
22 batch_size: 0,
23 epochs: 0,
24 hyperparameters: ModelConfigMap::new(),
25 }
26 }
27
28 pub fn batch_size(&self) -> usize {
29 self.batch_size
30 }
31
32 pub fn epochs(&self) -> usize {
33 self.epochs
34 }
35
36 pub fn with_batch_size(self, batch_size: usize) -> Self {
37 Self { batch_size, ..self }
38 }
39
40 pub fn with_epochs(self, epochs: usize) -> Self {
41 Self { epochs, ..self }
42 }
43
44 pub fn hyperparameters(&self) -> &ModelConfigMap<T> {
45 &self.hyperparameters
46 }
47
48 pub fn hyperparameters_mut(&mut self) -> &mut ModelConfigMap<T> {
49 &mut self.hyperparameters
50 }
51
52 pub fn insert_hyperparameter(&mut self, key: impl ToString, value: T) -> Option<T> {
53 self.hyperparameters_mut().insert(key.to_string(), value)
54 }
55
56 pub fn hyperparam<Q>(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T> where Q: ToString {
57 self.hyperparameters_mut().entry(key.to_string())
58 }
59
60 pub fn remove_hyperparameter(&mut self, key: impl ToString) -> Option<T> {
61 self.hyperparameters_mut().remove(&key.to_string())
62 }
63
64 pub fn set_decay(&mut self, decay: T) -> Option<T> {
65 self.insert_hyperparameter(Decay, decay)
66 }
67 pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
68 self.insert_hyperparameter(LearningRate, learning_rate)
69 }
70
71 pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
72 self.insert_hyperparameter(Momentum, momentum)
73 }
74
75 pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
76 self.insert_hyperparameter("weight_decay", decay)
77 }
78
79 pub fn get(&self, key: impl ToString) -> Option<&T> {
80 self.hyperparameters().get(&key.to_string())
81 }
82
83 pub fn learning_rate(&self) -> Option<&T> {
84 self.get(LearningRate)
85 }
86
87 pub fn momentum(&self) -> Option<&T> {
88 self.get(Momentum)
89 }
90
91 pub fn decay(&self) -> Option<&T> {
92 self.get(Decay)
93 }
94
95 pub fn weight_decay(&self) -> Option<&T> {
96 self.get("weight_decay")
97 }
98}
99
100impl<T> NetworkConfig<T> for StandardModelConfig<T> {
101 fn get<K>(&self, key: K) -> Option<&T>
102 where
103 K: AsRef<str>,
104 {
105 self.hyperparameters.get(key.as_ref())
106 }
107
108 fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
109 where
110 K: AsRef<str>,
111 {
112 self.hyperparameters.get_mut(key.as_ref())
113 }
114
115 fn set<K>(&mut self, key: K, value: T) -> Option<T>
116 where
117 K: AsRef<str>,
118 {
119 self.hyperparameters.insert(key.as_ref().to_string(), value)
120 }
121
122 fn remove<K>(&mut self, key: K) -> Option<T>
123 where
124 K: AsRef<str>,
125 {
126 self.hyperparameters.remove(key.as_ref())
127 }
128
129 fn contains<K>(&self, key: K) -> bool
130 where
131 K: AsRef<str>,
132 {
133 self.hyperparameters.contains_key(key.as_ref())
134 }
135
136 fn keys(&self) -> Vec<String> {
137 self.hyperparameters.keys().cloned().collect()
138 }
139}
140
141impl<T> TrainingConfiguration<T> for StandardModelConfig<T> {
142 fn epochs(&self) -> usize {
143 self.epochs
144 }
145
146 fn batch_size(&self) -> usize {
147 self.batch_size
148 }
149}