1use std::sync::Arc;
2use pyo3::prelude::*;
3use pyo3::ToPyObject;
4use serde::ser::{Serialize, Serializer, SerializeStruct};
5use serde::de::{self, Deserializer, Visitor, MapAccess};
6use serde::Deserialize;
7use pyo3::types::PyDict;
8use std::fmt;
9#[derive(Debug, Clone, PartialEq)]
16pub struct MetaLearningGene {
17 pub innovation_number: Arc<str>,
18 pub learning_rate: f32,
19 pub learning_threshold: f32
20}
21
22impl Default for MetaLearningGene {
23 fn default() -> Self {
24 Self {
25 innovation_number: Arc::from("m01"),
26 learning_rate: 0.001,
27 learning_threshold: 0.5,
28 }
29 }
30}
31
32impl ToPyObject for MetaLearningGene {
33 fn to_object(&self, py: Python<'_>) -> PyObject {
34 let dict = PyDict::new(py);
35 dict.set_item("innnovation_number", &self.innovation_number.to_string());
36 dict.set_item("learning_rate", self.learning_rate);
37 dict.set_item("learning_threshold", self.learning_threshold);
38 dict.into()
39 }
40}
41
42impl Serialize for MetaLearningGene {
43 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
44 where
45 S: Serializer,
46 {
47 let mut state = serializer.serialize_struct("MetaLearningGene", 3)?;
48 state.serialize_field("innovation_number", &self.innovation_number.as_ref())?;
49 state.serialize_field("learning_rate", &self.learning_rate)?;
50 state.serialize_field("learning_threshold", &self.learning_threshold)?;
51 state.end()
52 }
53}
54
55
56impl<'de> Deserialize<'de> for MetaLearningGene {
57 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
58 where
59 D: Deserializer<'de>,
60 {
61 #[derive(Deserialize)]
62 #[serde(field_identifier, rename_all = "lowercase")]
63 enum Field { Innovation_Number, Learning_Rate, Learning_Threshold }
64
65 struct MetaLearningGeneVisitor;
66
67 impl<'de> Visitor<'de> for MetaLearningGeneVisitor {
68 type Value = MetaLearningGene;
69
70 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
71 formatter.write_str("struct MetaLearningGene")
72 }
73
74 fn visit_map<V>(self, mut map: V) -> Result<MetaLearningGene, V::Error>
75 where
76 V: MapAccess<'de>,
77 {
78 let mut innovation_number = None;
79 let mut learning_rate = None;
80 let mut learning_threshold = None;
81
82 while let Some(key) = map.next_key()? {
83 match key {
84 Field::Innovation_Number => {
85 if innovation_number.is_some() {
86 return Err(de::Error::duplicate_field("innovation_number"));
87 }
88 let value: String = map.next_value()?;
89 innovation_number = Some(Arc::from(value.as_str()));
90 }
91 Field::Learning_Rate => {
92 if learning_rate.is_some() {
93 return Err(de::Error::duplicate_field("learning_rate"));
94 }
95 learning_rate = Some(map.next_value()?);
96 }
97 Field::Learning_Threshold => {
98 if learning_threshold.is_some() {
99 return Err(de::Error::duplicate_field("learning_threshold"));
100 }
101 learning_threshold = Some(map.next_value()?);
102 }
103 }
104 }
105
106 let innovation_number = innovation_number.ok_or_else(|| de::Error::missing_field("innovation_number"))?;
107 let learning_rate = learning_rate.ok_or_else(|| de::Error::missing_field("learning_rate"))?;
108 let learning_threshold = learning_threshold.ok_or_else(|| de::Error::missing_field("learning_threshold"))?;
109
110 Ok(MetaLearningGene {
111 innovation_number,
112 learning_rate,
113 learning_threshold,
114 })
115 }
116 }
117
118 const FIELDS: &'static [&'static str] = &["innovation_number", "learning_rate", "learning_threshold"];
119 deserializer.deserialize_struct("MetaLearningGene", FIELDS, MetaLearningGeneVisitor)
120 }
121}
122
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use log::*;
128 use crate::setup_logger;
129
130 #[test]
131 fn test_serialize_metalearning() {
132 setup_logger();
133
134 let mtg: MetaLearningGene = MetaLearningGene::default();
135 let json = serde_json::to_string_pretty(&mtg).unwrap();
136 debug!("{}", json);
137
138 assert!(json.len() > 0);
139 }
140
141 #[test]
142 fn test_deserialize_metalearning() {
143 setup_logger();
144
145 let mtg: MetaLearningGene = MetaLearningGene::default();
146 let json = serde_json::to_string_pretty(&mtg).unwrap();
147
148 let mtg_deserialized: MetaLearningGene = serde_json::from_str(&json).unwrap();
149
150 assert_eq!(MetaLearningGene::default(), mtg_deserialized);
151
152 }
153
154}