concision_neural/model/
config.rs

1/*
2    Appellation: config <module>
3    Contrib: @FL03
4*/
5
6use crate::Hyperparameters::*;
7
8pub(crate) type ModelConfigMap<T> = std::collections::HashMap<String, T>;
9
10#[derive(Clone, Debug)]
11#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))]
12pub struct StandardModelConfig<T> {
13    pub(crate) batch_size: usize,
14    pub(crate) epochs: usize,
15    pub(crate) hyperparameters: ModelConfigMap<T>,
16}
17
18impl<T> StandardModelConfig<T> {
19    pub fn new() -> Self {
20        Self {
21            batch_size: 0,
22            epochs: 0,
23            hyperparameters: ModelConfigMap::new(),
24        }
25    }
26
27    pub fn batch_size(&self) -> usize {
28        self.batch_size
29    }
30
31    pub fn epochs(&self) -> usize {
32        self.epochs
33    }
34
35    pub fn with_batch_size(self, batch_size: usize) -> Self {
36        Self { batch_size, ..self }
37    }
38
39    pub fn with_epochs(self, epochs: usize) -> Self {
40        Self { epochs, ..self }
41    }
42
43    pub fn insert_hyperparameter(&mut self, key: impl ToString, value: T) -> Option<T> {
44        self.hyperparameters.insert(key.to_string(), value)
45    }
46
47    pub fn set_decay(&mut self, decay: T) -> Option<T> {
48        self.insert_hyperparameter(Decay, decay)
49    }
50    pub fn set_learning_rate(&mut self, learning_rate: T) -> Option<T> {
51        self.insert_hyperparameter(LearningRate, learning_rate)
52    }
53
54    pub fn set_momentum(&mut self, momentum: T) -> Option<T> {
55        self.insert_hyperparameter(Momentum, momentum)
56    }
57
58    pub fn set_weight_decay(&mut self, decay: T) -> Option<T> {
59        self.insert_hyperparameter("weight_decay", decay)
60    }
61
62    pub fn get(&self, key: impl ToString) -> Option<&T> {
63        self.hyperparameters.get(&key.to_string())
64    }
65
66    pub fn learning_rate(&self) -> Option<&T> {
67        self.get(LearningRate)
68    }
69
70    pub fn momentum(&self) -> Option<&T> {
71        self.get(Momentum)
72    }
73
74    pub fn decay(&self) -> Option<&T> {
75        self.get(Decay)
76    }
77}