use c_str_macro::c_str;
use std::ffi::CString;
use std::path::Path;
use std::path::PathBuf;
#[derive(Copy, Clone, Debug, num_derive::FromPrimitive)]
#[repr(i32)]
pub enum AleAction {
Noop = 0,
Fire = 1,
Up = 2,
Right = 3,
Left = 4,
Down = 5,
UpRight = 6,
UpLeft = 7,
DownRight = 8,
DownLeft = 9,
UpFire = 10,
RightFire = 11,
LeftFire = 12,
DownFire = 13,
UpRightFire = 14,
UpLeftFire = 15,
DownRightFire = 16,
DownLeftFire = 17,
}
pub struct AleConfig {
pub random_seed: i32, pub display_screen: bool,
pub sound: bool,
pub color_averaging: bool, pub frame_skip: i32, pub repeat_action_probability: f32,
pub record_screen_dir: Option<PathBuf>,
pub difficulty_setting: i32,
}
impl Default for AleConfig {
fn default() -> Self {
Self {
random_seed: 0,
display_screen: false,
sound: false,
color_averaging: false, frame_skip: 1,
repeat_action_probability: 0.25,
record_screen_dir: None,
difficulty_setting: 0,
}
}
}
pub struct Ale {
#[cfg(feature = "atari-env-sys")]
inner: *mut atari_env_sys::ALEInterface,
}
unsafe impl Send for Ale {}
impl Drop for Ale {
fn drop(&mut self) {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::ALE_del(self.inner);
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
}
impl Ale {
pub fn new(rom_path: &Path, config: AleConfig) -> Self {
#[cfg(feature = "atari-env-sys")]
{
let ale = unsafe { atari_env_sys::ALE_new() };
unsafe {
atari_env_sys::setInt(ale, c_str!("random_seed").as_ptr(), config.random_seed);
atari_env_sys::setBool(
ale,
c_str!("display_screen").as_ptr(),
config.display_screen,
);
atari_env_sys::setBool(ale, c_str!("sound").as_ptr(), config.sound);
atari_env_sys::setBool(
ale,
c_str!("color_averaging").as_ptr(),
config.color_averaging,
);
atari_env_sys::setInt(ale, c_str!("frame_skip").as_ptr(), config.frame_skip);
atari_env_sys::setFloat(
ale,
c_str!("repeat_action_probability").as_ptr(),
config.repeat_action_probability,
);
if let Some(path) = config.record_screen_dir {
let path = CString::new(path.to_str().unwrap()).unwrap();
atari_env_sys::setString(
ale,
c_str!("record_screen_dir").as_ptr(),
path.as_ptr(),
);
}
let rom_path = CString::new(rom_path.to_str().unwrap()).unwrap();
atari_env_sys::loadROM(ale, rom_path.as_ptr());
}
unsafe {
atari_env_sys::setDifficulty(ale, config.difficulty_setting);
atari_env_sys::reset_game(ale);
}
Self { inner: ale }
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn available_actions(&self) -> Vec<AleAction> {
#[cfg(feature = "atari-env-sys")]
{
let n = unsafe { atari_env_sys::getLegalActionSize(self.inner) } as usize;
let mut buf = vec![AleAction::Noop; n];
unsafe {
atari_env_sys::getLegalActionSet(self.inner, buf.as_mut_ptr() as *mut i32);
}
buf
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn minimal_actions(&self) -> Vec<AleAction> {
#[cfg(feature = "atari-env-sys")]
{
let n = unsafe { atari_env_sys::getMinimalActionSize(self.inner) } as usize;
let mut buf = vec![AleAction::Noop; n];
unsafe {
atari_env_sys::getMinimalActionSet(self.inner, buf.as_mut_ptr() as *mut i32);
}
buf
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn is_game_over(&self) -> bool {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::game_over(self.inner)
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn rom_frame_number(&self) -> i32 {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::getFrameNumber(self.inner)
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn episode_frame_number(&self) -> i32 {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::getEpisodeFrameNumber(self.inner)
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn reset(&mut self) {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::reset_game(self.inner);
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn take_action(&mut self, action: AleAction) -> i32 {
#[cfg(feature = "atari-env-sys")]
{
let ret: ::std::os::raw::c_int =
unsafe { atari_env_sys::act(self.inner, action as i32) };
ret.into()
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn lives(&self) -> u32 {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::lives(self.inner) as u32
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn available_difficulty_settings(&self) -> Vec<i32> {
#[cfg(feature = "atari-env-sys")]
{
let n = unsafe { atari_env_sys::getAvailableDifficultiesSize(self.inner) } as usize;
let mut buf = vec![0i32; n];
unsafe {
atari_env_sys::getAvailableDifficulties(self.inner, buf.as_mut_ptr() as *mut i32);
}
buf
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn width(&self) -> u32 {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::getScreenWidth(self.inner) as u32
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn height(&self) -> u32 {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::getScreenHeight(self.inner) as u32
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn rgb24_size(&self) -> usize {
#[cfg(feature = "atari-env-sys")]
return (self.width() as usize) * (self.height() as usize) * 3;
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn rgb24_native_endian(&self, buf: &mut [u8]) {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::getScreenRGB(self.inner, buf.as_mut_ptr());
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn rgb24(&self, buf: &mut [u8]) {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::getScreenRGB2(self.inner, buf.as_mut_ptr());
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn rgb32_size(&self) -> usize {
#[cfg(feature = "atari-env-sys")]
return (self.width() as usize) * (self.height() as usize) * 4;
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn rgb32(&self, buf: &mut [u8]) {
#[cfg(feature = "atari-env-sys")]
{
let n = buf.len() / 4;
self.rgb24(&mut buf[n..]);
for i in 0..n {
buf[i * 4 + 0] = buf[n + (i * 3) + 0];
buf[i * 4 + 1] = buf[n + (i * 3) + 1];
buf[i * 4 + 2] = buf[n + (i * 3) + 2];
buf[i * 4 + 3] = 0;
}
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn ram_size(&self) -> usize {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::getRAMSize(self.inner) as usize
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn ram(&self, buf: &mut [u8]) {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::getRAM(self.inner, buf.as_mut_ptr());
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn save_png<P: AsRef<Path>>(&self, path: P) {
#[cfg(feature = "atari-env-sys")]
{
use std::os::unix::ffi::OsStrExt;
let path = path.as_ref();
let path = CString::new(path.as_os_str().as_bytes()).unwrap();
unsafe {
atari_env_sys::saveScreenPNG(self.inner, path.as_ptr());
}
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
pub fn seed(&self, seed: i32) {
#[cfg(feature = "atari-env-sys")]
unsafe {
atari_env_sys::setInt(self.inner, c_str!("random_seed").as_ptr(), seed);
}
#[cfg(not(feature = "atari-env-sys"))]
unimplemented!();
}
}