concision_core/config/
mod.rs1#[doc(inline)]
8pub use self::{hyper_params::HyperParam, model_config::StandardModelConfig};
9
10pub mod hyper_params;
11pub mod model_config;
12#[doc(hidden)]
14pub(crate) mod prelude {
15 pub use super::hyper_params::HyperParam;
16 pub use super::model_config::*;
17 pub use super::{ExtendedModelConfig, ModelConfiguration, RawConfig};
18}
19
20pub trait RawConfig {
23 type Ctx;
24}
25
26pub trait ModelConfiguration<T>: RawConfig {
29 fn get<K>(&self, key: K) -> Option<&T>
30 where
31 K: AsRef<str>;
32 fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
33 where
34 K: AsRef<str>;
35
36 fn set<K>(&mut self, key: K, value: T) -> Option<T>
37 where
38 K: AsRef<str>;
39 fn remove<K>(&mut self, key: K) -> Option<T>
40 where
41 K: AsRef<str>;
42 fn contains<K>(&self, key: K) -> bool
43 where
44 K: AsRef<str>;
45
46 fn keys(&self) -> Vec<&str>;
47}
48
49macro_rules! hyperparam_method {
50 ($($(dyn)? $name:ident::<$type:ty>),* $(,)?) => {
51 $(
52 hyperparam_method!(@impl $name::<$type>);
53 )*
54 };
55 (@impl dyn $name:ident::<$type:ty>) => {
56 fn $name(&self) -> Option<&$type> where T: 'static {
57 self.get(stringify!($name)).map(|v| v.downcast_ref::<$type>()).flatten()
58 }
59 };
60 (@impl $name:ident::<$type:ty>) => {
61 fn $name(&self) -> Option<&$type> {
62 self.get(stringify!($name))
63 }
64 };
65}
66
67pub trait ExtendedModelConfig<T>: ModelConfiguration<T> {
68 fn epochs(&self) -> usize;
69
70 fn batch_size(&self) -> usize;
71
72 hyperparam_method! {
73 learning_rate::<T>,
74 epsilon::<T>,
75 momentum::<T>,
76 weight_decay::<T>,
77 dropout::<T>,
78 decay::<T>,
79 beta::<T>,
80 beta1::<T>,
81 beta2::<T>,
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::StandardModelConfig;
88
89 #[test]
90 fn test_standard_model_config() {
91 let mut config = StandardModelConfig::new()
93 .with_epochs(1000)
94 .with_batch_size(32);
95 config.set_learning_rate(0.01);
97 config.set_momentum(0.9);
98 config.set_decay(0.0001);
99 assert_eq!(config.batch_size(), 32);
101 assert_eq!(config.epochs(), 1000);
102 assert_eq!(config.learning_rate(), Some(&0.01));
104 assert_eq!(config.momentum(), Some(&0.9));
105 assert_eq!(config.decay(), Some(&0.0001));
106 }
107}