#include "stella_environment.hpp"
#include <sstream>
#include "../emucore/m6502/src/System.hxx"
namespace ale {
StellaEnvironment::StellaEnvironment(OSystem* osystem, RomSettings* settings)
: m_osystem(osystem),
m_settings(settings),
m_phosphor_blend(osystem),
m_screen(m_osystem->console().mediaSource().height(),
m_osystem->console().mediaSource().width()),
m_player_a_action(PLAYER_A_NOOP),
m_player_b_action(PLAYER_B_NOOP) {
if (m_osystem->console().properties().get(Controller_Left) == "PADDLES" ||
m_osystem->console().properties().get(Controller_Right) == "PADDLES") {
m_use_paddles = true;
int paddle_min_val = m_osystem->settings().getInt("paddle_min");
int paddle_max_val = m_osystem->settings().getInt("paddle_max");
m_state.setPaddleLimits(paddle_min_val != -1 ? paddle_min_val : PADDLE_MIN,
paddle_max_val != -1 ? paddle_max_val : PADDLE_MAX);
m_state.resetPaddles(m_osystem->event());
} else {
m_use_paddles = false;
}
m_num_reset_steps = 4;
m_cartridge_md5 = m_osystem->console().properties().get(Cartridge_MD5);
m_max_num_frames_per_episode =
m_osystem->settings().getInt("max_num_frames_per_episode");
m_colour_averaging = m_osystem->settings().getBool("color_averaging");
m_repeat_action_probability =
m_osystem->settings().getFloat("repeat_action_probability");
m_frame_skip = m_osystem->settings().getInt("frame_skip");
if (m_frame_skip < 1) {
ale::Logger::Warning << "Warning: frame skip set to < 1. Setting to 1."
<< std::endl;
m_frame_skip = 1;
}
std::string recordDir = m_osystem->settings().getString("record_screen_dir");
if (!recordDir.empty()) {
ale::Logger::Info << "Recording screens to directory: " << recordDir
<< std::endl;
m_screen_exporter.reset(
new ScreenExporter(m_osystem->colourPalette(), recordDir));
}
}
void StellaEnvironment::reset() {
m_state.resetEpisodeFrameNumber();
m_state.resetPaddles(m_osystem->event());
m_osystem->console().system().reset();
int noopSteps;
noopSteps = 60;
emulate(PLAYER_A_NOOP, PLAYER_B_NOOP, noopSteps);
softReset();
m_settings->reset();
m_settings->setMode(m_state.getCurrentMode(), m_osystem->console().system(),
getWrapper());
softReset();
ActionVect startingActions = m_settings->getStartingActions();
for (size_t i = 0; i < startingActions.size(); i++) {
emulate(startingActions[i], PLAYER_B_NOOP);
}
}
void StellaEnvironment::save() {
ALEState new_state = cloneState();
m_saved_states.push(new_state);
}
void StellaEnvironment::load() {
ALEState& target_state = m_saved_states.top();
restoreState(target_state);
m_saved_states.pop();
}
ALEState StellaEnvironment::cloneState() {
return m_state.save(m_osystem, m_settings, m_cartridge_md5, false);
}
void StellaEnvironment::restoreState(const ALEState& target_state) {
m_state.load(m_osystem, m_settings, m_cartridge_md5, target_state, false);
}
ALEState StellaEnvironment::cloneSystemState() {
return m_state.save(m_osystem, m_settings, m_cartridge_md5, true);
}
void StellaEnvironment::restoreSystemState(const ALEState& target_state) {
m_state.load(m_osystem, m_settings, m_cartridge_md5, target_state, true);
}
void StellaEnvironment::noopIllegalActions(Action& player_a_action,
Action& player_b_action) {
if (player_a_action < (Action)PLAYER_B_NOOP &&
!m_settings->isLegal(player_a_action)) {
player_a_action = (Action)PLAYER_A_NOOP;
}
else if (player_a_action == RESET)
player_a_action = (Action)PLAYER_A_NOOP;
if (player_b_action < (Action)RESET &&
!m_settings->isLegal((Action)((int)player_b_action - PLAYER_B_NOOP))) {
player_b_action = (Action)PLAYER_B_NOOP;
} else if (player_b_action == RESET)
player_b_action = (Action)PLAYER_B_NOOP;
}
reward_t StellaEnvironment::act(Action player_a_action,
Action player_b_action) {
reward_t sum_rewards = 0;
Random& rng = m_osystem->rng();
for (size_t i = 0; i < m_frame_skip; i++) {
if (rng.nextDouble() >= m_repeat_action_probability)
m_player_a_action = player_a_action;
if (rng.nextDouble() >= m_repeat_action_probability)
m_player_b_action = player_b_action;
m_osystem->sound().recordNextFrame();
if (m_screen_exporter.get() != NULL)
m_screen_exporter->saveNext(m_screen);
sum_rewards += oneStepAct(m_player_a_action, m_player_b_action);
}
return sum_rewards;
}
void StellaEnvironment::softReset() {
emulate(RESET, PLAYER_B_NOOP, m_num_reset_steps);
m_player_a_action = PLAYER_A_NOOP;
m_player_b_action = PLAYER_B_NOOP;
}
reward_t StellaEnvironment::oneStepAct(Action player_a_action,
Action player_b_action) {
if (isTerminal())
return 0;
noopIllegalActions(player_a_action, player_b_action);
emulate(player_a_action, player_b_action);
m_state.incrementFrame();
return m_settings->getReward();
}
bool StellaEnvironment::isTerminal() const {
return (m_settings->isTerminal() ||
(m_max_num_frames_per_episode > 0 &&
m_state.getEpisodeFrameNumber() >= m_max_num_frames_per_episode));
}
void StellaEnvironment::pressSelect(size_t num_steps) {
m_state.pressSelect(m_osystem->event());
for (size_t t = 0; t < num_steps; t++) {
m_osystem->console().mediaSource().update();
}
processScreen();
processRAM();
emulate(PLAYER_A_NOOP, PLAYER_B_NOOP);
m_state.incrementFrame();
}
void StellaEnvironment::setDifficulty(difficulty_t value) {
m_state.setDifficulty(value);
}
void StellaEnvironment::setMode(game_mode_t value) {
m_state.setCurrentMode(value);
}
void StellaEnvironment::emulate(Action player_a_action, Action player_b_action,
size_t num_steps) {
Event* event = m_osystem->event();
if (m_use_paddles) {
for (size_t t = 0; t < num_steps; t++) {
m_state.applyActionPaddles(event, player_a_action, player_b_action);
m_osystem->console().mediaSource().update();
m_settings->step(m_osystem->console().system());
}
} else {
m_state.setActionJoysticks(event, player_a_action, player_b_action);
for (size_t t = 0; t < num_steps; t++) {
m_osystem->console().mediaSource().update();
m_settings->step(m_osystem->console().system());
}
}
processScreen();
processRAM();
}
void StellaEnvironment::setState(const ALEState& state) { m_state = state; }
const ALEState& StellaEnvironment::getState() const { return m_state; }
std::unique_ptr<StellaEnvironmentWrapper> StellaEnvironment::getWrapper() {
return std::unique_ptr<StellaEnvironmentWrapper>(
new StellaEnvironmentWrapper(*this));
}
void StellaEnvironment::processScreen() {
if (m_colour_averaging) {
m_phosphor_blend.process(m_screen);
} else {
memcpy(m_screen.getArray(),
m_osystem->console().mediaSource().currentFrameBuffer(),
m_screen.arraySize());
}
}
void StellaEnvironment::processRAM() {
for (size_t i = 0; i < m_ram.size(); i++)
*m_ram.byte(i) = m_osystem->console().system().peek(i + 0x80);
}
}