evo_rl/enecode/
meta.rs

1use std::sync::Arc;
2use pyo3::prelude::*;
3use pyo3::ToPyObject;
4use serde::ser::{Serialize, Serializer, SerializeStruct};
5use serde::de::{self, Deserializer, Visitor, MapAccess};
6use serde::Deserialize;
7use pyo3::types::PyDict;
8use std::fmt;
9/// Gene that defines the meta-learning rules for the neural network.
10///
11/// # Fields
12/// * `innovation_number` - Unique identifier for this particular gene.
13/// * `learning_rate` - Learning rate for synaptic adjustments.
14/// * `learning_threshold` - Learning threshold for synaptic adjustments.
15#[derive(Debug, Clone, PartialEq)]
16pub struct MetaLearningGene {
17    pub innovation_number: Arc<str>,
18    pub learning_rate: f32,
19    pub learning_threshold: f32
20}
21
22impl Default for MetaLearningGene {
23    fn default() -> Self {
24        Self {
25            innovation_number: Arc::from("m01"),
26            learning_rate: 0.001,
27            learning_threshold: 0.5,
28        }
29    }
30}
31
32impl ToPyObject for MetaLearningGene {
33    fn to_object(&self, py: Python<'_>) -> PyObject {
34        let dict = PyDict::new(py);
35        dict.set_item("innnovation_number", &self.innovation_number.to_string());
36        dict.set_item("learning_rate", self.learning_rate);
37        dict.set_item("learning_threshold", self.learning_threshold);
38        dict.into()
39    }
40}
41
42impl Serialize for MetaLearningGene {
43    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 
44    where 
45        S: Serializer,
46        {
47            let mut state = serializer.serialize_struct("MetaLearningGene", 3)?;
48            state.serialize_field("innovation_number", &self.innovation_number.as_ref())?;
49            state.serialize_field("learning_rate", &self.learning_rate)?;
50            state.serialize_field("learning_threshold", &self.learning_threshold)?;
51            state.end()
52        }
53}
54
55
56impl<'de> Deserialize<'de> for MetaLearningGene {
57    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
58    where
59        D: Deserializer<'de>,
60    {
61        #[derive(Deserialize)]
62        #[serde(field_identifier, rename_all = "lowercase")]
63        enum Field { Innovation_Number, Learning_Rate, Learning_Threshold }
64
65        struct MetaLearningGeneVisitor;
66
67        impl<'de> Visitor<'de> for MetaLearningGeneVisitor {
68            type Value = MetaLearningGene;
69
70            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
71                formatter.write_str("struct MetaLearningGene")
72            }
73
74            fn visit_map<V>(self, mut map: V) -> Result<MetaLearningGene, V::Error>
75            where
76                V: MapAccess<'de>,
77            {
78                let mut innovation_number = None;
79                let mut learning_rate = None;
80                let mut learning_threshold = None;
81
82                while let Some(key) = map.next_key()? {
83                    match key {
84                        Field::Innovation_Number => {
85                            if innovation_number.is_some() {
86                                return Err(de::Error::duplicate_field("innovation_number"));
87                            }
88                            let value: String = map.next_value()?;
89                            innovation_number = Some(Arc::from(value.as_str()));
90                        }
91                        Field::Learning_Rate => {
92                            if learning_rate.is_some() {
93                                return Err(de::Error::duplicate_field("learning_rate"));
94                            }
95                            learning_rate = Some(map.next_value()?);
96                        }
97                        Field::Learning_Threshold => {
98                            if learning_threshold.is_some() {
99                                return Err(de::Error::duplicate_field("learning_threshold"));
100                            }
101                            learning_threshold = Some(map.next_value()?);
102                        }
103                    }
104                }
105
106                let innovation_number = innovation_number.ok_or_else(|| de::Error::missing_field("innovation_number"))?;
107                let learning_rate = learning_rate.ok_or_else(|| de::Error::missing_field("learning_rate"))?;
108                let learning_threshold = learning_threshold.ok_or_else(|| de::Error::missing_field("learning_threshold"))?;
109
110                Ok(MetaLearningGene {
111                    innovation_number,
112                    learning_rate,
113                    learning_threshold,
114                })
115            }
116        }
117
118        const FIELDS: &'static [&'static str] = &["innovation_number", "learning_rate", "learning_threshold"];
119        deserializer.deserialize_struct("MetaLearningGene", FIELDS, MetaLearningGeneVisitor)
120    }
121}
122
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use log::*;
128    use crate::setup_logger;
129
130    #[test]
131    fn test_serialize_metalearning() {
132        setup_logger();
133
134        let mtg: MetaLearningGene = MetaLearningGene::default();
135        let json = serde_json::to_string_pretty(&mtg).unwrap();
136        debug!("{}", json);
137
138        assert!(json.len() > 0);
139    }
140
141    #[test]
142    fn test_deserialize_metalearning() {
143        setup_logger();
144
145        let mtg: MetaLearningGene = MetaLearningGene::default();
146        let json = serde_json::to_string_pretty(&mtg).unwrap();
147
148        let mtg_deserialized: MetaLearningGene = serde_json::from_str(&json).unwrap();
149
150        assert_eq!(MetaLearningGene::default(), mtg_deserialized);
151
152    }
153
154}