concision_core/config/
hyper_params.rs1#[derive(
7 Clone,
8 Debug,
9 Eq,
10 Hash,
11 Ord,
12 PartialEq,
13 PartialOrd,
14 strum::AsRefStr,
15 strum::Display,
16 strum::EnumCount,
17 strum::EnumIs,
18 strum::EnumIter,
19 strum::EnumString,
20 strum::VariantNames,
21)]
22#[cfg_attr(
23 feature = "serde",
24 derive(serde::Deserialize, serde::Serialize),
25 serde(rename_all = "snake_case", untagged)
26)]
27#[strum(serialize_all = "snake_case")]
28#[non_exhaustive]
29pub enum HyperParam {
30 Decay,
31 #[cfg_attr(feature = "serde", serde(alias = "drop_out", alias = "p"))]
32 Dropout,
33 #[cfg_attr(feature = "serde", serde(alias = "lr", alias = "gamma"))]
34 LearningRate,
35 Momentum,
36 Temperature,
37 WeightDecay,
38 Beta1,
39 Beta2,
40 Epsilon,
41}
42
43impl HyperParam {
44 pub const fn variants() -> &'static [&'static str] {
46 use strum::VariantNames;
47 HyperParam::VARIANTS
48 }
49}
50
51impl core::borrow::Borrow<str> for HyperParam {
52 fn borrow(&self) -> &str {
53 self.as_ref()
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 use super::HyperParam;
60
61 #[test]
62 fn test_hyper_params() {
63 use HyperParam::*;
64 use core::str::FromStr;
65
66 let tests = [
67 ("decay", Decay),
68 ("dropout", Dropout),
69 ("momentum", Momentum),
70 ("temperature", Temperature),
71 ("beta1", Beta1),
72 ("beta2", Beta2),
73 ("epsilon", Epsilon),
74 ("learning_rate", LearningRate),
75 ("weight_decay", WeightDecay),
76 ];
77 for (s, param) in tests {
78 assert_eq!(HyperParam::from_str(s).ok(), Some(param));
79 }
80 }
81}