concision_neural/types/
hyperparameters.rs1#[doc(hidden)]
7#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
8#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
9pub struct KeyValue<K = String, V = f64> {
10 pub key: K,
11 pub value: V,
12}
13
14#[derive(
15 Clone,
16 Copy,
17 Debug,
18 Default,
19 Eq,
20 Hash,
21 Ord,
22 PartialEq,
23 PartialOrd,
24 scsys_derive::VariantConstructors,
25 strum::AsRefStr,
26 strum::Display,
27 strum::EnumCount,
28 strum::EnumIs,
29 strum::EnumIter,
30 strum::EnumString,
31 strum::VariantArray,
32 strum::VariantNames,
33)]
34#[cfg_attr(
35 feature = "serde",
36 derive(serde::Deserialize, serde::Serialize),
37 serde(rename_all = "snake_case", untagged)
38)]
39#[strum(serialize_all = "snake_case")]
40pub enum Hyperparameters {
41 Decay,
42 Dropout,
43 #[default]
44 LearningRate,
45 Momentum,
46 Temperature,
47 WeightDecay,
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use core::str::FromStr;
54
55 #[test]
56 fn test_hyper() {
57 use strum::IntoEnumIterator;
58
59 assert_eq!(
60 Hyperparameters::from_str("learning_rate"),
61 Ok(Hyperparameters::LearningRate)
62 );
63
64 for variant in Hyperparameters::iter() {
65 let name = variant.as_ref();
66 let parsed = Hyperparameters::from_str(name);
67 assert_eq!(parsed, Ok(variant));
68 }
69 }
70}