border_atari_env/atari_env/
ale.rs

1use c_str_macro::c_str;
2use std::ffi::CString;
3use std::path::Path;
4use std::path::PathBuf;
5
6#[derive(Copy, Clone, Debug, num_derive::FromPrimitive)]
7#[repr(i32)]
8pub enum AleAction {
9    Noop = 0,
10    Fire = 1,
11    Up = 2,
12    Right = 3,
13    Left = 4,
14    Down = 5,
15    UpRight = 6,
16    UpLeft = 7,
17    DownRight = 8,
18    DownLeft = 9,
19    UpFire = 10,
20    RightFire = 11,
21    LeftFire = 12,
22    DownFire = 13,
23    UpRightFire = 14,
24    UpLeftFire = 15,
25    DownRightFire = 16,
26    DownLeftFire = 17,
27    // admin actions
28    // Reset = 40,
29    // SaveState = 43,
30    // LoadState = 44,
31    // SystemReset = 45,
32}
33
34pub struct AleConfig {
35    pub random_seed: i32, // if 0, set to time
36    pub display_screen: bool,
37    pub sound: bool,
38    pub color_averaging: bool, // average the last 2 frames
39    ///
40    pub frame_skip: i32, // 1 is no skip
41    pub repeat_action_probability: f32,
42    pub record_screen_dir: Option<PathBuf>,
43    ///
44    pub difficulty_setting: i32,
45}
46
47impl Default for AleConfig {
48    fn default() -> Self {
49        Self {
50            random_seed: 0,
51            display_screen: false,
52            sound: false,
53            color_averaging: false, // true is recommended
54            frame_skip: 1,
55            repeat_action_probability: 0.25,
56            record_screen_dir: None,
57            difficulty_setting: 0,
58        }
59    }
60}
61
62pub struct Ale {
63    #[cfg(feature = "atari-env-sys")]
64    inner: *mut atari_env_sys::ALEInterface,
65}
66
67// TODO: it seems to work, but needs to be verified
68unsafe impl Send for Ale {}
69
70impl Drop for Ale {
71    fn drop(&mut self) {
72        #[cfg(feature = "atari-env-sys")]
73        unsafe {
74            atari_env_sys::ALE_del(self.inner);
75        }
76
77        #[cfg(not(feature = "atari-env-sys"))]
78        unimplemented!();
79    }
80}
81
82impl Ale {
83    pub fn new(rom_path: &Path, config: AleConfig) -> Self {
84        #[cfg(feature = "atari-env-sys")]
85        {
86            let ale = unsafe { atari_env_sys::ALE_new() };
87            unsafe {
88                atari_env_sys::setInt(ale, c_str!("random_seed").as_ptr(), config.random_seed);
89                atari_env_sys::setBool(
90                    ale,
91                    c_str!("display_screen").as_ptr(),
92                    config.display_screen,
93                );
94                atari_env_sys::setBool(ale, c_str!("sound").as_ptr(), config.sound);
95                atari_env_sys::setBool(
96                    ale,
97                    c_str!("color_averaging").as_ptr(),
98                    config.color_averaging,
99                );
100                atari_env_sys::setInt(ale, c_str!("frame_skip").as_ptr(), config.frame_skip);
101                atari_env_sys::setFloat(
102                    ale,
103                    c_str!("repeat_action_probability").as_ptr(),
104                    config.repeat_action_probability,
105                );
106
107                if let Some(path) = config.record_screen_dir {
108                    let path = CString::new(path.to_str().unwrap()).unwrap();
109                    atari_env_sys::setString(
110                        ale,
111                        c_str!("record_screen_dir").as_ptr(),
112                        path.as_ptr(),
113                    );
114                }
115                let rom_path = CString::new(rom_path.to_str().unwrap()).unwrap();
116                atari_env_sys::loadROM(ale, rom_path.as_ptr());
117            }
118            unsafe {
119                atari_env_sys::setDifficulty(ale, config.difficulty_setting);
120                atari_env_sys::reset_game(ale);
121            }
122
123            Self { inner: ale }
124        }
125
126        #[cfg(not(feature = "atari-env-sys"))]
127        unimplemented!();
128    }
129    pub fn available_actions(&self) -> Vec<AleAction> {
130        #[cfg(feature = "atari-env-sys")]
131        {
132            let n = unsafe { atari_env_sys::getLegalActionSize(self.inner) } as usize;
133            let mut buf = vec![AleAction::Noop; n];
134            unsafe {
135                atari_env_sys::getLegalActionSet(self.inner, buf.as_mut_ptr() as *mut i32);
136            }
137            buf
138        }
139
140        #[cfg(not(feature = "atari-env-sys"))]
141        unimplemented!();
142    }
143
144    pub fn minimal_actions(&self) -> Vec<AleAction> {
145        #[cfg(feature = "atari-env-sys")]
146        {
147            let n = unsafe { atari_env_sys::getMinimalActionSize(self.inner) } as usize;
148            let mut buf = vec![AleAction::Noop; n];
149            unsafe {
150                atari_env_sys::getMinimalActionSet(self.inner, buf.as_mut_ptr() as *mut i32);
151            }
152            buf
153        }
154
155        #[cfg(not(feature = "atari-env-sys"))]
156        unimplemented!();
157    }
158
159    pub fn is_game_over(&self) -> bool {
160        #[cfg(feature = "atari-env-sys")]
161        unsafe {
162            atari_env_sys::game_over(self.inner)
163        }
164
165        #[cfg(not(feature = "atari-env-sys"))]
166        unimplemented!();
167    }
168
169    /// frame number since rom loading (Ale::new)
170    pub fn rom_frame_number(&self) -> i32 {
171        #[cfg(feature = "atari-env-sys")]
172        unsafe {
173            atari_env_sys::getFrameNumber(self.inner)
174        }
175
176        #[cfg(not(feature = "atari-env-sys"))]
177        unimplemented!();
178    }
179
180    /// frame number of the current episode
181    pub fn episode_frame_number(&self) -> i32 {
182        #[cfg(feature = "atari-env-sys")]
183        unsafe {
184            atari_env_sys::getEpisodeFrameNumber(self.inner)
185        }
186
187        #[cfg(not(feature = "atari-env-sys"))]
188        unimplemented!();
189    }
190
191    pub fn reset(&mut self) {
192        #[cfg(feature = "atari-env-sys")]
193        unsafe {
194            atari_env_sys::reset_game(self.inner);
195        }
196
197        #[cfg(not(feature = "atari-env-sys"))]
198        unimplemented!();
199    }
200
201    /// returns reward
202    pub fn take_action(&mut self, action: AleAction) -> i32 {
203        #[cfg(feature = "atari-env-sys")]
204        {
205            let ret: ::std::os::raw::c_int =
206                unsafe { atari_env_sys::act(self.inner, action as i32) };
207            ret.into()
208        }
209
210        #[cfg(not(feature = "atari-env-sys"))]
211        unimplemented!();
212    }
213
214    pub fn lives(&self) -> u32 {
215        #[cfg(feature = "atari-env-sys")]
216        unsafe {
217            atari_env_sys::lives(self.inner) as u32
218        }
219
220        #[cfg(not(feature = "atari-env-sys"))]
221        unimplemented!();
222    }
223
224    pub fn available_difficulty_settings(&self) -> Vec<i32> {
225        #[cfg(feature = "atari-env-sys")]
226        {
227            let n = unsafe { atari_env_sys::getAvailableDifficultiesSize(self.inner) } as usize;
228            let mut buf = vec![0i32; n];
229            unsafe {
230                atari_env_sys::getAvailableDifficulties(self.inner, buf.as_mut_ptr() as *mut i32);
231            }
232            buf
233        }
234
235        #[cfg(not(feature = "atari-env-sys"))]
236        unimplemented!();
237    }
238
239    pub fn width(&self) -> u32 {
240        #[cfg(feature = "atari-env-sys")]
241        unsafe {
242            atari_env_sys::getScreenWidth(self.inner) as u32
243        }
244
245        #[cfg(not(feature = "atari-env-sys"))]
246        unimplemented!();
247    }
248    pub fn height(&self) -> u32 {
249        #[cfg(feature = "atari-env-sys")]
250        unsafe {
251            atari_env_sys::getScreenHeight(self.inner) as u32
252        }
253
254        #[cfg(not(feature = "atari-env-sys"))]
255        unimplemented!();
256    }
257
258    pub fn rgb24_size(&self) -> usize {
259        #[cfg(feature = "atari-env-sys")]
260        return (self.width() as usize) * (self.height() as usize) * 3;
261
262        #[cfg(not(feature = "atari-env-sys"))]
263        unimplemented!();
264    }
265    /// bgr on little-endian, rgb on big-endian
266    pub fn rgb24_native_endian(&self, buf: &mut [u8]) {
267        #[cfg(feature = "atari-env-sys")]
268        unsafe {
269            atari_env_sys::getScreenRGB(self.inner, buf.as_mut_ptr());
270        }
271
272        #[cfg(not(feature = "atari-env-sys"))]
273        unimplemented!();
274    }
275    /// always rgb in regardless of endianness
276    pub fn rgb24(&self, buf: &mut [u8]) {
277        #[cfg(feature = "atari-env-sys")]
278        unsafe {
279            atari_env_sys::getScreenRGB2(self.inner, buf.as_mut_ptr());
280        }
281
282        #[cfg(not(feature = "atari-env-sys"))]
283        unimplemented!();
284    }
285
286    pub fn rgb32_size(&self) -> usize {
287        #[cfg(feature = "atari-env-sys")]
288        return (self.width() as usize) * (self.height() as usize) * 4;
289
290        #[cfg(not(feature = "atari-env-sys"))]
291        unimplemented!();
292    }
293    /// always rgb in regardless of endianness
294    pub fn rgb32(&self, buf: &mut [u8]) {
295        #[cfg(feature = "atari-env-sys")]
296        {
297            let n = buf.len() / 4;
298            self.rgb24(&mut buf[n..]);
299            for i in 0..n {
300                buf[i * 4 + 0] = buf[n + (i * 3) + 0];
301                buf[i * 4 + 1] = buf[n + (i * 3) + 1];
302                buf[i * 4 + 2] = buf[n + (i * 3) + 2];
303                buf[i * 4 + 3] = 0;
304            }
305        }
306
307        #[cfg(not(feature = "atari-env-sys"))]
308        unimplemented!();
309    }
310
311    pub fn ram_size(&self) -> usize {
312        #[cfg(feature = "atari-env-sys")]
313        unsafe {
314            atari_env_sys::getRAMSize(self.inner) as usize
315        }
316
317        #[cfg(not(feature = "atari-env-sys"))]
318        unimplemented!();
319    }
320    pub fn ram(&self, buf: &mut [u8]) {
321        #[cfg(feature = "atari-env-sys")]
322        unsafe {
323            atari_env_sys::getRAM(self.inner, buf.as_mut_ptr());
324        }
325
326        #[cfg(not(feature = "atari-env-sys"))]
327        unimplemented!();
328    }
329
330    pub fn save_png<P: AsRef<Path>>(&self, path: P) {
331        #[cfg(feature = "atari-env-sys")]
332        {
333            use std::os::unix::ffi::OsStrExt;
334            let path = path.as_ref();
335            let path = CString::new(path.as_os_str().as_bytes()).unwrap();
336            unsafe {
337                atari_env_sys::saveScreenPNG(self.inner, path.as_ptr());
338            }
339        }
340
341        #[cfg(not(feature = "atari-env-sys"))]
342        unimplemented!();
343    }
344
345    /// Sets the random seed.
346    pub fn seed(&self, seed: i32) {
347        #[cfg(feature = "atari-env-sys")]
348        unsafe {
349            atari_env_sys::setInt(self.inner, c_str!("random_seed").as_ptr(), seed);
350        }
351
352        #[cfg(not(feature = "atari-env-sys"))]
353        unimplemented!();
354    }
355}