1mod config;
2mod window;
3use super::BorderAtariAct;
4use super::{BorderAtariActFilter, BorderAtariObsFilter};
5use crate::atari_env::{AtariAction, AtariEnv, EmulatorConfig};
6use anyhow::Result;
7use border_core::{record::Record, Act, Env, Info, Obs, Step};
8pub use config::BorderAtariEnvConfig;
9use image::{
10 imageops::{resize, FilterType::Triangle},
11 ImageBuffer, Rgb,
12};
13use itertools::izip;
14use std::ptr::copy;
15use std::{default::Default, marker::PhantomData};
16use window::AtariWindow;
17#[cfg(feature = "atari-env-sys")]
18use winit::{event_loop::ControlFlow, platform::run_return::EventLoopExtRunReturn};
19
20pub struct NullInfo;
22
23impl Info for NullInfo {}
24
25fn env(rom_dir: &str, name: &str) -> AtariEnv {
26 AtariEnv::new(
27 rom_dir.to_string() + format!("/{}.bin", name).as_str(),
28 EmulatorConfig {
29 frame_skip: 1,
32 color_averaging: false,
33 repeat_action_probability: 0.0,
34 ..EmulatorConfig::default()
35 },
36 )
37}
38
39pub struct BorderAtariEnv<O, A, OF, AF>
44where
45 O: Obs,
46 A: Act,
47 OF: BorderAtariObsFilter<O>,
48 AF: BorderAtariActFilter<A>,
49{
50 train: bool,
52
53 env: AtariEnv,
55
56 window: Option<AtariWindow>,
58
59 obs_buffer: [Vec<u8>; 2],
61
62 lives: usize,
64
65 was_real_done: bool,
67
68 frames: Vec<u8>,
70
71 obs_filter: OF,
73 act_filter: AF,
74 phantom: PhantomData<(O, A)>,
75}
76
77impl<O, A, OF, AF> BorderAtariEnv<O, A, OF, AF>
78where
79 O: Obs,
80 A: Act,
81 OF: BorderAtariObsFilter<O>,
82 AF: BorderAtariActFilter<A>,
83{
84 pub fn open(&mut self) -> Result<()> {
86 if !self.window.is_none() {
88 return Ok(());
89 }
90
91 self.window = Some(AtariWindow::new(&self.env)?);
92
93 Ok(())
94 }
95
96 pub fn get_num_actions_atari(&self) -> i64 {
98 self.env.minimal_actions().len() as i64
100 }
101
102 fn episodic_life_env_step(&mut self, a: &BorderAtariAct) -> (Vec<u8>, f32, i8) {
103 let actions = self.env.minimal_actions();
104 let ix = a.act;
105 let reward = self.env.step(actions[ix as usize]) as f32;
106
107 let is_terminated = match self.env.is_game_over() {
108 true => 1,
109 false => 0,
110 };
111 self.was_real_done = is_terminated == 1;
112 let lives = self.env.lives();
113
114 self.lives = lives;
118
119 let (w, h) = (self.env.width(), self.env.height());
120 let mut obs = vec![0u8; w * h * 3];
121 self.env.render_rgb24(&mut obs);
122
123 (obs, reward, is_terminated)
124 }
125
126 fn skip_and_max(&mut self, a: &BorderAtariAct) -> (Vec<u8>, f32, Vec<i8>) {
127 let mut total_reward = 0f32;
128 let mut is_terminated = 0;
129
130 for i in 0..4 {
131 let (obs, reward, is_terminated_) = self.episodic_life_env_step(a);
132 total_reward += reward;
133 is_terminated = is_terminated_;
134 if i == 2 {
135 self.obs_buffer[0] = obs;
136 } else if i == 3 {
137 self.obs_buffer[1] = obs;
138 }
139 if is_terminated_ == 1 {
140 break;
141 }
142 }
143
144 let obs = self.obs_buffer[0]
146 .iter()
147 .zip(self.obs_buffer[1].iter())
148 .map(|(&a, &b)| a.max(b))
149 .collect::<Vec<_>>();
150
151 (obs, total_reward, vec![is_terminated])
152 }
153
154 fn clip_reward(&self, r: f32) -> Vec<f32> {
155 if self.train {
156 if r == 0.0 {
157 vec![0.0]
158 } else {
159 vec![r.signum()]
160 }
161 } else {
162 vec![r]
163 }
164 }
165
166 fn warp_and_grayscale(w: u32, h: u32, obs: Vec<u8>) -> Vec<u8> {
167 let img = ImageBuffer::<Rgb<_>, _>::from_vec(w, h, obs).unwrap();
169 let img = resize(&img, 84, 84, Triangle);
170 let buf = {
171 let buf = img.to_vec();
172 let i1 = buf.iter().step_by(3);
173 let i2 = buf.iter().skip(1).step_by(3);
174 let i3 = buf.iter().skip(2).step_by(3);
175 izip![i1, i2, i3]
176 .map(|(&b, &g, &r)| {
177 ((0.299 * r as f32) + (0.587 * g as f32) + (0.114 * b as f32)) as u8
178 })
179 .collect::<Vec<_>>()
180 };
181 assert_eq!(buf.len(), 84 * 84);
186 buf
187 }
188
189 fn stack_frame(&mut self, obs: Vec<u8>) {
190 unsafe {
191 let src: *const u8 = &self.frames[0];
192 let dst: *mut u8 = &mut self.frames[1 * 84 * 84];
193 copy(src, dst, 3 * 84 * 84);
194
195 let src: *const u8 = &obs[0];
196 let dst: *mut u8 = &mut self.frames[0];
197 copy(src, dst, 1 * 84 * 84);
198 }
199 }
200}
201
202impl<O, A, OF, AF> Default for BorderAtariEnv<O, A, OF, AF>
203where
204 O: Obs,
205 A: Act,
206 OF: BorderAtariObsFilter<O>,
207 AF: BorderAtariActFilter<A>,
208{
209 fn default() -> Self {
210 let config = BorderAtariEnvConfig::<O, A, OF, AF>::default();
211
212 Self {
213 train: false,
214 env: env(config.rom_dir.as_str(), "pong"),
215 window: None,
216 obs_buffer: [vec![], vec![]],
217 lives: 0,
218 was_real_done: true,
219 frames: vec![0; 4 * 84 * 84],
220 obs_filter: OF::build(&config.obs_filter_config).unwrap(),
221 act_filter: AF::build(&config.act_filter_config).unwrap(),
222 phantom: PhantomData,
223 }
224 }
225}
226
227impl<O, A, OF, AF> Env for BorderAtariEnv<O, A, OF, AF>
228where
229 O: Obs,
230 A: Act,
231 OF: BorderAtariObsFilter<O>,
232 AF: BorderAtariActFilter<A>,
233{
234 type Config = BorderAtariEnvConfig<O, A, OF, AF>;
235 type Obs = O;
236 type Act = A;
237 type Info = NullInfo;
238
239 fn build(config: &Self::Config, _seed: i64) -> Result<Self>
240 where
241 Self: Sized,
242 {
243 let mut env = Self {
244 train: config.train,
245 env: env(config.rom_dir.as_str(), config.name.as_str()),
246 window: None,
247 obs_buffer: [vec![], vec![]],
248 lives: 0,
249 was_real_done: true,
250 frames: vec![0; 4 * 84 * 84],
251 obs_filter: OF::build(&config.obs_filter_config)?,
252 act_filter: AF::build(&config.act_filter_config)?,
253 phantom: PhantomData,
254 };
255
256 if config.render {
257 let _ = env.open();
258 }
259
260 Ok(env)
261 }
262
263 fn reset(&mut self, _is_done: Option<&Vec<i8>>) -> Result<Self::Obs> {
264 if self.was_real_done {
265 self.env.reset();
266 } else {
268 self.env.step(AtariAction::Noop);
270
271 let n = fastrand::u8(0..=30);
272 for _ in 0..n {
273 self.env.step(AtariAction::Noop);
274 }
275 }
276
277 self.was_real_done = false;
280 self.lives = self.env.lives();
281
282 let (w, h) = (self.env.width(), self.env.height());
283 let mut obs = vec![0u8; w * h * 3];
284 self.env.render_rgb24(&mut obs);
285 self.obs_buffer[0] = obs.clone();
286 self.obs_buffer[1] = obs.clone();
287
288 let obs = Self::warp_and_grayscale(w as u32, h as u32, obs);
289
290 unsafe {
291 let src: *const u8 = &obs[0];
292 for i in 0..4 {
293 let dst: *mut u8 = &mut self.frames[i * 84 * 84];
294 copy(src, dst, 84 * 84);
295 }
296 }
297
298 Ok(self.obs_filter.filt(self.frames.clone().into()).0)
299 }
300
301 fn reset_with_index(&mut self, ix: usize) -> Result<Self::Obs> {
302 self.env.seed(ix as i32);
303 self.reset(None)
304 }
305
306 fn step(&mut self, act: &Self::Act) -> (border_core::Step<Self>, border_core::record::Record)
307 where
308 Self: Sized,
309 {
310 #[cfg(feature = "atari-env-sys")]
311 {
312 let act_org = act.clone();
313 let (act, _record) = self.act_filter.filt(act_org.clone());
314 let (obs, reward, is_terminated) = self.skip_and_max(&act);
315 let is_truncated = vec![0]; let (w, h) = (self.env.width() as u32, self.env.height() as u32);
317 let obs = Self::warp_and_grayscale(w, h, obs);
318 let reward = self.clip_reward(reward); self.stack_frame(obs);
320 let (obs, _record) = self.obs_filter.filt(self.frames.clone().into());
321 let step = Step::new(
322 obs,
323 act_org,
324 reward,
325 is_terminated,
326 is_truncated,
327 NullInfo,
328 None,
329 );
330 let record = Record::empty();
331
332 if let Some(window) = self.window.as_mut() {
333 window.event_loop.run_return(|_event, _, control_flow| {
334 *control_flow = ControlFlow::Exit;
335 });
336 self.env.render_rgb32(window.get_frame());
337 window.render_and_request_redraw();
338 }
339
340 (step, record)
341 }
342
343 #[cfg(not(feature = "atari-env-sys"))]
344 unimplemented!();
345 }
346}