concision_neural/traits/
config.rs1pub trait NetworkConfig<T> {
7 fn get<K>(&self, key: K) -> Option<&T>
8 where
9 K: AsRef<str>;
10 fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
11 where
12 K: AsRef<str>;
13
14 fn set<K>(&mut self, key: K, value: T) -> Option<T>
15 where
16 K: AsRef<str>;
17 fn remove<K>(&mut self, key: K) -> Option<T>
18 where
19 K: AsRef<str>;
20 fn contains<K>(&self, key: K) -> bool
21 where
22 K: AsRef<str>;
23
24 fn keys(&self) -> Vec<String>;
25}
26
27macro_rules! hyperparam_method {
28 (@dyn $name:ident: $type:ty) => {
29 fn $name(&self) -> Option<&$type> where T: 'static {
30 self.get(stringify!($name)).map(|v| v.downcast_ref::<$type>()).flatten()
31 }
32 };
33 (@impl $name:ident: $type:ty) => {
34 fn $name(&self) -> Option<&$type> {
35 self.get(stringify!($name))
36 }
37 };
38 (#[dyn] $($name:ident $type:ty),* $(,)?) => {
39 $(
40 hyperparam_method!(@dyn $name: $type);
41 )*
42 };
43 ($($name:ident $type:ty),* $(,)?) => {
44 $(
45 hyperparam_method!(@impl $name: $type);
46 )*
47 };
48}
49
50pub trait TrainingConfiguration<T>: NetworkConfig<T> {
51 fn epochs(&self) -> usize;
52
53 fn batch_size(&self) -> usize;
54
55 hyperparam_method! {
56 learning_rate T,
57 momentum T,
58 weight_decay T,
59 dropout T,
60 decay T,
61 beta1 T,
62 beta2 T,
63 epsilon T,
64 gradient_clip T,
65
66 }
67}