border_atari_env/atari_env/
ale.rs1use c_str_macro::c_str;
2use std::ffi::CString;
3use std::path::Path;
4use std::path::PathBuf;
5
6#[derive(Copy, Clone, Debug, num_derive::FromPrimitive)]
7#[repr(i32)]
8pub enum AleAction {
9 Noop = 0,
10 Fire = 1,
11 Up = 2,
12 Right = 3,
13 Left = 4,
14 Down = 5,
15 UpRight = 6,
16 UpLeft = 7,
17 DownRight = 8,
18 DownLeft = 9,
19 UpFire = 10,
20 RightFire = 11,
21 LeftFire = 12,
22 DownFire = 13,
23 UpRightFire = 14,
24 UpLeftFire = 15,
25 DownRightFire = 16,
26 DownLeftFire = 17,
27 }
33
34pub struct AleConfig {
35 pub random_seed: i32, pub display_screen: bool,
37 pub sound: bool,
38 pub color_averaging: bool, pub frame_skip: i32, pub repeat_action_probability: f32,
42 pub record_screen_dir: Option<PathBuf>,
43 pub difficulty_setting: i32,
45}
46
47impl Default for AleConfig {
48 fn default() -> Self {
49 Self {
50 random_seed: 0,
51 display_screen: false,
52 sound: false,
53 color_averaging: false, frame_skip: 1,
55 repeat_action_probability: 0.25,
56 record_screen_dir: None,
57 difficulty_setting: 0,
58 }
59 }
60}
61
62pub struct Ale {
63 #[cfg(feature = "atari-env-sys")]
64 inner: *mut atari_env_sys::ALEInterface,
65}
66
67unsafe impl Send for Ale {}
69
70impl Drop for Ale {
71 fn drop(&mut self) {
72 #[cfg(feature = "atari-env-sys")]
73 unsafe {
74 atari_env_sys::ALE_del(self.inner);
75 }
76
77 #[cfg(not(feature = "atari-env-sys"))]
78 unimplemented!();
79 }
80}
81
82impl Ale {
83 pub fn new(rom_path: &Path, config: AleConfig) -> Self {
84 #[cfg(feature = "atari-env-sys")]
85 {
86 let ale = unsafe { atari_env_sys::ALE_new() };
87 unsafe {
88 atari_env_sys::setInt(ale, c_str!("random_seed").as_ptr(), config.random_seed);
89 atari_env_sys::setBool(
90 ale,
91 c_str!("display_screen").as_ptr(),
92 config.display_screen,
93 );
94 atari_env_sys::setBool(ale, c_str!("sound").as_ptr(), config.sound);
95 atari_env_sys::setBool(
96 ale,
97 c_str!("color_averaging").as_ptr(),
98 config.color_averaging,
99 );
100 atari_env_sys::setInt(ale, c_str!("frame_skip").as_ptr(), config.frame_skip);
101 atari_env_sys::setFloat(
102 ale,
103 c_str!("repeat_action_probability").as_ptr(),
104 config.repeat_action_probability,
105 );
106
107 if let Some(path) = config.record_screen_dir {
108 let path = CString::new(path.to_str().unwrap()).unwrap();
109 atari_env_sys::setString(
110 ale,
111 c_str!("record_screen_dir").as_ptr(),
112 path.as_ptr(),
113 );
114 }
115 let rom_path = CString::new(rom_path.to_str().unwrap()).unwrap();
116 atari_env_sys::loadROM(ale, rom_path.as_ptr());
117 }
118 unsafe {
119 atari_env_sys::setDifficulty(ale, config.difficulty_setting);
120 atari_env_sys::reset_game(ale);
121 }
122
123 Self { inner: ale }
124 }
125
126 #[cfg(not(feature = "atari-env-sys"))]
127 unimplemented!();
128 }
129 pub fn available_actions(&self) -> Vec<AleAction> {
130 #[cfg(feature = "atari-env-sys")]
131 {
132 let n = unsafe { atari_env_sys::getLegalActionSize(self.inner) } as usize;
133 let mut buf = vec![AleAction::Noop; n];
134 unsafe {
135 atari_env_sys::getLegalActionSet(self.inner, buf.as_mut_ptr() as *mut i32);
136 }
137 buf
138 }
139
140 #[cfg(not(feature = "atari-env-sys"))]
141 unimplemented!();
142 }
143
144 pub fn minimal_actions(&self) -> Vec<AleAction> {
145 #[cfg(feature = "atari-env-sys")]
146 {
147 let n = unsafe { atari_env_sys::getMinimalActionSize(self.inner) } as usize;
148 let mut buf = vec![AleAction::Noop; n];
149 unsafe {
150 atari_env_sys::getMinimalActionSet(self.inner, buf.as_mut_ptr() as *mut i32);
151 }
152 buf
153 }
154
155 #[cfg(not(feature = "atari-env-sys"))]
156 unimplemented!();
157 }
158
159 pub fn is_game_over(&self) -> bool {
160 #[cfg(feature = "atari-env-sys")]
161 unsafe {
162 atari_env_sys::game_over(self.inner)
163 }
164
165 #[cfg(not(feature = "atari-env-sys"))]
166 unimplemented!();
167 }
168
169 pub fn rom_frame_number(&self) -> i32 {
171 #[cfg(feature = "atari-env-sys")]
172 unsafe {
173 atari_env_sys::getFrameNumber(self.inner)
174 }
175
176 #[cfg(not(feature = "atari-env-sys"))]
177 unimplemented!();
178 }
179
180 pub fn episode_frame_number(&self) -> i32 {
182 #[cfg(feature = "atari-env-sys")]
183 unsafe {
184 atari_env_sys::getEpisodeFrameNumber(self.inner)
185 }
186
187 #[cfg(not(feature = "atari-env-sys"))]
188 unimplemented!();
189 }
190
191 pub fn reset(&mut self) {
192 #[cfg(feature = "atari-env-sys")]
193 unsafe {
194 atari_env_sys::reset_game(self.inner);
195 }
196
197 #[cfg(not(feature = "atari-env-sys"))]
198 unimplemented!();
199 }
200
201 pub fn take_action(&mut self, action: AleAction) -> i32 {
203 #[cfg(feature = "atari-env-sys")]
204 {
205 let ret: ::std::os::raw::c_int =
206 unsafe { atari_env_sys::act(self.inner, action as i32) };
207 ret.into()
208 }
209
210 #[cfg(not(feature = "atari-env-sys"))]
211 unimplemented!();
212 }
213
214 pub fn lives(&self) -> u32 {
215 #[cfg(feature = "atari-env-sys")]
216 unsafe {
217 atari_env_sys::lives(self.inner) as u32
218 }
219
220 #[cfg(not(feature = "atari-env-sys"))]
221 unimplemented!();
222 }
223
224 pub fn available_difficulty_settings(&self) -> Vec<i32> {
225 #[cfg(feature = "atari-env-sys")]
226 {
227 let n = unsafe { atari_env_sys::getAvailableDifficultiesSize(self.inner) } as usize;
228 let mut buf = vec![0i32; n];
229 unsafe {
230 atari_env_sys::getAvailableDifficulties(self.inner, buf.as_mut_ptr() as *mut i32);
231 }
232 buf
233 }
234
235 #[cfg(not(feature = "atari-env-sys"))]
236 unimplemented!();
237 }
238
239 pub fn width(&self) -> u32 {
240 #[cfg(feature = "atari-env-sys")]
241 unsafe {
242 atari_env_sys::getScreenWidth(self.inner) as u32
243 }
244
245 #[cfg(not(feature = "atari-env-sys"))]
246 unimplemented!();
247 }
248 pub fn height(&self) -> u32 {
249 #[cfg(feature = "atari-env-sys")]
250 unsafe {
251 atari_env_sys::getScreenHeight(self.inner) as u32
252 }
253
254 #[cfg(not(feature = "atari-env-sys"))]
255 unimplemented!();
256 }
257
258 pub fn rgb24_size(&self) -> usize {
259 #[cfg(feature = "atari-env-sys")]
260 return (self.width() as usize) * (self.height() as usize) * 3;
261
262 #[cfg(not(feature = "atari-env-sys"))]
263 unimplemented!();
264 }
265 pub fn rgb24_native_endian(&self, buf: &mut [u8]) {
267 #[cfg(feature = "atari-env-sys")]
268 unsafe {
269 atari_env_sys::getScreenRGB(self.inner, buf.as_mut_ptr());
270 }
271
272 #[cfg(not(feature = "atari-env-sys"))]
273 unimplemented!();
274 }
275 pub fn rgb24(&self, buf: &mut [u8]) {
277 #[cfg(feature = "atari-env-sys")]
278 unsafe {
279 atari_env_sys::getScreenRGB2(self.inner, buf.as_mut_ptr());
280 }
281
282 #[cfg(not(feature = "atari-env-sys"))]
283 unimplemented!();
284 }
285
286 pub fn rgb32_size(&self) -> usize {
287 #[cfg(feature = "atari-env-sys")]
288 return (self.width() as usize) * (self.height() as usize) * 4;
289
290 #[cfg(not(feature = "atari-env-sys"))]
291 unimplemented!();
292 }
293 pub fn rgb32(&self, buf: &mut [u8]) {
295 #[cfg(feature = "atari-env-sys")]
296 {
297 let n = buf.len() / 4;
298 self.rgb24(&mut buf[n..]);
299 for i in 0..n {
300 buf[i * 4 + 0] = buf[n + (i * 3) + 0];
301 buf[i * 4 + 1] = buf[n + (i * 3) + 1];
302 buf[i * 4 + 2] = buf[n + (i * 3) + 2];
303 buf[i * 4 + 3] = 0;
304 }
305 }
306
307 #[cfg(not(feature = "atari-env-sys"))]
308 unimplemented!();
309 }
310
311 pub fn ram_size(&self) -> usize {
312 #[cfg(feature = "atari-env-sys")]
313 unsafe {
314 atari_env_sys::getRAMSize(self.inner) as usize
315 }
316
317 #[cfg(not(feature = "atari-env-sys"))]
318 unimplemented!();
319 }
320 pub fn ram(&self, buf: &mut [u8]) {
321 #[cfg(feature = "atari-env-sys")]
322 unsafe {
323 atari_env_sys::getRAM(self.inner, buf.as_mut_ptr());
324 }
325
326 #[cfg(not(feature = "atari-env-sys"))]
327 unimplemented!();
328 }
329
330 pub fn save_png<P: AsRef<Path>>(&self, path: P) {
331 #[cfg(feature = "atari-env-sys")]
332 {
333 use std::os::unix::ffi::OsStrExt;
334 let path = path.as_ref();
335 let path = CString::new(path.as_os_str().as_bytes()).unwrap();
336 unsafe {
337 atari_env_sys::saveScreenPNG(self.inner, path.as_ptr());
338 }
339 }
340
341 #[cfg(not(feature = "atari-env-sys"))]
342 unimplemented!();
343 }
344
345 pub fn seed(&self, seed: i32) {
347 #[cfg(feature = "atari-env-sys")]
348 unsafe {
349 atari_env_sys::setInt(self.inner, c_str!("random_seed").as_ptr(), seed);
350 }
351
352 #[cfg(not(feature = "atari-env-sys"))]
353 unimplemented!();
354 }
355}