import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
import os.path as osp
import numpy as np
from packaging import version
DIV_LINE_WIDTH = 50
exp_idx = 0
units = dict()
def get_simple_dataset_plot(data, x, y, title) -> sns.lineplot:
plot = sns.lineplot(data=data, x=x, y=y)
plot.set_title(title)
plot.set_xlabel('Epoch')
plot.set_ylabel('AverageEpRet')
return plot
def plot_data(data, xaxis='Epoch', value="AverageEpRet", condition="Condition1", smooth=1, **kwargs):
if smooth > 1:
y = np.ones(smooth)
for datum in data:
x = np.asarray(datum[value])
z = np.ones(len(x))
smoothed_x = np.convolve(x, y, 'same') / np.convolve(z, y, 'same')
datum[value] = smoothed_x
if isinstance(data, list):
data = pd.concat(data, ignore_index=True)
if version.parse(sns.__version__) <= version.parse("0.8.1"):
sns.tsplot(data=data, time=xaxis, value=value, unit="Unit", condition=condition, ci='sd', **kwargs)
else:
sns.lineplot(data=data, x=xaxis, y=value, hue=condition, errorbar='sd', **kwargs)
plt.legend(loc=4).set_draggable(True)
xscale = np.max(np.asarray(data[xaxis])) > 5e3
if xscale:
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
plt.tight_layout(pad=0.5)
def get_newest_dataset(data_log_dir: str, return_file_root: bool = False):
if not osp.exists(data_log_dir):
return None
roots = []
progress_files = []
for root, dirs, files in os.walk(data_log_dir):
for file in files:
if file == 'progress.txt':
full_path = os.path.join(root, file)
roots.append(root)
progress_files.append(full_path)
if not progress_files:
return None
newest_file = max(progress_files, key=os.path.getctime)
newest_root = os.path.abspath(os.path.dirname(newest_file))
if return_file_root:
return newest_root
newest_dataset = pd.read_table(newest_file)
return newest_dataset
def get_datasets(logdir, condition=None, other_algos=False):
global exp_idx
global units
datasets = []
for root, _, files in os.walk(logdir):
if 'progress.txt' in files:
exp_name = None
try:
config_path = open(os.path.join(root, 'config.json'))
config = json.load(config_path)
if 'exp_name' in config:
exp_name = config['exp_name']
except:
print('No file named config.json')
condition1 = condition or exp_name or 'exp'
condition2 = condition1 + '-' + str(exp_idx)
exp_idx += 1
if condition1 not in units:
units[condition1] = 0
unit = units[condition1]
units[condition1] += 1
try:
exp_data = pd.read_table(os.path.join(root, 'progress.txt'))
except:
print('Could not read from %s' % os.path.join(root, 'progress.txt'))
continue
performance = 'AverageTestEpRet' if 'AverageTestEpRet' in exp_data else 'AverageEpRet'
exp_data.insert(len(exp_data.columns), 'Unit', unit)
if other_algos:
exp_data2 = exp_data.copy()
exp_data2.insert(len(exp_data2.columns), 'Condition1', "F1")
exp_data2.insert(len(exp_data2.columns), 'Condition2', "F1")
exp_data2.insert(len(exp_data2.columns), 'Performance', -exp_data["F1"])
datasets.append(exp_data2)
exp_data3 = exp_data.copy()
exp_data3.insert(len(exp_data3.columns), 'Condition1', "SJF")
exp_data3.insert(len(exp_data3.columns), 'Condition2', "SJF")
exp_data3.insert(len(exp_data3.columns), 'Performance', -exp_data["SJF"])
datasets.append(exp_data3)
exp_data.insert(len(exp_data.columns), 'Condition1', condition1)
exp_data.insert(len(exp_data.columns), 'Condition2', condition2)
exp_data.insert(len(exp_data.columns), 'Performance', exp_data[performance])
datasets.append(exp_data)
return datasets
def get_all_datasets(all_logdirs, legend=None, select=None, exclude=None, other_algos=False):
logdirs = []
for logdir in all_logdirs:
if osp.isdir(logdir) and logdir[-1] == os.sep:
logdirs += [logdir]
else:
basedir = osp.dirname(logdir)
fulldir = lambda x: osp.join(basedir, x)
prefix = logdir.split(os.sep)[-1]
listdir = os.listdir(basedir)
logdirs += sorted([fulldir(x) for x in listdir if prefix in x])
if select is not None:
logdirs = [log for log in logdirs if all(x in log for x in select)]
if exclude is not None:
logdirs = [log for log in logdirs if all(not (x in log) for x in exclude)]
print('Plotting from...\n' + '=' * DIV_LINE_WIDTH + '\n')
for logdir in logdirs:
print(logdir)
print('\n' + '=' * DIV_LINE_WIDTH)
assert not (legend) or (len(legend) == len(logdirs)), \
"Must give a legend title for each set of experiments."
data = []
if legend:
for log, leg in zip(logdirs, legend):
data += get_datasets(log, leg, other_algos)
else:
for log in logdirs:
data += get_datasets(log, other_algos=other_algos)
return data
def make_plots(all_logdirs, legend=None, xaxis=None, values=None, count=False,
font_scale=1.5, smooth=1, select=None, exclude=None, estimator='mean', other_algos=False):
data = get_all_datasets(all_logdirs, legend, select, exclude, other_algos=other_algos)
values = values if isinstance(values, list) else [values]
condition = 'Condition2' if count else 'Condition1'
estimator = getattr(np, estimator) for value in values:
plt.figure()
plot_data(data, xaxis=xaxis, value=value, condition=condition, smooth=smooth, estimator=estimator)
plt.show()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('logdir', nargs='*')
parser.add_argument('--legend', '-l', nargs='*')
parser.add_argument('--xaxis', '-x', default='TotalEnvInteracts')
parser.add_argument('--value', '-y', default='Performance', nargs='*')
parser.add_argument('--count', action='store_true')
parser.add_argument('--smooth', '-s', type=int, default=2)
parser.add_argument('--select', nargs='*')
parser.add_argument('--exclude', nargs='*')
parser.add_argument('--est', default='mean')
parser.add_argument('--other_algos', type=int, default=0)
args = parser.parse_args()
make_plots(args.logdir, args.legend, args.xaxis, args.value, args.count,
smooth=args.smooth, select=args.select, exclude=args.exclude,
estimator=args.est, other_algos=args.other_algos)
if __name__ == "__main__":
main()