concision_neural/config/traits/
config.rs1pub trait RawConfig {
9 type Ctx;
10}
11
12pub trait NetworkConfig<T>: RawConfig<Ctx = T> {
15 fn get<K>(&self, key: K) -> Option<&T>
16 where
17 K: AsRef<str>;
18 fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
19 where
20 K: AsRef<str>;
21
22 fn set<K>(&mut self, key: K, value: T) -> Option<T>
23 where
24 K: AsRef<str>;
25 fn remove<K>(&mut self, key: K) -> Option<T>
26 where
27 K: AsRef<str>;
28 fn contains<K>(&self, key: K) -> bool
29 where
30 K: AsRef<str>;
31
32 fn keys(&self) -> Vec<String>;
33}
34
35macro_rules! hyperparam_method {
36 (@dyn $name:ident: $type:ty) => {
37 fn $name(&self) -> Option<&$type> where T: 'static {
38 self.get(stringify!($name)).map(|v| v.downcast_ref::<$type>()).flatten()
39 }
40 };
41 (@impl $name:ident: $type:ty) => {
42 fn $name(&self) -> Option<&$type> {
43 self.get(stringify!($name))
44 }
45 };
46 (#[dyn] $($name:ident $type:ty),* $(,)?) => {
47 $(
48 hyperparam_method!(@dyn $name: $type);
49 )*
50 };
51 ($($name:ident $type:ty),* $(,)?) => {
52 $(
53 hyperparam_method!(@impl $name: $type);
54 )*
55 };
56}
57
58pub trait TrainingConfiguration<T>: NetworkConfig<T> {
59 fn epochs(&self) -> usize;
60
61 fn batch_size(&self) -> usize;
62
63 hyperparam_method! {
64 learning_rate T,
65 momentum T,
66 weight_decay T,
67 dropout T,
68 decay T,
69 beta1 T,
70 beta2 T,
71 epsilon T,
72 gradient_clip T,
73
74 }
75}
76
77