from flask import Flask, request, jsonify
import uuid
import gym
import numpy as np
import argparse
import json
import logging
logger = logging.getLogger('werkzeug')
logger.setLevel(logging.ERROR)
class Envs(object):
def __init__(self):
self.envs = {}
self.id_len = 8
def _lookup_env(self, instance_id):
try:
return self.envs[instance_id]
except KeyError:
raise InvalidUsage('Instance_id {} unknown'.format(instance_id))
def _remove_env(self, instance_id):
try:
del self.envs[instance_id]
except KeyError:
raise InvalidUsage('Instance_id {} unknown'.format(instance_id))
def create(self, env_id, seed=None):
try:
env = gym.make(env_id)
if seed:
env.seed(seed)
except gym.error.DependencyNotInstalled as e:
raise InvalidUsage("Dependency not installed: {}".format(str(e)))
except gym.error.DeprecatedEnv:
raise InvalidUsage("Deprecated environment with ID {}".format(env_id))
except gym.error.Error as e:
raise InvalidUsage("Error creating environment with ID {}: {}".format(env_id, str(e)))
instance_id = str(uuid.uuid4().hex)[:self.id_len]
self.envs[instance_id] = env
return instance_id
def list_all(self):
return dict([(instance_id, env.spec.id) for (instance_id, env) in self.envs.items()])
def reset(self, instance_id):
env = self._lookup_env(instance_id)
obs = env.reset()
if 'numpy' in str(type(obs)):
obs = np.array(obs).tolist()
if isinstance(obs, tuple):
obs = [ item.tolist() if isinstance(item, np.ndarray) else item for item in obs ]
return obs
def step(self, instance_id, action, render):
env = self._lookup_env(instance_id)
if render:
env.render()
if env.action_space.__class__.__name__ == "Box":
action = np.array(action)
if env.action_space.__class__.__name__ == "Tuple":
action = [int(a) for a in action]
[observation, reward, done, info] = env.step(action)
if 'numpy' in str(type(observation)):
observation = np.array(observation).tolist()
if isinstance(observation, tuple):
observation = [ item.tolist() if isinstance(item, np.ndarray) else item for item in observation ]
return [observation, reward, done, info]
def get_action_space_contains(self, instance_id, x):
env = self._lookup_env(instance_id)
return env.action_space.contains(int(x))
def get_action_space_info(self, instance_id):
env = self._lookup_env(instance_id)
return self._get_space_properties(env.action_space)
def get_action_space_sample(self, instance_id):
env = self._lookup_env(instance_id)
action = env.action_space.sample()
if 'numpy' in str(type(action)):
action = action.tolist()
elif isinstance(action, tuple):
list(action)
return action
def get_observation_space_contains(self, instance_id, j):
env = self._lookup_env(instance_id)
info = self._get_space_properties(env.observation_space)
for key, value in j.items():
if json.dumps(info[key]) != json.dumps(value):
print('Values for "{}" do not match. Passed "{}", Observed "{}".'.format(key, value, info[key]))
return False
return True
def get_observation_space_info(self, instance_id):
env = self._lookup_env(instance_id)
return self._get_space_properties(env.observation_space)
def _get_space_properties(self, space):
info = {}
info['name'] = space.__class__.__name__
if info['name'] == 'Discrete':
info['n'] = int(space.n)
elif info['name'] == 'Box':
info['shape'] = space.shape
info['low'] = [(float(x) if x != -np.inf else -1e100) for x in np.array(space.low ).flatten()]
info['high'] = [(float(x) if x != +np.inf else +1e100) for x in np.array(space.high).flatten()]
elif info['name'] == 'HighLow':
info['num_rows'] = space.num_rows
info['matrix'] = [((float(x) if x != -np.inf else -1e100) if x != +np.inf else +1e100) for x in np.array(space.matrix).flatten()]
elif info['name'] == 'Tuple':
info['n'] = len(space.spaces)
info['spaces'] = [ self._get_space_properties(sub_space) for sub_space in space.spaces ]
return info
def monitor_start(self, instance_id, directory, force, resume, video_callable):
env = self._lookup_env(instance_id)
if video_callable == False:
v_c = lambda count: False
else:
v_c = lambda count: count % video_callable == 0
self.envs[instance_id] = gym.wrappers.Monitor(env, directory, force=force, resume=resume, video_callable=v_c)
def monitor_close(self, instance_id):
env = self._lookup_env(instance_id)
env.close()
def env_close(self, instance_id):
env = self._lookup_env(instance_id)
env.close()
self._remove_env(instance_id)
app = Flask(__name__)
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
app.logger.disabled = True
envs = Envs()
class InvalidUsage(Exception):
status_code = 400
def __init__(self, message, status_code=None, payload=None):
Exception.__init__(self)
self.message = message
if status_code is not None:
self.status_code = status_code
self.payload = payload
def to_dict(self):
rv = dict(self.payload or ())
rv['message'] = self.message
return rv
def get_required_param(json, param):
if json is None:
logger.info("Request is not a valid json")
raise InvalidUsage("Request is not a valid json")
value = json.get(param, None)
if (value is None) or (value=='') or (value==[]):
logger.info("A required request parameter <{}> had value {}".format(param, value))
raise InvalidUsage("A required request parameter <{}> was not provided".format(param))
return value
def get_optional_param(json, param, default):
if json is None:
logger.info("Request is not a valid json")
raise InvalidUsage("Request is not a valid json")
value = json.get(param, None)
if ((value is None) or (value=='') or (value==[])) and value != default:
value = default
return value
@app.errorhandler(InvalidUsage)
def handle_invalid_usage(error):
response = jsonify(error.to_dict())
response.status_code = error.status_code
return response
@app.route('/v1/envs/', methods=['POST'])
def env_create():
env_id = get_required_param(request.get_json(), 'env_id')
seed = get_optional_param(request.get_json(), 'seed', None)
instance_id = envs.create(env_id, seed)
return jsonify(instance_id = instance_id)
@app.route('/v1/envs/', methods=['GET'])
def env_list_all():
all_envs = envs.list_all()
return jsonify(all_envs = all_envs)
@app.route('/v1/envs/<instance_id>/reset/', methods=['POST'])
def env_reset(instance_id):
observation = envs.reset(instance_id)
return jsonify(observation = observation)
@app.route('/v1/envs/<instance_id>/step/', methods=['POST'])
def env_step(instance_id):
json = request.get_json()
action = get_required_param(json, 'action')
render = get_optional_param(json, 'render', False)
[obs_jsonable, reward, done, info] = envs.step(instance_id, action, render)
return jsonify(observation = obs_jsonable, reward = reward, done = done, info = info)
@app.route('/v1/envs/<instance_id>/action_space/', methods=['GET'])
def env_action_space_info(instance_id):
info = envs.get_action_space_info(instance_id)
return jsonify(info = info)
@app.route('/v1/envs/<instance_id>/action_space/sample', methods=['GET'])
def env_action_space_sample(instance_id):
action = envs.get_action_space_sample(instance_id)
return jsonify(action = action)
@app.route('/v1/envs/<instance_id>/action_space/contains/<x>', methods=['GET'])
def env_action_space_contains(instance_id, x):
member = envs.get_action_space_contains(instance_id, x)
return jsonify(member = member)
@app.route('/v1/envs/<instance_id>/observation_space/', methods=['GET'])
def env_observation_space_info(instance_id):
info = envs.get_observation_space_info(instance_id)
return jsonify(info = info)
@app.route('/v1/envs/<instance_id>/observation_space/contains', methods=['POST'])
def env_observation_space_contains(instance_id):
j = request.get_json()
member = envs.get_observation_space_contains(instance_id, j)
return jsonify(member = member)
@app.route('/v1/envs/<instance_id>/monitor/start/', methods=['POST'])
def env_monitor_start(instance_id):
j = request.get_json()
directory = get_required_param(j, 'directory')
force = get_optional_param(j, 'force', False)
resume = get_optional_param(j, 'resume', False)
video_callable = get_optional_param(j, 'video_callable', False)
envs.monitor_start(instance_id, directory, force, resume, video_callable)
return ('', 204)
@app.route('/v1/envs/<instance_id>/monitor/close/', methods=['POST'])
def env_monitor_close(instance_id):
envs.monitor_close(instance_id)
return ('', 204)
@app.route('/v1/envs/<instance_id>/close/', methods=['POST'])
def env_close(instance_id):
envs.env_close(instance_id)
return ('', 204)
def main():
parser = argparse.ArgumentParser(description='Start a Gym HTTP API server')
parser.add_argument('-l', '--listen', help='interface to listen to', default='127.0.0.1')
parser.add_argument('-p', '--port', default=5000, type=int, help='port to bind to')
args = parser.parse_args()
app.run(host=args.listen, port=args.port, threaded=False, debug=False, use_debugger=False, use_reloader=False)
main()