concision_core/config/
hyper_params.rs

1/*
2    Appellation: hyper_params <module>
3    Contrib: @FL03
4*/
5/// An enumeration of common HyperParams used in neural network configurations.
6#[derive(
7    Clone,
8    Debug,
9    Eq,
10    Hash,
11    Ord,
12    PartialEq,
13    PartialOrd,
14    strum::AsRefStr,
15    strum::Display,
16    strum::EnumCount,
17    strum::EnumIs,
18    strum::EnumIter,
19    strum::EnumString,
20    strum::VariantNames,
21)]
22#[cfg_attr(
23    feature = "serde",
24    derive(serde::Deserialize, serde::Serialize),
25    serde(rename_all = "snake_case", untagged)
26)]
27#[strum(serialize_all = "snake_case")]
28#[non_exhaustive]
29pub enum HyperParam {
30    Decay,
31    #[cfg_attr(feature = "serde", serde(alias = "drop_out", alias = "p"))]
32    Dropout,
33    #[cfg_attr(feature = "serde", serde(alias = "lr", alias = "gamma"))]
34    LearningRate,
35    Momentum,
36    Temperature,
37    WeightDecay,
38    Beta1,
39    Beta2,
40    Epsilon,
41}
42
43impl HyperParam {
44    /// returns a list of variants as strings
45    pub const fn variants() -> &'static [&'static str] {
46        use strum::VariantNames;
47        HyperParam::VARIANTS
48    }
49}
50
51impl core::borrow::Borrow<str> for HyperParam {
52    fn borrow(&self) -> &str {
53        self.as_ref()
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::HyperParam;
60
61    #[test]
62    fn test_hyper_params() {
63        use HyperParam::*;
64        use core::str::FromStr;
65
66        let tests = [
67            ("decay", Decay),
68            ("dropout", Dropout),
69            ("momentum", Momentum),
70            ("temperature", Temperature),
71            ("beta1", Beta1),
72            ("beta2", Beta2),
73            ("epsilon", Epsilon),
74            ("learning_rate", LearningRate),
75            ("weight_decay", WeightDecay),
76        ];
77        for (s, param) in tests {
78            assert_eq!(HyperParam::from_str(s).ok(), Some(param));
79        }
80    }
81}