import asyncio
import json
import logging
import sys
import time
import queue
import threading
import os
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
from relayrl_framework import ConfigLoader
from utils.plot import get_newest_dataset
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
class TensorboardWriter:
def __init__(self, config_path: str, env_dir=os.getcwd(), algorithm_name: str = 'run'):
config = ConfigLoader(config_path=config_path)
tb_params = config.get_tb_params()
scalar_tags = tb_params['scalar_tags']
global_step_tag = tb_params['global_step_tag']
self.shutdown_flag = False
self.launch_tb_on_startup = tb_params['launch_tb_on_startup']
self.writer = None
self._data_log_dir = env_dir + '/logs'
self._file_root = get_newest_dataset(self._data_log_dir, return_file_root=True)
self._tb_log_dir = self._file_root + f'/tb_' + algorithm_name.lower() + f'_{int(time.time())}'
self._file = self._file_root + '/progress.txt'
self.data_queue = queue.Queue()
self.valid_tags = False
self.scalar_tags = scalar_tags
self._global_step_tag = global_step_tag
self._recent_global_step = 0
self._loop_stop_signal = threading.Event()
self._tb_thread = threading.Thread(target=self._tensorboard_writer_processes)
self._tb_thread.daemon = False
self._tb_thread.start()
logging.info("[TensorboardWriter] Initialized")
def manually_queue_scalar(self, tag: str, scalar_value: float, global_step: int):
self.data_queue.put(('scalar', tag, scalar_value, global_step))
def shutdown(self, timeout=None):
logging.info("[TensorboardWriter - shutdown] Initiating shutdown...")
self.shutdown_flag = True
self._loop_stop_signal.set()
if timeout is not None or timeout != 0:
logging.info(f"[TensorboardWriter - shutdown] Waiting for writer thread to complete (timeout: {timeout}s)...")
self._tb_thread.join(timeout=timeout)
thread_completed = not self._tb_thread.is_alive()
if thread_completed:
logging.info("[TensorboardWriter - shutdown] Writer thread has completed successfully.")
else:
logging.warning("[TensorboardWriter - shutdown] Timeout reached, writer thread still running.")
return thread_completed
else:
logging.info("[TensorboardWriter - shutdown] Shutdown signal sent, not waiting for thread completion.")
return None
def _tensorboard_writer_processes(self):
def _validate_tag_existence():
logging.info("[TensorboardWriter - _validate_tag_existence] Validating scalar tags...")
if not os.path.exists(self._file_root):
logging.info("[TensorboardWriter - _validate_tag_existence] Data directory not found. Tensorboard not started.")
return False
if not os.path.exists(self._file):
logging.info("[TensorboardWriter - _validate_tag_existence] progress.txt not found. Tensorboard not started.")
return False
if os.path.getsize(self._file) == 0:
logging.info("[TensorboardWriter - _validate_tag_existence] progress.txt is empty. Tensorboard not started.")
return False
data = pd.read_table(self._file)
if not data.empty:
for scalar in self.scalar_tags:
if scalar not in data.columns:
logging.info(f"[TensorboardWriter - _validate_tag_existence] {scalar} not found in progress.txt. Removing from scalar tags.")
self.scalar_tags.remove(scalar)
if not self.scalar_tags:
logging.info("[TensorboardWriter - _validate_tag_existence] No scalar tags found. Tensorboard not started.")
self.valid_tags = False
return None
else:
self.valid_tags = True
return None
else:
logging.info("[TensorboardWriter - _validate_tag_existence] Data is empty. Tensorboard not started.")
return False
def _retrieve_and_queue_data(_previous_last_step: int):
logging.info("[TensorboardWriter - _retrieve_and_queue_data] Retrieving data from progress.txt...")
if not os.path.exists(self._file_root):
logging.info("[TensorboardWriter - _retrieve_and_queue_data] Data directory not found. Tensorboard not started.")
return _previous_last_step, 0
if not os.path.exists(self._file):
logging.info("[TensorboardWriter - _retrieve_and_queue_data] progress.txt not found. Tensorboard not started.")
return _previous_last_step, 0
if os.path.getsize(self._file) == 0:
logging.info("[TensorboardWriter - _retrieve_and_queue_data] progress.txt is empty. Tensorboard not started.")
return _previous_last_step, 0
data = pd.read_table(self._file)
_queued_count = 0
if not data.empty:
try:
new_last_step = int(data[self._global_step_tag].idxmax())
except (KeyError, ValueError):
logging.warning(f"[TensorboardWriter - _retrieve_and_queue_data] Could not find {self._global_step_tag} in data.")
return _previous_last_step, 0
for scalar in self.scalar_tags:
for i in range(_previous_last_step + 1, new_last_step + 1):
try:
self.data_queue.put(('scalar', scalar, data[scalar][i], data[self._global_step_tag][i]))
_queued_count += 1
except (KeyError, IndexError):
logging.warning(f"[TensorboardWriter - _retrieve_and_queue_data] Error accessing data for {scalar} at step {i}.")
continue
return new_last_step, _queued_count
else:
logging.info("[TensorboardWriter - _retrieve_and_queue_data] Data is empty. Tensorboard not started.")
return _previous_last_step, 0
while not self.shutdown_flag:
if _validate_tag_existence() is not None:
time.sleep(5)
else:
break
if self.shutdown_flag:
logging.info("[TensorboardWriter - _tensorboard_processes] Shutdown requested during validation. Stopping.")
self._loop_stop_signal.set()
return
if not self.valid_tags:
logging.info("[TensorboardWriter - _tensorboard_processes] No valid tags found. Stopping.")
self._loop_stop_signal.set()
return
else:
self.writer = SummaryWriter(log_dir=self._tb_log_dir, filename_suffix='_tb')
previous_last_step = 0
while not self.shutdown_flag:
previous_last_step, queued_count = _retrieve_and_queue_data(previous_last_step)
if queued_count > 0:
try:
for count in range(queued_count):
if self.shutdown_flag:
break
write_type, *args = self.data_queue.get()
if write_type == 'scalar':
tag, scalar_value, global_step = args
logging.info(f"[TensorboardWriter - _tensorboard_processes] Writing scalar: {args}")
self.writer.add_scalar(tag, scalar_value, global_step)
self._recent_global_step += 1
if self._recent_global_step == 1 and self.launch_tb_on_startup:
launch_tensorboard(self._tb_log_dir)
if not self.shutdown_flag:
self.writer.flush()
except queue.Empty:
continue
else:
for _ in range(10):
if self.shutdown_flag:
break
time.sleep(1)
logging.info("[TensorboardWriter - _tensorboard_processes] Shutdown flag detected. Stopping writer.")
if self.writer:
self.writer.close()
return
def launch_tensorboard(logdir: str):
import subprocess
try:
logging.info("[launch_tensorboard] Starting Tensorboard.")
logdir = os.path.join(logdir, 'logs')
subprocess.run(["tensorboard", "--logdir", logdir], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except Exception as e:
logging.info(f"[launch_tensorboard] Error: {e}")
finally:
logging.info("[launch_tensorboard] Tensorboard dashboard unable to start. Please start it manually using `tensorboard --logdir <log_dir>`.")
async def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--env_dir", type=str, default='<directory_of_environment>')
parser.add_argument("--config_path", type=str, default='<path_to_config_json>')
parser.add_argument("--algorithm_name", type=str, default='<algorithm_name>')
args = parser.parse_args()
tb_writer = TensorboardWriter(
config_path=args.config_path,
env_dir=args.env_dir,
algorithm_name=args.algorithm_name
)
loop = asyncio.get_running_loop()
reader = asyncio.StreamReader(limit=5120 * 5120) protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
while not tb_writer.shutdown_flag:
line = await reader.readline()
if not line:
break
try:
command_json = json.loads(line.decode().strip())
command = command_json.get("command")
if command == "shutdown":
tb_writer.shutdown(timeout=10)
else:
logging.info(f"[TensorboardWriter - main] Unknown command: {command}")
except json.JSONDecodeError:
logging.info(f"[TensorboardWriter - main] Invalid JSON received: {line.decode().strip()}")
print(line.decode('utf-8'), end='', flush=True)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())