border_atari_env/
atari_env.rs

1//! Atari environment for reinforcement learning.
2pub mod ale;
3use std::path::Path;
4
5pub use ale::Ale;
6pub use ale::AleAction as AtariAction;
7pub use ale::AleConfig as EmulatorConfig;
8use gym_core::{ActionSpace, CategoricalActionSpace, GymEnv};
9
10use anyhow::{Context, Result};
11use ndarray::{Array1, ArrayD, ArrayView3, Ix0, Ix1, Ix3};
12use num_traits::cast::FromPrimitive;
13
14pub struct AtariEnv {
15    ale: Ale,
16}
17
18impl AtariEnv {
19    /// about frame-skipping and action-repeat,
20    /// see <https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/>
21    pub fn new<P: AsRef<Path>>(rom_path: P, emulator_config: EmulatorConfig) -> Self {
22        Self {
23            ale: Ale::new(rom_path.as_ref(), emulator_config),
24        }
25    }
26    pub fn width(&self) -> usize {
27        self.ale.width() as usize
28    }
29    pub fn height(&self) -> usize {
30        self.ale.height() as usize
31    }
32    pub fn available_actions(&self) -> Vec<AtariAction> {
33        self.ale.available_actions()
34    }
35    pub fn minimal_actions(&self) -> Vec<AtariAction> {
36        self.ale.minimal_actions()
37    }
38    pub fn available_difficulty_settings(&self) -> Vec<i32> {
39        self.ale.available_difficulty_settings()
40    }
41    pub fn lives(&self) -> usize {
42        self.ale.lives() as usize
43    }
44    pub fn is_game_over(&self) -> bool {
45        self.ale.is_game_over()
46    }
47    pub fn reset(&mut self) {
48        self.ale.reset()
49    }
50    pub fn step(&mut self, action: AtariAction) -> i32 {
51        self.ale.take_action(action)
52    }
53    pub fn rgb32_size(&self) -> usize {
54        self.ale.rgb32_size()
55    }
56    pub fn rgb24_size(&self) -> usize {
57        self.ale.rgb24_size()
58    }
59    pub fn ram_size(&self) -> usize {
60        self.ale.ram_size()
61    }
62    pub fn render_rgb32(&self, buf: &mut [u8]) {
63        self.ale.rgb32(buf);
64    }
65    pub fn render_rgb24(&self, buf: &mut [u8]) {
66        self.ale.rgb24(buf);
67    }
68    pub fn render_ram(&self, buf: &mut [u8]) {
69        self.ale.ram(buf);
70    }
71    pub fn into_ram_env(self) -> AtariRamEnv {
72        AtariRamEnv::new(self)
73    }
74    pub fn into_rgb_env(self) -> AtariRgbEnv {
75        AtariRgbEnv::new(self)
76    }
77    pub fn seed(&self, seed: i32) {
78        self.ale.seed(seed);
79    }
80}
81
82pub struct AtariRamEnv {
83    buf1: Array1<u8>,
84    inner: AtariEnv,
85    available_actions: Vec<AtariAction>,
86}
87
88pub struct AtariRgbEnv {
89    buf1: Array1<u8>,
90    inner: AtariEnv,
91    available_actions: Vec<AtariAction>,
92}
93
94impl AtariRamEnv {
95    pub fn new(env: AtariEnv) -> Self {
96        Self {
97            buf1: Array1::zeros(env.ram_size()),
98            available_actions: env.minimal_actions(),
99            inner: env,
100        }
101    }
102}
103
104impl GymEnv<i32> for AtariRamEnv {
105    fn state_size(&self) -> Vec<usize> {
106        vec![self.inner.ram_size()]
107    }
108    fn action_space(&self) -> ActionSpace<i32> {
109        Box::new(CategoricalActionSpace::new(self.available_actions.len()))
110    }
111    fn state(&self, out: ndarray::ArrayViewMut<f32, ndarray::IxDyn>) -> Result<()> {
112        let mut out = out.into_dimensionality::<Ix1>()?;
113        ndarray::parallel::par_azip!((a in &mut out, &b in &self.buf1) {*a = b as f32 / 255.0;});
114        Ok(())
115    }
116    fn step(&mut self, action: ArrayD<i32>) -> Result<i32> {
117        let action = AtariAction::from_i32(action.into_dimensionality::<Ix0>()?.into_scalar())
118            .context("action out of range")?;
119        let reward = self.inner.step(action);
120        self.inner.render_ram(self.buf1.as_slice_mut().unwrap());
121        Ok(reward)
122    }
123    fn is_over(&self) -> bool {
124        self.inner.is_game_over()
125    }
126    fn reset(&mut self) {
127        self.inner.reset();
128    }
129}
130
131impl AtariRgbEnv {
132    pub fn new(env: AtariEnv) -> Self {
133        Self {
134            buf1: Array1::zeros(env.rgb24_size()),
135            available_actions: env.minimal_actions(),
136            inner: env,
137        }
138    }
139}
140
141impl GymEnv<i32> for AtariRgbEnv {
142    fn state_size(&self) -> Vec<usize> {
143        vec![self.inner.height(), self.inner.width(), 3]
144    }
145    fn action_space(&self) -> ActionSpace<i32> {
146        Box::new(CategoricalActionSpace::new(self.available_actions.len()))
147    }
148    fn state(&self, out: ndarray::ArrayViewMut<f32, ndarray::IxDyn>) -> Result<()> {
149        let mut out = out.into_dimensionality::<Ix3>()?;
150        let from: ArrayView3<_> = self
151            .buf1
152            .view()
153            .into_shape(self.state_size())?
154            .into_dimensionality()?;
155        ndarray::parallel::par_azip!((a in &mut out, &b in &from) {*a = b as f32 / 255.0;});
156        // ndarray::parallel::par_azip!((a in &mut out, &b in &self.buf1) {*a = b as f32 / 255.0;});
157        Ok(())
158    }
159    // fn state(&self) -> ArrayView<f32, IxDyn>{ self.buf2.view().into_dyn() }
160    fn step(&mut self, action: ArrayD<i32>) -> Result<i32> {
161        let action = self.available_actions
162            [(action.into_dimensionality::<Ix0>()?.into_scalar() - 1) as usize];
163        let reward = self.inner.step(action);
164        self.inner.render_rgb24(self.buf1.as_slice_mut().unwrap());
165        Ok(reward)
166    }
167    fn is_over(&self) -> bool {
168        self.inner.is_game_over()
169    }
170    fn reset(&mut self) {
171        self.inner.reset();
172    }
173}