1use indexmap::IndexMap;
2use ndarray::{concatenate, s, Array2, Axis};
3use ron::extensions::Extensions;
4use std::collections::HashMap;
5use std::env;
6use std::fmt::Write;
7use std::fs::File;
8use std::io::Read;
9use std::path::Path;
10
11use crate::categorical_action_head::CategoricalActionHead;
12use crate::config::RogueNetConfig;
13use crate::config::TrainConfig;
14use crate::embedding::Embedding;
15use crate::msgpack::decode_state_dict;
16use crate::msgpack::TensorDict;
17use crate::state::{ObsSpace, State};
18use crate::transformer::Transformer;
19
20#[derive(Debug, Clone)]
21pub struct RogueNet {
23 pub config: RogueNetConfig,
24 pub obs_space: ObsSpace,
25 translation: Option<Translate>,
26 embeddings: Vec<(String, Embedding)>,
27 backbone: Transformer,
28 action_heads: IndexMap<String, CategoricalActionHead>,
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct FwdArgs {
34 pub features: HashMap<String, Array2<f32>>,
35 pub actors: Vec<String>,
36}
37
38#[derive(Debug, Clone)]
39struct Translate {
40 reference_entity: String,
41 rotation_vec_indices: Option<[usize; 2]>,
42 position_feature_indices: HashMap<String, Vec<usize>>,
43}
44
45impl RogueNet {
46 pub fn load<P: AsRef<Path>>(path: P) -> RogueNet {
51 let config_path = path.as_ref().join("config.ron");
52 let ron = ron::Options::default().with_default_extension(Extensions::IMPLICIT_SOME);
53
54 let config: TrainConfig = ron
55 .from_reader(
56 File::open(&config_path)
57 .unwrap_or_else(|_| panic!("Failed to open {}", config_path.display())),
58 )
59 .unwrap();
60
61 let state_path = path.as_ref().join("state.ron");
62 let state: State = ron
63 .from_reader(
64 File::open(&state_path)
65 .unwrap_or_else(|_| panic!("Failed to open {}", state_path.display())),
66 )
67 .unwrap();
68
69 let agent_path = path.as_ref().join("state.agent.msgpack");
70 let state_dict = decode_state_dict(File::open(&agent_path).unwrap()).unwrap();
71 RogueNet::new(&state_dict, config.net, &state)
72 }
73
74 pub fn load_archive<R: Read>(r: R) -> Result<RogueNet, std::io::Error> {
87 let mut a = tar::Archive::new(r);
88 let mut config: Option<TrainConfig> = None;
89 let mut state = None;
90 let mut state_dict = None;
91 let ron = ron::Options::default().with_default_extension(Extensions::IMPLICIT_SOME);
92 for file in a.entries()? {
93 let file = file?;
94 match file
95 .path()?
96 .components()
97 .last()
98 .unwrap()
99 .as_os_str()
100 .to_str()
101 .unwrap()
102 {
103 "config.ron" => config = Some(ron.from_reader(file).unwrap()),
104 "state.ron" => state = Some(ron.from_reader(file).unwrap()),
105 "state.agent.msgpack" => state_dict = Some(decode_state_dict(file).unwrap()),
106 _ => {
107 return Err(std::io::Error::new(
108 std::io::ErrorKind::InvalidData,
109 format!("Unexpected file: {}", file.path().unwrap().display()),
110 ))
111 }
112 }
113 }
114 Ok(RogueNet::new(
115 &state_dict.ok_or_else(|| {
116 std::io::Error::new(std::io::ErrorKind::Other, "Missing state.agent.msgpack")
117 })?,
118 config
119 .ok_or_else(|| {
120 std::io::Error::new(std::io::ErrorKind::Other, "Missing config.ron")
121 })?
122 .net,
123 &state.ok_or_else(|| {
124 std::io::Error::new(std::io::ErrorKind::Other, "Missing state.ron")
125 })?,
126 ))
127 }
128
129 pub fn forward(&self, mut args: FwdArgs) -> (Array2<f32>, Vec<u64>) {
148 if env::var("ROGUE_NET_DUMP_INPUTS").is_ok() {
149 args.dump(self.action_heads.get_index(0).unwrap().0)
150 .unwrap();
151 }
152
153 if let Some(t) = &self.translation {
154 let reference_entity = args
155 .features
156 .get(&t.reference_entity)
157 .unwrap_or_else(|| panic!("Missing entity type: {}", t.reference_entity));
158 let origin = t.position_feature_indices[&t.reference_entity]
159 .iter()
160 .map(|&i| reference_entity[[0, i]])
161 .collect::<Vec<_>>();
162 let rotation = t
163 .rotation_vec_indices
164 .map(|r| (reference_entity[[0, r[0]]], reference_entity[[0, r[1]]]));
165 for (entity, feats) in args.features.iter_mut() {
166 if *entity != t.reference_entity {
167 for i in 0..feats.dim().0 {
168 match rotation {
169 Some((rx, ry)) => {
170 let x =
171 feats[[i, t.position_feature_indices[entity][0]]] - origin[0];
172 let y =
173 feats[[i, t.position_feature_indices[entity][1]]] - origin[1];
174 feats[[i, t.position_feature_indices[entity][0]]] = x * rx + y * ry;
175 feats[[i, t.position_feature_indices[entity][1]]] =
176 -x * ry + y * rx;
177 }
178 None => {
179 for (j, x) in
180 t.position_feature_indices[entity].iter().zip(origin.iter())
181 {
182 feats[[i, *j]] -= x;
183 }
184 }
185 }
186 }
187 }
188 }
189 }
190
191 let mut actors = vec![];
192 let mut i = 0;
193 let mut embeddings = Vec::with_capacity(args.features.len());
194 for (key, embedding) in &self.embeddings {
195 let x = embedding.forward(args.features[key].view());
196 if args.actors.iter().any(|a| a == key) {
197 for j in i..i + x.dim().0 {
198 actors.push(j);
199 }
200 }
201 i += x.dim().0;
202 embeddings.push(x);
203 }
204 let x = concatenate(
205 Axis(0),
206 &embeddings.iter().map(|x| x.view()).collect::<Vec<_>>(),
207 )
208 .unwrap();
209 let x = self.backbone.forward(x, &args.features);
210 self.action_heads
211 .values()
212 .next()
213 .unwrap()
214 .forward(x.view(), actors)
215 }
216
217 fn new(state_dict: &TensorDict, config: RogueNetConfig, state: &State) -> Self {
218 assert!(
219 config.embd_pdrop == 0.0 && config.resid_pdrop == 0.0 && config.attn_pdrop == 0.0,
220 "dropout is not supported"
221 );
222 assert!(config.pooling.is_none(), "pooling is not supported");
223
224 let translation = config.translation.as_ref().map(|t| {
225 assert!(
226 t.rotation_angle_feature.is_none(),
227 "rotation_angle_feature not implemented",
228 );
229 assert!(!t.add_dist_feature, "add_dist_features not implemented");
230 let rotation_vec_indices = t.rotation_vec_features.as_ref().map(|rot| {
231 let indices = rot
232 .iter()
233 .map(|s| {
234 state.obs_space.entities[&t.reference_entity]
235 .features
236 .iter()
237 .position(|f| f == s)
238 .unwrap()
239 })
240 .collect::<Vec<_>>();
241 assert_eq!(indices.len(), 2, "rotation_vec_features must have length 2");
242 [indices[0], indices[1]]
243 });
244 let position_feature_indices = state
245 .obs_space
246 .entities
247 .iter()
248 .map(|(name, entity)| {
249 let indices = t
250 .position_features
251 .iter()
252 .map(|f| {
253 entity
254 .features
255 .iter()
256 .position(|f2| f2 == f)
257 .unwrap_or_else(|| {
258 panic!("feature \"{}\" not found in reference entity", f)
259 })
260 })
261 .collect::<Vec<_>>();
262 (name.clone(), indices)
263 })
264 .collect();
265 Translate {
266 reference_entity: t.reference_entity.clone(),
267 rotation_vec_indices,
268 position_feature_indices,
269 }
270 });
271
272 let dict = state_dict.as_dict();
273 let mut embeddings = Vec::new();
274 for (key, value) in dict["embedding"].as_dict()["embeddings"].as_dict() {
275 let embedding = Embedding::from(value);
276 embeddings.push((key.clone(), embedding));
277 }
278 let backbone = Transformer::new(&dict["backbone"], &config, state);
279
280 let mut action_heads = IndexMap::new();
281 for (key, value) in dict["action_heads"].as_dict() {
282 let action_head = CategoricalActionHead::from(value);
283 action_heads.insert(key.clone(), action_head);
284 }
285
286 RogueNet {
287 embeddings,
288 translation,
289 backbone,
290 action_heads,
291 config,
292 obs_space: state.obs_space.clone(),
293 }
294 }
295
296 pub fn with_obs_filter(mut self, obs_space: HashMap<String, Vec<String>>) -> Self {
299 for (entity, received_features) in obs_space {
300 if let Some((_, embedding)) = self.embeddings.iter_mut().find(|(e, _)| *e == entity) {
301 embedding.set_obs_filter(
302 &self.obs_space.entities[&entity].features,
303 &received_features,
304 );
305 }
306 }
307 self
308 }
309}
310
311impl FwdArgs {
312 fn dump(&self, action_name: &str) -> Result<(), std::fmt::Error> {
313 let mut out = String::new();
314 writeln!(out, "obs = Observation(")?;
315
316 writeln!(out, " features={{")?;
318 for (entity_name, features) in &self.features {
319 writeln!(out, " \"{entity_name}\": [")?;
320 for i in 0..features.dim().0 {
321 writeln!(out, " {},", features.slice(s![i, ..]))?;
322 }
323 writeln!(out, " ],")?;
324 }
325 writeln!(out, " }},")?;
326
327 writeln!(out, " ids={{")?;
329 let mut total = 0;
330 for (entity_name, features) in &self.features {
331 let count = features.dim().0;
332 writeln!(
333 out,
334 " \"{entity_name}\": {:?},",
335 &(total..total + count).collect::<Vec<_>>()[..]
336 )?;
337 total += count;
338 }
339 writeln!(out, " }},")?;
340
341 writeln!(out, " done=False,")?;
343 writeln!(out, " reward=0.0,")?;
344
345 writeln!(
347 out,
348 " actions={{\"{action_name}\": CategoricalActionMask(actor_types={:?})}},",
349 &self.actors[..]
350 )?;
351
352 writeln!(out, ")")?;
353
354 println!("{}", out);
355
356 Ok(())
357 }
358}