use std::ptr::null_mut;
use std::ffi::{CString, CStr};
use std::convert::TryInto;
use std::os::raw::c_int;
use std::io;
pub struct Ale {
ptr: *mut ale_sys::ALEInterface,
available_difficulties: Vec<i32>,
available_modes: Vec<i32>,
legal_actions: Vec<i32>,
minimal_actions: Vec<i32>,
}
impl Ale {
pub fn new() -> Ale {
let ptr = unsafe { ale_sys::ALE_new() };
assert!(ptr != null_mut());
Ale {
ptr,
available_difficulties: vec![],
available_modes: vec![],
legal_actions: vec![],
minimal_actions: vec![],
}
}
pub fn load_rom(&mut self, rom: BundledRom) -> io::Result<()> {
let dir = tempdir::TempDir::new("ale-rs")?;
let rom_path = dir.path().join(rom.filename());
std::fs::write(&rom_path, rom.data())?;
let rom_path_string = rom_path.to_string_lossy().to_string();
let rom_path_c_str = CString::new(rom_path_string).expect("Invalid path");
self.load_rom_file(&rom_path_c_str);
Ok(())
}
pub fn load_rom_file(&mut self, rom_file: &CStr) {
unsafe { ale_sys::loadROM(self.ptr, rom_file.as_ptr()); }
}
pub fn act(&mut self, action: i32) -> i32 {
assert!(self.legal_action_set().contains(&action), "Illegal action: {}", action);
unsafe { ale_sys::act(self.ptr, action) }
}
pub fn is_game_over(&mut self) -> bool {
unsafe { ale_sys::game_over(self.ptr) }
}
pub fn reset_game(&mut self) {
unsafe { ale_sys::reset_game(self.ptr); }
}
pub fn available_modes(&mut self) -> &[i32] {
let size = unsafe { ale_sys::getAvailableModesSize(self.ptr) };
assert!(size >= 0);
self.available_modes.resize(size as usize, 0);
unsafe { ale_sys::getAvailableModes(self.ptr, self.available_modes.as_mut_ptr()); }
&self.available_modes
}
pub fn set_mode(&mut self, mode: i32) {
assert!(self.available_modes().contains(&mode), "Invalid mode: {}", mode);
unsafe { ale_sys::setMode(self.ptr, mode); }
}
pub fn available_difficulties(&mut self) -> &[i32] {
let size = unsafe { ale_sys::getAvailableDifficultiesSize(self.ptr) };
assert!(size >= 0);
self.available_difficulties.resize(size as usize, 0);
unsafe { ale_sys::getAvailableDifficulties(self.ptr, self.available_difficulties.as_mut_ptr()); }
&self.available_difficulties
}
pub fn set_difficulty(&mut self, difficulty: i32) {
assert!(self.available_difficulties().contains(&difficulty), "Invalid difficulty: {}", difficulty);
unsafe { ale_sys::setDifficulty(self.ptr, difficulty); }
}
pub fn legal_action_set(&mut self) -> &[i32] {
let size = unsafe { ale_sys::getLegalActionSize(self.ptr) };
assert!(size >= 0);
self.legal_actions.resize(size as usize, 0);
unsafe { ale_sys::getLegalActionSet(self.ptr, self.legal_actions.as_mut_ptr()); }
&self.legal_actions
}
pub fn minimal_action_set(&mut self) -> &[i32] {
let size = unsafe { ale_sys::getMinimalActionSize(self.ptr) };
assert!(size >= 0);
self.minimal_actions.resize(size as usize, 0);
unsafe { ale_sys::getMinimalActionSet(self.ptr, self.minimal_actions.as_mut_ptr()); }
&self.minimal_actions
}
pub fn frame_number(&mut self) -> i32 {
unsafe { ale_sys::getFrameNumber(self.ptr) as i32 }
}
pub fn lives(&mut self) -> i32 {
unsafe { ale_sys::lives(self.ptr) }
}
pub fn episode_frame_number(&mut self) -> i32 {
unsafe { ale_sys::getEpisodeFrameNumber(self.ptr) }
}
pub fn get_ram(&mut self, ram: &mut [u8]) {
assert!(ram.len() >= self.ram_size());
unsafe { ale_sys::getRAM(self.ptr, ram.as_mut_ptr()); }
}
pub fn ram_size(&mut self) -> usize {
unsafe { ale_sys::getRAMSize(self.ptr) }.try_into().expect("invalid size")
}
pub fn screen_width(&mut self) -> usize {
unsafe { ale_sys::getScreenWidth(self.ptr) }.try_into().expect("invalid size")
}
pub fn screen_height(&mut self) -> usize {
unsafe { ale_sys::getScreenHeight(self.ptr) }.try_into().expect("invalid size")
}
pub fn get_screen_rgb(&mut self, screen_data: &mut [u8]) {
assert!(screen_data.len() >= self.screen_width() * self.screen_height() * 3);
unsafe { ale_sys::getScreenRGB(self.ptr, screen_data.as_mut_ptr()); }
}
pub fn get_screen_grayscale(&mut self, screen_data: &mut [u8]) {
assert!(screen_data.len() >= self.screen_width() * self.screen_height());
unsafe { ale_sys::getScreenGrayscale(self.ptr, screen_data.as_mut_ptr()); }
}
pub fn save_state(&mut self) {
unsafe { ale_sys::saveState(self.ptr); }
}
pub fn load_state(&mut self) {
unsafe { ale_sys::loadState(self.ptr); }
}
pub fn clone_state(&mut self) -> AleState {
AleState {
ptr: unsafe { ale_sys::cloneState(self.ptr) },
}
}
pub fn restore_state(&mut self, state: &AleState) {
unsafe { ale_sys::restoreState(self.ptr, state.ptr); }
}
pub fn clone_system_state(&mut self) -> AleState {
AleState {
ptr: unsafe { ale_sys::cloneSystemState(self.ptr) },
}
}
pub fn restore_system_state(&mut self, state: &AleState) {
unsafe { ale_sys::restoreSystemState(self.ptr, state.ptr); }
}
pub unsafe fn save_screen_png(&mut self, filename: &CStr) {
ale_sys::saveScreenPNG(self.ptr, filename.as_ptr());
}
pub fn set_logger_mode(mode: LoggerMode) {
unsafe { ale_sys::setLoggerMode(mode as c_int); }
}
}
impl Drop for Ale {
fn drop(&mut self) {
unsafe {
let ptr = self.ptr;
self.ptr = std::ptr::null_mut();
ale_sys::ALE_del(ptr);
}
}
}
pub struct AleState {
ptr: *mut ale_sys::ALEState,
}
impl AleState {
pub fn encode_state(&self, buf: &mut [u8]) {
assert!(buf.len() >= self.encode_state_len(), "Buffer not long enough to store encoded state. Expected {}, got {}", self.encode_state_len(), buf.len());
unsafe { ale_sys::encodeState(self.ptr, buf.as_mut_ptr() as *mut _, buf.len() as c_int); }
}
pub fn encode_state_len(&self) -> usize {
let size = unsafe { ale_sys::encodeStateLen(self.ptr) };
assert!(size >= 0, "Invalid size: {}", size);
size as usize
}
pub fn decode_state(serialized: &[u8]) -> AleState {
let len: c_int = serialized.len().try_into().expect("Length too long");
AleState {
ptr: unsafe { ale_sys::decodeState(serialized.as_ptr() as *const _, len) },
}
}
}
impl Drop for AleState {
fn drop(&mut self) {
unsafe {
let ptr = self.ptr;
self.ptr = std::ptr::null_mut();
ale_sys::deleteState(ptr);
}
}
}
pub enum LoggerMode {
Info = 0,
Warning = 1,
Error = 2,
}
pub enum BundledRom {
Adventure,
AirRaid,
Alien,
Amidar,
Assault,
Asterix,
Asteroids,
Atlantis,
BankHeist,
BattleZone,
BeamRider,
Berzerk,
Bowling,
Boxing,
Breakout,
Carnival,
Centipede,
ChopperCommand,
CrazyClimber,
Defender,
DemonAttack,
DoubleDunk,
ElevatorAction,
Enduro,
FishingDerby,
Freeway,
Frostbite,
Gopher,
Gravitar,
Hero,
IceHockey,
JamesBond,
JourneyEscape,
Kaboom,
Kangaroo,
Krull,
KungFuMaster,
MontezumaRevenge,
MsPacman,
NameThisGame,
Phoenix,
Pitfall,
Pong,
Pooyan,
PrivateEye,
QBert,
RiverRaid,
RoadRunner,
RoboTank,
Seaquest,
Skiing,
SpaceInvaders,
StarGunner,
Tennis,
TimePilot,
Tutankham,
UpNDown,
Venture,
VideoPinball,
WizardOfWor,
YarsRevenge,
Zaxxon,
}
impl BundledRom {
pub fn filename(&self) -> &'static str {
use BundledRom::*;
match self {
Adventure => "adventure.bin",
AirRaid => "air_raid.bin",
Alien => "alien.bin",
Amidar => "amidar.bin",
Assault => "assault.bin",
Asterix => "asterix.bin",
Asteroids => "asteroids.bin",
Atlantis => "atlantis.bin",
BankHeist => "bank_heist.bin",
BattleZone => "battle_zone.bin",
BeamRider => "beam_rider.bin",
Berzerk => "berzerk.bin",
Bowling => "bowling.bin",
Boxing => "boxing.bin",
Breakout => "breakout.bin",
Carnival => "carnival.bin",
Centipede => "centipede.bin",
ChopperCommand => "chopper_command.bin",
CrazyClimber => "crazy_climber.bin",
Defender => "defender.bin",
DemonAttack => "demon_attack.bin",
DoubleDunk => "double_dunk.bin",
ElevatorAction => "elevator_action.bin",
Enduro => "enduro.bin",
FishingDerby => "fishing_derby.bin",
Freeway => "freeway.bin",
Frostbite => "frostbite.bin",
Gopher => "gopher.bin",
Gravitar => "gravitar.bin",
Hero => "hero.bin",
IceHockey => "ice_hockey.bin",
JamesBond => "jamesbond.bin",
JourneyEscape => "journey_escape.bin",
Kaboom => "kaboom.bin",
Kangaroo => "kangaroo.bin",
Krull => "krull.bin",
KungFuMaster => "kung_fu_master.bin",
MontezumaRevenge => "montezuma_revenge.bin",
MsPacman => "ms_pacman.bin",
NameThisGame => "name_this_game.bin",
Phoenix => "phoenix.bin",
Pitfall => "pitfall.bin",
Pong => "pong.bin",
Pooyan => "pooyan.bin",
PrivateEye => "private_eye.bin",
QBert => "qbert.bin",
RiverRaid => "riverraid.bin",
RoadRunner => "road_runner.bin",
RoboTank => "robotank.bin",
Seaquest => "seaquest.bin",
Skiing => "skiing.bin",
SpaceInvaders => "space_invaders.bin",
StarGunner => "star_gunner.bin",
Tennis => "tennis.bin",
TimePilot => "time_pilot.bin",
Tutankham => "tutankham.bin",
UpNDown => "up_n_down.bin",
Venture => "venture.bin",
VideoPinball => "video_pinball.bin",
WizardOfWor => "wizard_of_wor.bin",
YarsRevenge => "yars_revenge.bin",
Zaxxon => "zaxxon.bin",
}
}
pub fn data(&self) -> &'static [u8] {
use BundledRom::*;
match self {
Adventure => include_bytes!("../roms/adventure.bin"),
AirRaid => include_bytes!("../roms/air_raid.bin"),
Alien => include_bytes!("../roms/alien.bin"),
Amidar => include_bytes!("../roms/amidar.bin"),
Assault => include_bytes!("../roms/assault.bin"),
Asterix => include_bytes!("../roms/asterix.bin"),
Asteroids => include_bytes!("../roms/asteroids.bin"),
Atlantis => include_bytes!("../roms/atlantis.bin"),
BankHeist => include_bytes!("../roms/bank_heist.bin"),
BattleZone => include_bytes!("../roms/battle_zone.bin"),
BeamRider => include_bytes!("../roms/beam_rider.bin"),
Berzerk => include_bytes!("../roms/berzerk.bin"),
Bowling => include_bytes!("../roms/bowling.bin"),
Boxing => include_bytes!("../roms/boxing.bin"),
Breakout => include_bytes!("../roms/breakout.bin"),
Carnival => include_bytes!("../roms/carnival.bin"),
Centipede => include_bytes!("../roms/centipede.bin"),
ChopperCommand => include_bytes!("../roms/chopper_command.bin"),
CrazyClimber => include_bytes!("../roms/crazy_climber.bin"),
Defender => include_bytes!("../roms/defender.bin"),
DemonAttack => include_bytes!("../roms/demon_attack.bin"),
DoubleDunk => include_bytes!("../roms/double_dunk.bin"),
ElevatorAction => include_bytes!("../roms/elevator_action.bin"),
Enduro => include_bytes!("../roms/enduro.bin"),
FishingDerby => include_bytes!("../roms/fishing_derby.bin"),
Freeway => include_bytes!("../roms/freeway.bin"),
Frostbite => include_bytes!("../roms/frostbite.bin"),
Gopher => include_bytes!("../roms/gopher.bin"),
Gravitar => include_bytes!("../roms/gravitar.bin"),
Hero => include_bytes!("../roms/hero.bin"),
IceHockey => include_bytes!("../roms/ice_hockey.bin"),
JamesBond => include_bytes!("../roms/jamesbond.bin"),
JourneyEscape => include_bytes!("../roms/journey_escape.bin"),
Kaboom => include_bytes!("../roms/kaboom.bin"),
Kangaroo => include_bytes!("../roms/kangaroo.bin"),
Krull => include_bytes!("../roms/krull.bin"),
KungFuMaster => include_bytes!("../roms/kung_fu_master.bin"),
MontezumaRevenge => include_bytes!("../roms/montezuma_revenge.bin"),
MsPacman => include_bytes!("../roms/ms_pacman.bin"),
NameThisGame => include_bytes!("../roms/name_this_game.bin"),
Phoenix => include_bytes!("../roms/phoenix.bin"),
Pitfall => include_bytes!("../roms/pitfall.bin"),
Pong => include_bytes!("../roms/pong.bin"),
Pooyan => include_bytes!("../roms/pooyan.bin"),
PrivateEye => include_bytes!("../roms/private_eye.bin"),
QBert => include_bytes!("../roms/qbert.bin"),
RiverRaid => include_bytes!("../roms/riverraid.bin"),
RoadRunner => include_bytes!("../roms/road_runner.bin"),
RoboTank => include_bytes!("../roms/robotank.bin"),
Seaquest => include_bytes!("../roms/seaquest.bin"),
Skiing => include_bytes!("../roms/skiing.bin"),
SpaceInvaders => include_bytes!("../roms/space_invaders.bin"),
StarGunner => include_bytes!("../roms/star_gunner.bin"),
Tennis => include_bytes!("../roms/tennis.bin"),
TimePilot => include_bytes!("../roms/time_pilot.bin"),
Tutankham => include_bytes!("../roms/tutankham.bin"),
UpNDown => include_bytes!("../roms/up_n_down.bin"),
Venture => include_bytes!("../roms/venture.bin"),
VideoPinball => include_bytes!("../roms/video_pinball.bin"),
WizardOfWor => include_bytes!("../roms/wizard_of_wor.bin"),
YarsRevenge => include_bytes!("../roms/yars_revenge.bin"),
Zaxxon => include_bytes!("../roms/zaxxon.bin"),
}
}
}