concision_neural/model/
config.rs1use crate::Hyperparameters::*;
7
8pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
9
10#[derive(Clone, Debug)]
11#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
12pub struct StandardModelConfig<T> {
13 pub(crate) batch_size: usize,
14 pub(crate) epochs: usize,
15 pub(crate) hyperparameters: ModelConfigMap<T>,
16}
17
18impl<T> StandardModelConfig<T> {
19 pub fn new() -> Self {
20 Self {
21 batch_size: 0,
22 epochs: 0,
23 hyperparameters: ModelConfigMap::new(),
24 }
25 }
26
27 pub fn batch_size(&self) -> usize {
28 self.batch_size
29 }
30
31 pub fn epochs(&self) -> usize {
32 self.epochs
33 }
34
35 pub fn with_batch_size(self, batch_size: usize) -> Self {
36 Self { batch_size, ..self }
37 }
38
39 pub fn with_epochs(self, epochs: usize) -> Self {
40 Self { epochs, ..self }
41 }
42
43 pub fn insert_hyperparameter(&mut self, key: impl ToString, value: T) -> Option<T> {
44 self.hyperparameters.insert(key.to_string(), value)
45 }
46
47 pub fn set_decay(&mut self, decay: T) -> Option<T> {
48 self.insert_hyperparameter(Decay, decay)
49 }
50 pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
51 self.insert_hyperparameter(LearningRate, learning_rate)
52 }
53
54 pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
55 self.insert_hyperparameter(Momentum, momentum)
56 }
57
58 pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
59 self.insert_hyperparameter("weight_decay", decay)
60 }
61
62 pub fn get(&self, key: impl ToString) -> Option<&T> {
63 self.hyperparameters.get(&key.to_string())
64 }
65
66 pub fn learning_rate(&self) -> Option<&T> {
67 self.get(LearningRate)
68 }
69
70 pub fn momentum(&self) -> Option<&T> {
71 self.get(Momentum)
72 }
73
74 pub fn decay(&self) -> Option<&T> {
75 self.get(Decay)
76 }
77}
78
79
80impl<T> crate::NetworkConfig<T> for StandardModelConfig<T> {
81 fn get<K>(&self, key: K) -> Option<&T>
82 where
83 K: AsRef<str>,
84 {
85 self.hyperparameters.get(key.as_ref())
86 }
87
88 fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
89 where
90 K: AsRef<str>,
91 {
92 self.hyperparameters.get_mut(key.as_ref())
93 }
94
95 fn set<K>(&mut self, key: K, value: T) -> Option<T>
96 where
97 K: AsRef<str>,
98 {
99 self.hyperparameters.insert(key.as_ref().to_string(), value)
100 }
101
102 fn remove<K>(&mut self, key: K) -> Option<T>
103 where
104 K: AsRef<str>,
105 {
106 self.hyperparameters.remove(key.as_ref())
107 }
108
109 fn contains<K>(&self, key: K) -> bool
110 where
111 K: AsRef<str>,
112 {
113 self.hyperparameters.contains_key(key.as_ref())
114 }
115
116 fn keys(&self) -> Vec<String> {
117 self.hyperparameters.keys().cloned().collect()
118 }
119}
120
121impl<T> crate::TrainingConfiguration<T> for StandardModelConfig<T> {
122 fn epochs(&self) -> usize {
123 self.epochs
124 }
125
126 fn batch_size(&self) -> usize {
127 self.batch_size
128 }
129}