import sys
import logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
import json
import importlib
import asyncio
from relayrl_framework import RelayRLTrajectory
class LoadScripts:
def __init__(self, algorithm_name, hyperparams, env_dir, resolved_algorithm_dir, config_path):
sys.path.append(resolved_algorithm_dir)
hyperparams = json.loads(hyperparams) if hyperparams else None
if hyperparams is not None:
for k, v in hyperparams.items():
if isinstance(v, str) and v.isdigit():
if "." in v:
hyperparams[k] = float(v)
else:
hyperparams[k] = int(v)
hyperparams['config_path'] = config_path
algorithm_module = f"{algorithm_name}.{algorithm_name}"
algorithm_module_import = importlib.import_module(algorithm_module)
algorithm_class = getattr(algorithm_module_import, algorithm_name)
self.algorithm = algorithm_class(**hyperparams) if algorithm_class else None
if self.algorithm is None:
logging.critical("[algorithm_pyscript_false]")
else:
logging.critical("[algorithm_pyscript_true]")
class PythonWorker:
def __init__(self, algorithm_name, hyperparams, env_dir, resolved_algorithm_dir, config_path):
self.loaded_scripts = LoadScripts(algorithm_name, hyperparams, env_dir, resolved_algorithm_dir, config_path)
self.lock = asyncio.Lock()
self.shutdown_flag = False
async def save_model(self):
async with self.lock:
if self.loaded_scripts.algorithm:
try:
await asyncio.to_thread(self.loaded_scripts.algorithm.save)
return {"status": "success"}
except Exception as e:
return {"status": "error", "message": str(e)}
else:
return {"status": "error", "message": "Algorithm not initialized"}
async def receive_trajectory(self, trajectory):
trajectory_obj = await asyncio.to_thread(RelayRLTrajectory.traj_from_json, trajectory)
async with self.lock:
if self.loaded_scripts.algorithm:
try:
status = await asyncio.to_thread(self.loaded_scripts.algorithm.receive_trajectory, trajectory_obj)
if status:
return {"status": "success"}
else:
return {"status": "not_updated"}
except Exception as e:
return {"status": "error", "message": str(e)}
else:
return {"status": "error", "message": "Algorithm not initialized"}
async def handle_command(self, command_json):
command = command_json.get("command")
if command == "save_model":
return await self.save_model()
elif command == "receive_trajectory":
trajectory = command_json.get("trajectory")
return await self.receive_trajectory(trajectory)
else:
return {"status": "error", "message": f"Unknown command: {command}"}
async def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--algorithm_name", type=str, default=None)
parser.add_argument("--env_dir", type=str, default=None)
parser.add_argument("--resolved_algorithm_dir", type=str, default=None)
parser.add_argument("--config_path", type=str, default=None)
parser.add_argument("--hyperparams", type=str, default=None)
args = parser.parse_args()
worker = PythonWorker(
algorithm_name=args.algorithm_name,
hyperparams=args.hyperparams,
env_dir=args.env_dir,
resolved_algorithm_dir=args.resolved_algorithm_dir,
config_path=args.config_path,
)
loop = asyncio.get_running_loop()
reader = asyncio.StreamReader(limit=10240 * 10240) protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
while not worker.shutdown_flag:
line = await reader.readline()
if not line:
break
try:
command_json = json.loads(line.decode().strip())
response = await worker.handle_command(command_json)
except json.JSONDecodeError:
response = {"status": "error", "message": "Invalid JSON"}
except Exception as e:
response = {"status": "error", "message": str(e)}
print(json.dumps(response))
sys.stdout.flush()
print("[PythonWorker] Shutting down worker...")
if __name__ == "__main__":
asyncio.run(main())