concision_core/config/
mod.rs

1/*
2    appellation: config <module>
3    authors: @FL03
4*/
5//! This module is dedicated to establishing common interfaces for valid configuration objects
6//! while providing a standard implementation to quickly spin up a new model.
7#[doc(inline)]
8pub use self::{hyper_params::HyperParam, model_config::StandardModelConfig};
9
10pub mod hyper_params;
11pub mod model_config;
12// prelude (local)
13#[doc(hidden)]
14pub(crate) mod prelude {
15    pub use super::hyper_params::HyperParam;
16    pub use super::model_config::*;
17    pub use super::{ExtendedModelConfig, ModelConfiguration, RawConfig};
18}
19
20/// The [`RawConfig`] trait defines a basic interface for all _configurations_ used within the
21/// framework for neural networks, their layers, and more.
22pub trait RawConfig {
23    type Ctx;
24}
25
26/// The [`ModelConfiguration`] trait extends the [`RawConfig`] trait to provide a more robust
27/// interface for neural network configurations.
28pub trait ModelConfiguration<T>: RawConfig {
29    fn get<K>(&self, key: K) -> Option<&T>
30    where
31        K: AsRef<str>;
32    fn get_mut<K>(&mut self, key: K) -> Option<&mut T>
33    where
34        K: AsRef<str>;
35
36    fn set<K>(&mut self, key: K, value: T) -> Option<T>
37    where
38        K: AsRef<str>;
39    fn remove<K>(&mut self, key: K) -> Option<T>
40    where
41        K: AsRef<str>;
42    fn contains<K>(&self, key: K) -> bool
43    where
44        K: AsRef<str>;
45
46    fn keys(&self) -> Vec<&str>;
47}
48
49macro_rules! hyperparam_method {
50    ($($(dyn)? $name:ident::<$type:ty>),* $(,)?) => {
51        $(
52            hyperparam_method!(@impl $name::<$type>);
53        )*
54    };
55    (@impl dyn $name:ident::<$type:ty>) => {
56        fn $name(&self) -> Option<&$type> where T: 'static {
57            self.get(stringify!($name)).map(|v| v.downcast_ref::<$type>()).flatten()
58        }
59    };
60    (@impl $name:ident::<$type:ty>) => {
61        fn $name(&self) -> Option<&$type> {
62            self.get(stringify!($name))
63        }
64    };
65}
66
67pub trait ExtendedModelConfig<T>: ModelConfiguration<T> {
68    fn epochs(&self) -> usize;
69
70    fn batch_size(&self) -> usize;
71
72    hyperparam_method! {
73        learning_rate::<T>,
74        epsilon::<T>,
75        momentum::<T>,
76        weight_decay::<T>,
77        dropout::<T>,
78        decay::<T>,
79        beta::<T>,
80        beta1::<T>,
81        beta2::<T>,
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::StandardModelConfig;
88
89    #[test]
90    fn test_standard_model_config() {
91        // initialize a new model configuration with then given epochs and batch size
92        let mut config = StandardModelConfig::new()
93            .with_epochs(1000)
94            .with_batch_size(32);
95        // set various hyperparameters
96        config.set_learning_rate(0.01);
97        config.set_momentum(0.9);
98        config.set_decay(0.0001);
99        // verify the configuration
100        assert_eq!(config.batch_size(), 32);
101        assert_eq!(config.epochs(), 1000);
102        // validate the stored hyperparameters
103        assert_eq!(config.learning_rate(), Some(&0.01));
104        assert_eq!(config.momentum(), Some(&0.9));
105        assert_eq!(config.decay(), Some(&0.0001));
106    }
107}