concision_neural/types/
hyperparameters.rs

1/*
2    Appellation: hyperparameters <module>
3    Contrib: @FL03
4*/
5
6#[doc(hidden)]
7#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
8#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
9pub struct KeyValue<K = String, V = f64> {
10    pub key: K,
11    pub value: V,
12}
13
14#[derive(
15    Clone,
16    Copy,
17    Debug,
18    Default,
19    Eq,
20    Hash,
21    Ord,
22    PartialEq,
23    PartialOrd,
24    scsys_derive::VariantConstructors,
25    strum::AsRefStr,
26    strum::Display,
27    strum::EnumCount,
28    strum::EnumIs,
29    strum::EnumIter,
30    strum::EnumString,
31    strum::VariantArray,
32    strum::VariantNames,
33)]
34#[cfg_attr(
35    feature = "serde",
36    derive(serde::Deserialize, serde::Serialize),
37    serde(rename_all = "snake_case", untagged)
38)]
39#[strum(serialize_all = "snake_case")]
40pub enum Hyperparameters {
41    Decay,
42    Dropout,
43    #[default]
44    LearningRate,
45    Momentum,
46    Temperature,
47    WeightDecay,
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use core::str::FromStr;
54
55    #[test]
56    fn test_hyper() {
57        use strum::IntoEnumIterator;
58
59        assert_eq!(
60            Hyperparameters::from_str("learning_rate"),
61            Ok(Hyperparameters::LearningRate)
62        );
63
64        for variant in Hyperparameters::iter() {
65            let name = variant.as_ref();
66            let parsed = Hyperparameters::from_str(name);
67            assert_eq!(parsed, Ok(variant));
68        }
69    }
70}