rogue_net/
rogue_net.rs

1use indexmap::IndexMap;
2use ndarray::{concatenate, s, Array2, Axis};
3use ron::extensions::Extensions;
4use std::collections::HashMap;
5use std::env;
6use std::fmt::Write;
7use std::fs::File;
8use std::io::Read;
9use std::path::Path;
10
11use crate::categorical_action_head::CategoricalActionHead;
12use crate::config::RogueNetConfig;
13use crate::config::TrainConfig;
14use crate::embedding::Embedding;
15use crate::msgpack::decode_state_dict;
16use crate::msgpack::TensorDict;
17use crate::state::{ObsSpace, State};
18use crate::transformer::Transformer;
19
20#[derive(Debug, Clone)]
21/// Implements the [RogueNet](https://github.com/entity-neural-network/rogue-net) entity neural network.
22pub struct RogueNet {
23    pub config: RogueNetConfig,
24    pub obs_space: ObsSpace,
25    translation: Option<Translate>,
26    embeddings: Vec<(String, Embedding)>,
27    backbone: Transformer,
28    action_heads: IndexMap<String, CategoricalActionHead>,
29}
30
31#[derive(Debug, Clone, Default)]
32/// Arguments for RogueNet forward pass.
33pub struct FwdArgs {
34    pub features: HashMap<String, Array2<f32>>,
35    pub actors: Vec<String>,
36}
37
38#[derive(Debug, Clone)]
39struct Translate {
40    reference_entity: String,
41    rotation_vec_indices: Option<[usize; 2]>,
42    position_feature_indices: HashMap<String, Vec<usize>>,
43}
44
45impl RogueNet {
46    /// Loads the parameters for a trained RogueNet neural network from a checkpoint directory produced by [enn-trainer](https://github.com/entity-neural-network/enn-trainer).
47    ///
48    /// # Arguments
49    /// * `path` - Path to the checkpoint directory.
50    pub fn load<P: AsRef<Path>>(path: P) -> RogueNet {
51        let config_path = path.as_ref().join("config.ron");
52        let ron = ron::Options::default().with_default_extension(Extensions::IMPLICIT_SOME);
53
54        let config: TrainConfig = ron
55            .from_reader(
56                File::open(&config_path)
57                    .unwrap_or_else(|_| panic!("Failed to open {}", config_path.display())),
58            )
59            .unwrap();
60
61        let state_path = path.as_ref().join("state.ron");
62        let state: State = ron
63            .from_reader(
64                File::open(&state_path)
65                    .unwrap_or_else(|_| panic!("Failed to open {}", state_path.display())),
66            )
67            .unwrap();
68
69        let agent_path = path.as_ref().join("state.agent.msgpack");
70        let state_dict = decode_state_dict(File::open(&agent_path).unwrap()).unwrap();
71        RogueNet::new(&state_dict, config.net, &state)
72    }
73
74    /// Loads the parameters for a trained RogueNet neural network from a tar archive of a checkpoint directory.
75    ///
76    /// # Arguments
77    /// * `r` - A reader for the tar archive.
78    ///
79    /// # Example
80    /// ```
81    /// use std::fs::File;
82    /// use rogue_net::RogueNet;
83    ///
84    /// let rogue_net = RogueNet::load_archive(File::open("test-data/simple.roguenet").unwrap());
85    /// ```
86    pub fn load_archive<R: Read>(r: R) -> Result<RogueNet, std::io::Error> {
87        let mut a = tar::Archive::new(r);
88        let mut config: Option<TrainConfig> = None;
89        let mut state = None;
90        let mut state_dict = None;
91        let ron = ron::Options::default().with_default_extension(Extensions::IMPLICIT_SOME);
92        for file in a.entries()? {
93            let file = file?;
94            match file
95                .path()?
96                .components()
97                .last()
98                .unwrap()
99                .as_os_str()
100                .to_str()
101                .unwrap()
102            {
103                "config.ron" => config = Some(ron.from_reader(file).unwrap()),
104                "state.ron" => state = Some(ron.from_reader(file).unwrap()),
105                "state.agent.msgpack" => state_dict = Some(decode_state_dict(file).unwrap()),
106                _ => {
107                    return Err(std::io::Error::new(
108                        std::io::ErrorKind::InvalidData,
109                        format!("Unexpected file: {}", file.path().unwrap().display()),
110                    ))
111                }
112            }
113        }
114        Ok(RogueNet::new(
115            &state_dict.ok_or_else(|| {
116                std::io::Error::new(std::io::ErrorKind::Other, "Missing state.agent.msgpack")
117            })?,
118            config
119                .ok_or_else(|| {
120                    std::io::Error::new(std::io::ErrorKind::Other, "Missing config.ron")
121                })?
122                .net,
123            &state.ok_or_else(|| {
124                std::io::Error::new(std::io::ErrorKind::Other, "Missing state.ron")
125            })?,
126        ))
127    }
128
129    /// Runs a forward pass of the RogueNet neural network.
130    ///
131    /// # Arguments
132    /// * `entities` - Maps each entity type to an `Array2<f32>` containing the entities' features.
133    ///
134    /// # Example
135    /// ```
136    /// use std::collections::HashMap;
137    /// use ndarray::prelude::*;
138    /// use rogue_net::{RogueNet, FwdArgs};
139    ///
140    /// let rogue_net = RogueNet::load("test-data/simple");
141    /// let mut features = HashMap::new();
142    /// features.insert("Head".to_string(), array![[3.0, 4.0]]);
143    /// features.insert("SnakeSegment".to_string(), array![[3.0, 4.0], [4.0, 4.0]]);
144    /// features.insert("Food".to_string(), array![[3.0, 5.0], [8.0, 4.0]]);
145    /// let (action_probs, actions) = rogue_net.forward(FwdArgs { features, ..Default::default() });
146    /// ```
147    pub fn forward(&self, mut args: FwdArgs) -> (Array2<f32>, Vec<u64>) {
148        if env::var("ROGUE_NET_DUMP_INPUTS").is_ok() {
149            args.dump(self.action_heads.get_index(0).unwrap().0)
150                .unwrap();
151        }
152
153        if let Some(t) = &self.translation {
154            let reference_entity = args
155                .features
156                .get(&t.reference_entity)
157                .unwrap_or_else(|| panic!("Missing entity type: {}", t.reference_entity));
158            let origin = t.position_feature_indices[&t.reference_entity]
159                .iter()
160                .map(|&i| reference_entity[[0, i]])
161                .collect::<Vec<_>>();
162            let rotation = t
163                .rotation_vec_indices
164                .map(|r| (reference_entity[[0, r[0]]], reference_entity[[0, r[1]]]));
165            for (entity, feats) in args.features.iter_mut() {
166                if *entity != t.reference_entity {
167                    for i in 0..feats.dim().0 {
168                        match rotation {
169                            Some((rx, ry)) => {
170                                let x =
171                                    feats[[i, t.position_feature_indices[entity][0]]] - origin[0];
172                                let y =
173                                    feats[[i, t.position_feature_indices[entity][1]]] - origin[1];
174                                feats[[i, t.position_feature_indices[entity][0]]] = x * rx + y * ry;
175                                feats[[i, t.position_feature_indices[entity][1]]] =
176                                    -x * ry + y * rx;
177                            }
178                            None => {
179                                for (j, x) in
180                                    t.position_feature_indices[entity].iter().zip(origin.iter())
181                                {
182                                    feats[[i, *j]] -= x;
183                                }
184                            }
185                        }
186                    }
187                }
188            }
189        }
190
191        let mut actors = vec![];
192        let mut i = 0;
193        let mut embeddings = Vec::with_capacity(args.features.len());
194        for (key, embedding) in &self.embeddings {
195            let x = embedding.forward(args.features[key].view());
196            if args.actors.iter().any(|a| a == key) {
197                for j in i..i + x.dim().0 {
198                    actors.push(j);
199                }
200            }
201            i += x.dim().0;
202            embeddings.push(x);
203        }
204        let x = concatenate(
205            Axis(0),
206            &embeddings.iter().map(|x| x.view()).collect::<Vec<_>>(),
207        )
208        .unwrap();
209        let x = self.backbone.forward(x, &args.features);
210        self.action_heads
211            .values()
212            .next()
213            .unwrap()
214            .forward(x.view(), actors)
215    }
216
217    fn new(state_dict: &TensorDict, config: RogueNetConfig, state: &State) -> Self {
218        assert!(
219            config.embd_pdrop == 0.0 && config.resid_pdrop == 0.0 && config.attn_pdrop == 0.0,
220            "dropout is not supported"
221        );
222        assert!(config.pooling.is_none(), "pooling is not supported");
223
224        let translation = config.translation.as_ref().map(|t| {
225            assert!(
226                t.rotation_angle_feature.is_none(),
227                "rotation_angle_feature not implemented",
228            );
229            assert!(!t.add_dist_feature, "add_dist_features not implemented");
230            let rotation_vec_indices = t.rotation_vec_features.as_ref().map(|rot| {
231                let indices = rot
232                    .iter()
233                    .map(|s| {
234                        state.obs_space.entities[&t.reference_entity]
235                            .features
236                            .iter()
237                            .position(|f| f == s)
238                            .unwrap()
239                    })
240                    .collect::<Vec<_>>();
241                assert_eq!(indices.len(), 2, "rotation_vec_features must have length 2");
242                [indices[0], indices[1]]
243            });
244            let position_feature_indices = state
245                .obs_space
246                .entities
247                .iter()
248                .map(|(name, entity)| {
249                    let indices = t
250                        .position_features
251                        .iter()
252                        .map(|f| {
253                            entity
254                                .features
255                                .iter()
256                                .position(|f2| f2 == f)
257                                .unwrap_or_else(|| {
258                                    panic!("feature \"{}\" not found in reference entity", f)
259                                })
260                        })
261                        .collect::<Vec<_>>();
262                    (name.clone(), indices)
263                })
264                .collect();
265            Translate {
266                reference_entity: t.reference_entity.clone(),
267                rotation_vec_indices,
268                position_feature_indices,
269            }
270        });
271
272        let dict = state_dict.as_dict();
273        let mut embeddings = Vec::new();
274        for (key, value) in dict["embedding"].as_dict()["embeddings"].as_dict() {
275            let embedding = Embedding::from(value);
276            embeddings.push((key.clone(), embedding));
277        }
278        let backbone = Transformer::new(&dict["backbone"], &config, state);
279
280        let mut action_heads = IndexMap::new();
281        for (key, value) in dict["action_heads"].as_dict() {
282            let action_head = CategoricalActionHead::from(value);
283            action_heads.insert(key.clone(), action_head);
284        }
285
286        RogueNet {
287            embeddings,
288            translation,
289            backbone,
290            action_heads,
291            config,
292            obs_space: state.obs_space.clone(),
293        }
294    }
295
296    /// Adapts the RogueNet neural network to the given observation space by
297    /// filtering out any features that were not present during training.
298    pub fn with_obs_filter(mut self, obs_space: HashMap<String, Vec<String>>) -> Self {
299        for (entity, received_features) in obs_space {
300            if let Some((_, embedding)) = self.embeddings.iter_mut().find(|(e, _)| *e == entity) {
301                embedding.set_obs_filter(
302                    &self.obs_space.entities[&entity].features,
303                    &received_features,
304                );
305            }
306        }
307        self
308    }
309}
310
311impl FwdArgs {
312    fn dump(&self, action_name: &str) -> Result<(), std::fmt::Error> {
313        let mut out = String::new();
314        writeln!(out, "obs = Observation(")?;
315
316        // Features
317        writeln!(out, "    features={{")?;
318        for (entity_name, features) in &self.features {
319            writeln!(out, "        \"{entity_name}\": [")?;
320            for i in 0..features.dim().0 {
321                writeln!(out, "            {},", features.slice(s![i, ..]))?;
322            }
323            writeln!(out, "        ],")?;
324        }
325        writeln!(out, "    }},")?;
326
327        // IDs
328        writeln!(out, "    ids={{")?;
329        let mut total = 0;
330        for (entity_name, features) in &self.features {
331            let count = features.dim().0;
332            writeln!(
333                out,
334                "        \"{entity_name}\": {:?},",
335                &(total..total + count).collect::<Vec<_>>()[..]
336            )?;
337            total += count;
338        }
339        writeln!(out, "    }},")?;
340
341        // done, reward
342        writeln!(out, "    done=False,")?;
343        writeln!(out, "    reward=0.0,")?;
344
345        // Actions
346        writeln!(
347            out,
348            "    actions={{\"{action_name}\": CategoricalActionMask(actor_types={:?})}},",
349            &self.actors[..]
350        )?;
351
352        writeln!(out, ")")?;
353
354        println!("{}", out);
355
356        Ok(())
357    }
358}