#include "rlglue_controller.hpp"
#ifdef __USE_RLGLUE
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include "../environment/ale_ram.hpp"
#include <rlglue/utils/C/RLStruct_util.h>
#include "../common/Log.hpp"
namespace ale {
RLGlueController::RLGlueController(OSystem* _osystem)
: ALEController(_osystem) {
m_max_num_frames = m_osystem->settings().getInt("max_num_frames");
if (m_osystem->settings().getBool("restricted_action_set")) {
available_actions = m_settings->getMinimalActionSet();
} else {
available_actions = m_settings->getAllActions();
}
m_send_rgb = m_osystem->settings().getBool("send_rgb");
}
RLGlueController::~RLGlueController() {}
void RLGlueController::run() {
initRLGlue();
rlGlueLoop();
endRLGlue();
}
bool RLGlueController::isDone() {
return ((m_max_num_frames > 0 &&
m_environment.getFrameNumber() >= m_max_num_frames));
}
void RLGlueController::initRLGlue() {
ale::Logger::Info << "Initializing ALE RL-Glue ..." << std::endl;
const char* host = kLocalHost;
short port = kDefaultPort;
const char* envptr = 0;
envptr = getenv("RLGLUE_PORT");
if (envptr != 0) {
port = strtol(envptr, 0, 10);
if (port == 0) {
port = kDefaultPort;
}
}
rlBufferCreate(&m_buffer, 4096);
m_connection = rlWaitForConnection(host, port, kRetryTimeout);
rlBufferClear(&m_buffer);
rlSendBufferData(m_connection, &m_buffer, kEnvironmentConnection);
}
void RLGlueController::endRLGlue() {
rlClose(m_connection);
rlBufferDestroy(&m_buffer);
}
void RLGlueController::rlGlueLoop() {
int envState = 0;
bool error = false;
while (!isDone() && !error && envState != kRLTerm) {
rlBufferClear(&m_buffer);
rlRecvBufferData(m_connection, &m_buffer, &envState);
switch (envState) {
case kEnvInit:
envInit();
break;
case kEnvStart:
envStart();
break;
case kEnvStep:
envStep();
break;
case kEnvCleanup:
envCleanup();
break;
case kEnvMessage:
envMessage();
break;
case kRLTerm:
break;
default:
ale::Logger::Error << "Unknown RL-Glue command: " << envState
<< std::endl;
error = true;
break;
};
rlSendBufferData(m_connection, &m_buffer, envState);
display();
}
}
void RLGlueController::envInit() {
unsigned int offset = 0;
unsigned int observation_dimensions;
std::stringstream taskSpec;
taskSpec
<< "VERSION RL-Glue-3.0 "
"PROBLEMTYPE episodic "
"DISCOUNTFACTOR 1 " "OBSERVATIONS INTS (128 0 255)"; if (m_send_rgb) {
taskSpec << "(100800 0 255) "; observation_dimensions = 128 + 210 * 160 * 3;
} else {
taskSpec << "(33600 0 127) "; observation_dimensions = 128 + 210 * 160;
}
taskSpec
<< "ACTIONS INTS (0 " << available_actions.size()
<< ") "
"REWARDS (UNSPEC UNSPEC) " "EXTRA Name: Arcade Learning Environment ";
allocateRLStruct(&m_rlglue_action, 1, 0, 0);
allocateRLStruct(&m_observation, observation_dimensions, 0, 0);
rlBufferClear(&m_buffer);
unsigned int taskSpecLength = taskSpec.str().size();
offset += rlBufferWrite(&m_buffer, offset, &taskSpecLength, 1, sizeof(int));
rlBufferWrite(&m_buffer, offset, taskSpec.str().c_str(), taskSpecLength,
sizeof(char));
}
void RLGlueController::envStart() {
m_environment.reset();
reward_t reset_reward = 0;
constructRewardObservationTerminal(reset_reward);
rlBufferClear(&m_buffer);
rlCopyADTToBuffer(&m_observation, &m_buffer, 0);
}
void RLGlueController::envStep() {
unsigned int offset = 0;
offset = rlCopyBufferToADT(&m_buffer, offset, &m_rlglue_action);
__RL_CHECK_STRUCT(&m_rlglue_action);
unsigned int player_a_action_index = m_rlglue_action.intArray[0];
if (player_a_action_index >= available_actions.size()) {
player_a_action_index = 0;
}
Action player_a_action = available_actions[player_a_action_index];
Action player_b_action = (Action)PLAYER_B_NOOP;
filterActions(player_a_action, player_b_action);
reward_t reward = applyActions(player_a_action, player_b_action);
reward_observation_terminal_t ro = constructRewardObservationTerminal(reward);
rlBufferClear(&m_buffer);
offset = 0;
offset = rlBufferWrite(&m_buffer, offset, &ro.terminal, 1, sizeof(int));
offset = rlBufferWrite(&m_buffer, offset, &ro.reward, 1, sizeof(double));
offset = rlCopyADTToBuffer(ro.observation, &m_buffer, offset);
}
void RLGlueController::envCleanup() {
rlBufferClear(&m_buffer);
clearRLStruct(&m_observation);
}
void RLGlueController::envMessage() {
unsigned int messageLength;
unsigned int offset = 0;
offset = rlBufferRead(&m_buffer, offset, &messageLength, 1, sizeof(int));
if (messageLength > 0) {
char* message = new char[messageLength + 1];
rlBufferRead(&m_buffer, offset, message, messageLength, sizeof(char));
message[messageLength] = 0;
ale::Logger::Error << "Message from RL-Glue: " << message << std::endl;
delete[] message;
}
}
void RLGlueController::filterActions(Action& player_a_action,
Action& player_b_action) {
if (player_a_action >= PLAYER_A_MAX)
player_a_action = PLAYER_A_NOOP;
if (player_b_action < PLAYER_B_NOOP || player_b_action >= PLAYER_B_MAX)
player_b_action = PLAYER_B_NOOP;
}
reward_observation_terminal_t
RLGlueController::constructRewardObservationTerminal(reward_t reward) {
reward_observation_terminal_t ro;
int index = 0;
const ALERAM& ram = m_environment.getRAM();
const ALEScreen& screen = m_environment.getScreen();
for (size_t i = 0; i < ram.size(); i++)
m_observation.intArray[index++] = ram.get(i);
size_t arraySize = screen.arraySize();
if (m_send_rgb) {
assert(arraySize * 3 + ram.size() == m_observation.numInts);
pixel_t* screenArray = screen.getArray();
int red, green, blue;
for (size_t i = 0; i < arraySize; i++) {
m_osystem->colourPalette().getRGB(screenArray[i], red, green, blue);
m_observation.intArray[index++] = red;
m_observation.intArray[index++] = green;
m_observation.intArray[index++] = blue;
}
} else {
assert(arraySize + ram.size() == m_observation.numInts);
for (size_t i = 0; i < arraySize; i++)
m_observation.intArray[index++] = screen.getArray()[i];
}
ro.observation = &m_observation;
ro.reward = reward;
ro.terminal = m_settings->isTerminal();
__RL_CHECK_STRUCT(ro.observation)
return ro;
}
}
#else
namespace ale {
RLGlueController::RLGlueController(OSystem* system) : ALEController(system) {}
void RLGlueController::run() {
ale::Logger::Error
<< "RL-Glue interface unavailable. Please recompile with RL-Glue support."
<< std::endl;
}
}
#endif