border_atari_env/
env.rs

1mod config;
2mod window;
3use super::BorderAtariAct;
4use super::{BorderAtariActFilter, BorderAtariObsFilter};
5use crate::atari_env::{AtariAction, AtariEnv, EmulatorConfig};
6use anyhow::Result;
7use border_core::{record::Record, Act, Env, Info, Obs, Step};
8pub use config::BorderAtariEnvConfig;
9use image::{
10    imageops::{/*grayscale,*/ resize, FilterType::Triangle},
11    ImageBuffer, /*Luma,*/ Rgb,
12};
13use itertools::izip;
14use std::ptr::copy;
15use std::{default::Default, marker::PhantomData};
16use window::AtariWindow;
17#[cfg(feature = "atari-env-sys")]
18use winit::{event_loop::ControlFlow, platform::run_return::EventLoopExtRunReturn};
19
20/// Empty struct.
21pub struct NullInfo;
22
23impl Info for NullInfo {}
24
25fn env(rom_dir: &str, name: &str) -> AtariEnv {
26    AtariEnv::new(
27        rom_dir.to_string() + format!("/{}.bin", name).as_str(),
28        EmulatorConfig {
29            // display_screen: true,
30            // sound: true,
31            frame_skip: 1,
32            color_averaging: false,
33            repeat_action_probability: 0.0,
34            ..EmulatorConfig::default()
35        },
36    )
37}
38
39/// A wrapper of atari learning environment.
40///
41/// Preprocessing is the same in the link:
42/// <https://stable-baselines3.readthedocs.io/en/master/common/atari_wrappers.html#stable_baselines3.common.atari_wrappers.AtariWrapper>.
43pub struct BorderAtariEnv<O, A, OF, AF>
44where
45    O: Obs,
46    A: Act,
47    OF: BorderAtariObsFilter<O>,
48    AF: BorderAtariActFilter<A>,
49{
50    // True for training mode, it affects preprocessing at every steps.
51    train: bool,
52
53    // Environment
54    env: AtariEnv,
55
56    // Window for displaying the current game state
57    window: Option<AtariWindow>,
58
59    // Observation buffer for frame skipping
60    obs_buffer: [Vec<u8>; 2],
61
62    // Lives in the game
63    lives: usize,
64
65    // If the game was done.
66    was_real_done: bool,
67
68    // Buffer for stacking frames
69    frames: Vec<u8>,
70
71    // Filters
72    obs_filter: OF,
73    act_filter: AF,
74    phantom: PhantomData<(O, A)>,
75}
76
77impl<O, A, OF, AF> BorderAtariEnv<O, A, OF, AF>
78where
79    O: Obs,
80    A: Act,
81    OF: BorderAtariObsFilter<O>,
82    AF: BorderAtariActFilter<A>,
83{
84    /// Opens window for display.
85    pub fn open(&mut self) -> Result<()> {
86        // Do nothing if a window is already opened.
87        if !self.window.is_none() {
88            return Ok(());
89        }
90
91        self.window = Some(AtariWindow::new(&self.env)?);
92
93        Ok(())
94    }
95
96    /// Returns the number of actions.
97    pub fn get_num_actions_atari(&self) -> i64 {
98        // self.env.available_actions().len() as i64
99        self.env.minimal_actions().len() as i64
100    }
101
102    fn episodic_life_env_step(&mut self, a: &BorderAtariAct) -> (Vec<u8>, f32, i8) {
103        let actions = self.env.minimal_actions();
104        let ix = a.act;
105        let reward = self.env.step(actions[ix as usize]) as f32;
106
107        let is_terminated = match self.env.is_game_over() {
108            true => 1,
109            false => 0,
110        };
111        self.was_real_done = is_terminated == 1;
112        let lives = self.env.lives();
113
114        // if self.train && lives < self.lives && lives > 0 {
115        //     done = true;
116        // }
117        self.lives = lives;
118
119        let (w, h) = (self.env.width(), self.env.height());
120        let mut obs = vec![0u8; w * h * 3];
121        self.env.render_rgb24(&mut obs);
122
123        (obs, reward, is_terminated)
124    }
125
126    fn skip_and_max(&mut self, a: &BorderAtariAct) -> (Vec<u8>, f32, Vec<i8>) {
127        let mut total_reward = 0f32;
128        let mut is_terminated = 0;
129
130        for i in 0..4 {
131            let (obs, reward, is_terminated_) = self.episodic_life_env_step(a);
132            total_reward += reward;
133            is_terminated = is_terminated_;
134            if i == 2 {
135                self.obs_buffer[0] = obs;
136            } else if i == 3 {
137                self.obs_buffer[1] = obs;
138            }
139            if is_terminated_ == 1 {
140                break;
141            }
142        }
143
144        // Max pooling
145        let obs = self.obs_buffer[0]
146            .iter()
147            .zip(self.obs_buffer[1].iter())
148            .map(|(&a, &b)| a.max(b))
149            .collect::<Vec<_>>();
150
151        (obs, total_reward, vec![is_terminated])
152    }
153
154    fn clip_reward(&self, r: f32) -> Vec<f32> {
155        if self.train {
156            if r == 0.0 {
157                vec![0.0]
158            } else {
159                vec![r.signum()]
160            }
161        } else {
162            vec![r]
163        }
164    }
165
166    fn warp_and_grayscale(w: u32, h: u32, obs: Vec<u8>) -> Vec<u8> {
167        // `obs.len()` is w * h * 3 where (w, h) is the size of the frame.
168        let img = ImageBuffer::<Rgb<_>, _>::from_vec(w, h, obs).unwrap();
169        let img = resize(&img, 84, 84, Triangle);
170        let buf = {
171            let buf = img.to_vec();
172            let i1 = buf.iter().step_by(3);
173            let i2 = buf.iter().skip(1).step_by(3);
174            let i3 = buf.iter().skip(2).step_by(3);
175            izip![i1, i2, i3]
176                .map(|(&b, &g, &r)| {
177                    ((0.299 * r as f32) + (0.587 * g as f32) + (0.114 * b as f32)) as u8
178                })
179                .collect::<Vec<_>>()
180        };
181        // let buf = {
182        //     let img: ImageBuffer<Luma<u8>, _> = grayscale(&img);
183        //     img.to_vec()
184        // };
185        assert_eq!(buf.len(), 84 * 84);
186        buf
187    }
188
189    fn stack_frame(&mut self, obs: Vec<u8>) {
190        unsafe {
191            let src: *const u8 = &self.frames[0];
192            let dst: *mut u8 = &mut self.frames[1 * 84 * 84];
193            copy(src, dst, 3 * 84 * 84);
194
195            let src: *const u8 = &obs[0];
196            let dst: *mut u8 = &mut self.frames[0];
197            copy(src, dst, 1 * 84 * 84);
198        }
199    }
200}
201
202impl<O, A, OF, AF> Default for BorderAtariEnv<O, A, OF, AF>
203where
204    O: Obs,
205    A: Act,
206    OF: BorderAtariObsFilter<O>,
207    AF: BorderAtariActFilter<A>,
208{
209    fn default() -> Self {
210        let config = BorderAtariEnvConfig::<O, A, OF, AF>::default();
211
212        Self {
213            train: false,
214            env: env(config.rom_dir.as_str(), "pong"),
215            window: None,
216            obs_buffer: [vec![], vec![]],
217            lives: 0,
218            was_real_done: true,
219            frames: vec![0; 4 * 84 * 84],
220            obs_filter: OF::build(&config.obs_filter_config).unwrap(),
221            act_filter: AF::build(&config.act_filter_config).unwrap(),
222            phantom: PhantomData,
223        }
224    }
225}
226
227impl<O, A, OF, AF> Env for BorderAtariEnv<O, A, OF, AF>
228where
229    O: Obs,
230    A: Act,
231    OF: BorderAtariObsFilter<O>,
232    AF: BorderAtariActFilter<A>,
233{
234    type Config = BorderAtariEnvConfig<O, A, OF, AF>;
235    type Obs = O;
236    type Act = A;
237    type Info = NullInfo;
238
239    fn build(config: &Self::Config, _seed: i64) -> Result<Self>
240    where
241        Self: Sized,
242    {
243        let mut env = Self {
244            train: config.train,
245            env: env(config.rom_dir.as_str(), config.name.as_str()),
246            window: None,
247            obs_buffer: [vec![], vec![]],
248            lives: 0,
249            was_real_done: true,
250            frames: vec![0; 4 * 84 * 84],
251            obs_filter: OF::build(&config.obs_filter_config)?,
252            act_filter: AF::build(&config.act_filter_config)?,
253            phantom: PhantomData,
254        };
255
256        if config.render {
257            let _ = env.open();
258        }
259
260        Ok(env)
261    }
262
263    fn reset(&mut self, _is_done: Option<&Vec<i8>>) -> Result<Self::Obs> {
264        if self.was_real_done {
265            self.env.reset();
266            // println!("RESET");
267        } else {
268            // no-op step to advance from terminal/lost life state
269            self.env.step(AtariAction::Noop);
270
271            let n = fastrand::u8(0..=30);
272            for _ in 0..n {
273                self.env.step(AtariAction::Noop);
274            }
275        }
276
277        // TODO: noop random steps (?)
278
279        self.was_real_done = false;
280        self.lives = self.env.lives();
281
282        let (w, h) = (self.env.width(), self.env.height());
283        let mut obs = vec![0u8; w * h * 3];
284        self.env.render_rgb24(&mut obs);
285        self.obs_buffer[0] = obs.clone();
286        self.obs_buffer[1] = obs.clone();
287
288        let obs = Self::warp_and_grayscale(w as u32, h as u32, obs);
289
290        unsafe {
291            let src: *const u8 = &obs[0];
292            for i in 0..4 {
293                let dst: *mut u8 = &mut self.frames[i * 84 * 84];
294                copy(src, dst, 84 * 84);
295            }
296        }
297
298        Ok(self.obs_filter.filt(self.frames.clone().into()).0)
299    }
300
301    fn reset_with_index(&mut self, ix: usize) -> Result<Self::Obs> {
302        self.env.seed(ix as i32);
303        self.reset(None)
304    }
305
306    fn step(&mut self, act: &Self::Act) -> (border_core::Step<Self>, border_core::record::Record)
307    where
308        Self: Sized,
309    {
310        #[cfg(feature = "atari-env-sys")]
311        {
312            let act_org = act.clone();
313            let (act, _record) = self.act_filter.filt(act_org.clone());
314            let (obs, reward, is_terminated) = self.skip_and_max(&act);
315            let is_truncated = vec![0]; // not compatible with the official implementation
316            let (w, h) = (self.env.width() as u32, self.env.height() as u32);
317            let obs = Self::warp_and_grayscale(w, h, obs);
318            let reward = self.clip_reward(reward); // in training
319            self.stack_frame(obs);
320            let (obs, _record) = self.obs_filter.filt(self.frames.clone().into());
321            let step = Step::new(
322                obs,
323                act_org,
324                reward,
325                is_terminated,
326                is_truncated,
327                NullInfo,
328                None,
329            );
330            let record = Record::empty();
331
332            if let Some(window) = self.window.as_mut() {
333                window.event_loop.run_return(|_event, _, control_flow| {
334                    *control_flow = ControlFlow::Exit;
335                });
336                self.env.render_rgb32(window.get_frame());
337                window.render_and_request_redraw();
338            }
339
340            (step, record)
341        }
342
343        #[cfg(not(feature = "atari-env-sys"))]
344        unimplemented!();
345    }
346}