concision_neural/model/
config.rs1use crate::Hyperparameters::*;
7
8pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
9
10#[derive(Clone, Debug)]
11#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
12pub struct StandardModelConfig<T> {
13 pub(crate) batch_size: usize,
14 pub(crate) epochs: usize,
15 pub(crate) hyperparameters: ModelConfigMap<T>,
16}
17
18impl<T> StandardModelConfig<T> {
19 pub fn new() -> Self {
20 Self {
21 batch_size: 0,
22 epochs: 0,
23 hyperparameters: ModelConfigMap::new(),
24 }
25 }
26
27 pub fn batch_size(&self) -> usize {
28 self.batch_size
29 }
30
31 pub fn epochs(&self) -> usize {
32 self.epochs
33 }
34
35 pub fn with_batch_size(self, batch_size: usize) -> Self {
36 Self { batch_size, ..self }
37 }
38
39 pub fn with_epochs(self, epochs: usize) -> Self {
40 Self { epochs, ..self }
41 }
42
43 pub fn insert_hyperparameter(&mut self, key: impl ToString, value: T) -> Option<T> {
44 self.hyperparameters.insert(key.to_string(), value)
45 }
46
47 pub fn set_decay(&mut self, decay: T) -> Option<T> {
48 self.insert_hyperparameter(Decay, decay)
49 }
50 pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
51 self.insert_hyperparameter(LearningRate, learning_rate)
52 }
53
54 pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
55 self.insert_hyperparameter(Momentum, momentum)
56 }
57
58 pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
59 self.insert_hyperparameter("weight_decay", decay)
60 }
61
62 pub fn get(&self, key: impl ToString) -> Option<&T> {
63 self.hyperparameters.get(&key.to_string())
64 }
65
66 pub fn learning_rate(&self) -> Option<&T> {
67 self.get(LearningRate)
68 }
69
70 pub fn momentum(&self) -> Option<&T> {
71 self.get(Momentum)
72 }
73
74 pub fn decay(&self) -> Option<&T> {
75 self.get(Decay)
76 }
77}