border_atari_env/
atari_env.rs1pub mod ale;
3use std::path::Path;
4
5pub use ale::Ale;
6pub use ale::AleAction as AtariAction;
7pub use ale::AleConfig as EmulatorConfig;
8use gym_core::{ActionSpace, CategoricalActionSpace, GymEnv};
9
10use anyhow::{Context, Result};
11use ndarray::{Array1, ArrayD, ArrayView3, Ix0, Ix1, Ix3};
12use num_traits::cast::FromPrimitive;
13
14pub struct AtariEnv {
15 ale: Ale,
16}
17
18impl AtariEnv {
19 pub fn new<P: AsRef<Path>>(rom_path: P, emulator_config: EmulatorConfig) -> Self {
22 Self {
23 ale: Ale::new(rom_path.as_ref(), emulator_config),
24 }
25 }
26 pub fn width(&self) -> usize {
27 self.ale.width() as usize
28 }
29 pub fn height(&self) -> usize {
30 self.ale.height() as usize
31 }
32 pub fn available_actions(&self) -> Vec<AtariAction> {
33 self.ale.available_actions()
34 }
35 pub fn minimal_actions(&self) -> Vec<AtariAction> {
36 self.ale.minimal_actions()
37 }
38 pub fn available_difficulty_settings(&self) -> Vec<i32> {
39 self.ale.available_difficulty_settings()
40 }
41 pub fn lives(&self) -> usize {
42 self.ale.lives() as usize
43 }
44 pub fn is_game_over(&self) -> bool {
45 self.ale.is_game_over()
46 }
47 pub fn reset(&mut self) {
48 self.ale.reset()
49 }
50 pub fn step(&mut self, action: AtariAction) -> i32 {
51 self.ale.take_action(action)
52 }
53 pub fn rgb32_size(&self) -> usize {
54 self.ale.rgb32_size()
55 }
56 pub fn rgb24_size(&self) -> usize {
57 self.ale.rgb24_size()
58 }
59 pub fn ram_size(&self) -> usize {
60 self.ale.ram_size()
61 }
62 pub fn render_rgb32(&self, buf: &mut [u8]) {
63 self.ale.rgb32(buf);
64 }
65 pub fn render_rgb24(&self, buf: &mut [u8]) {
66 self.ale.rgb24(buf);
67 }
68 pub fn render_ram(&self, buf: &mut [u8]) {
69 self.ale.ram(buf);
70 }
71 pub fn into_ram_env(self) -> AtariRamEnv {
72 AtariRamEnv::new(self)
73 }
74 pub fn into_rgb_env(self) -> AtariRgbEnv {
75 AtariRgbEnv::new(self)
76 }
77 pub fn seed(&self, seed: i32) {
78 self.ale.seed(seed);
79 }
80}
81
82pub struct AtariRamEnv {
83 buf1: Array1<u8>,
84 inner: AtariEnv,
85 available_actions: Vec<AtariAction>,
86}
87
88pub struct AtariRgbEnv {
89 buf1: Array1<u8>,
90 inner: AtariEnv,
91 available_actions: Vec<AtariAction>,
92}
93
94impl AtariRamEnv {
95 pub fn new(env: AtariEnv) -> Self {
96 Self {
97 buf1: Array1::zeros(env.ram_size()),
98 available_actions: env.minimal_actions(),
99 inner: env,
100 }
101 }
102}
103
104impl GymEnv<i32> for AtariRamEnv {
105 fn state_size(&self) -> Vec<usize> {
106 vec![self.inner.ram_size()]
107 }
108 fn action_space(&self) -> ActionSpace<i32> {
109 Box::new(CategoricalActionSpace::new(self.available_actions.len()))
110 }
111 fn state(&self, out: ndarray::ArrayViewMut<f32, ndarray::IxDyn>) -> Result<()> {
112 let mut out = out.into_dimensionality::<Ix1>()?;
113 ndarray::parallel::par_azip!((a in &mut out, &b in &self.buf1) {*a = b as f32 / 255.0;});
114 Ok(())
115 }
116 fn step(&mut self, action: ArrayD<i32>) -> Result<i32> {
117 let action = AtariAction::from_i32(action.into_dimensionality::<Ix0>()?.into_scalar())
118 .context("action out of range")?;
119 let reward = self.inner.step(action);
120 self.inner.render_ram(self.buf1.as_slice_mut().unwrap());
121 Ok(reward)
122 }
123 fn is_over(&self) -> bool {
124 self.inner.is_game_over()
125 }
126 fn reset(&mut self) {
127 self.inner.reset();
128 }
129}
130
131impl AtariRgbEnv {
132 pub fn new(env: AtariEnv) -> Self {
133 Self {
134 buf1: Array1::zeros(env.rgb24_size()),
135 available_actions: env.minimal_actions(),
136 inner: env,
137 }
138 }
139}
140
141impl GymEnv<i32> for AtariRgbEnv {
142 fn state_size(&self) -> Vec<usize> {
143 vec![self.inner.height(), self.inner.width(), 3]
144 }
145 fn action_space(&self) -> ActionSpace<i32> {
146 Box::new(CategoricalActionSpace::new(self.available_actions.len()))
147 }
148 fn state(&self, out: ndarray::ArrayViewMut<f32, ndarray::IxDyn>) -> Result<()> {
149 let mut out = out.into_dimensionality::<Ix3>()?;
150 let from: ArrayView3<_> = self
151 .buf1
152 .view()
153 .into_shape(self.state_size())?
154 .into_dimensionality()?;
155 ndarray::parallel::par_azip!((a in &mut out, &b in &from) {*a = b as f32 / 255.0;});
156 Ok(())
158 }
159 fn step(&mut self, action: ArrayD<i32>) -> Result<i32> {
161 let action = self.available_actions
162 [(action.into_dimensionality::<Ix0>()?.into_scalar() - 1) as usize];
163 let reward = self.inner.step(action);
164 self.inner.render_rgb24(self.buf1.as_slice_mut().unwrap());
165 Ok(reward)
166 }
167 fn is_over(&self) -> bool {
168 self.inner.is_game_over()
169 }
170 fn reset(&mut self) {
171 self.inner.reset();
172 }
173}