concision_neural/config/traits/
config.rs

1/*
2    Appellation: config <module>
3    Contrib: @FL03
4*/
5
6pub trait NetworkConfig<T> {
7    fn get<K>(&self, key: K) -> Option<&T>
8    where
9        K: AsRef<str>;
10    fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
11    where
12        K: AsRef<str>;
13
14    fn set<K>(&mut self, key: K, value: T) -> Option<T>
15    where
16        K: AsRef<str>;
17    fn remove<K>(&mut self, key: K) -> Option<T>
18    where
19        K: AsRef<str>;
20    fn contains<K>(&self, key: K) -> bool
21    where
22        K: AsRef<str>;
23
24    fn keys(&self) -> Vec<String>;
25}
26
27macro_rules! hyperparam_method {
28    (@dyn $name:ident: $type:ty) => {
29        fn $name(&self) -> Option<&$type> where T: 'static {
30            self.get(stringify!($name)).map(|v| v.downcast_ref::<$type>()).flatten()
31        }
32    };
33    (@impl $name:ident: $type:ty) => {
34        fn $name(&self) -> Option<&$type> {
35            self.get(stringify!($name))
36        }
37    };
38    (#[dyn] $($name:ident $type:ty),* $(,)?) => {
39        $(
40            hyperparam_method!(@dyn $name: $type);
41        )*
42    };
43    ($($name:ident $type:ty),* $(,)?) => {
44        $(
45            hyperparam_method!(@impl $name: $type);
46        )*
47    };
48}
49
50pub trait TrainingConfiguration<T>: NetworkConfig<T> {
51    fn epochs(&self) -> usize;
52
53    fn batch_size(&self) -> usize;
54
55    hyperparam_method! {
56        learning_rate T,
57        momentum T,
58        weight_decay T,
59        dropout T,
60        decay T,
61        beta1 T,
62        beta2 T,
63        epsilon T,
64        gradient_clip T,
65
66    }
67}