import json
import string
import time
import os.path as osp
import time
import atexit
import os
import warnings
import numpy as np
import torch
DEFAULT_DATA_DIR = osp.join(osp.abspath(
osp.dirname(osp.dirname(__file__))), 'data')
FORCE_DATESTAMP = False
color2num = dict(gray=30,
red=31,
green=32,
yellow=33,
blue=34,
magenta=35,
cyan=36,
white=37,
crimson=38)
def colorize(string, color, bold=False, highlight=False):
attr = []
num = color2num[color]
if highlight:
num += 10
attr.append(str(num))
if bold:
attr.append('1')
return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
def convert_json(obj):
if is_json_serializable(obj):
return obj
else:
if isinstance(obj, dict):
return {convert_json(k): convert_json(v) for k, v in obj.items()}
elif isinstance(obj, tuple):
return (convert_json(x) for x in obj)
elif isinstance(obj, list):
return [convert_json(x) for x in obj]
elif hasattr(obj, '__name__') and not ('lambda' in obj.__name__):
return convert_json(obj.__name__)
elif hasattr(obj, '__dict__') and obj.__dict__:
obj_dict = {
convert_json(k): convert_json(v)
for k, v in obj.__dict__.items()
}
return {str(obj): obj_dict}
return str(obj)
def is_json_serializable(v):
try:
json.dumps(v)
return True
except:
return False
def statistics_scalar(x, with_min_and_max=False):
x = np.array(x, dtype=np.float32)
global_sum = np.sum(x)
global_n = len(x)
mean = global_sum / global_n
global_sum_sq = np.sum((x - mean)**2)
std = np.sqrt(global_sum_sq / global_n)
if with_min_and_max:
global_min = np.min(x) if len(x) > 0 else np.inf
global_max = np.max(x) if len(x) > 0 else -np.inf
return mean, std, global_min, global_max
return mean, std
class Logger:
def __init__(self,
output_dir=None,
output_fname='progress.txt',
exp_name=None):
self.output_dir = output_dir or "/tmp/experiments/%i" % int(time.time())
if osp.exists(self.output_dir):
print(
"Warning: Log dir %s already exists! Storing info there anyway."
% self.output_dir)
else:
os.makedirs(self.output_dir)
self.output_file = open(
osp.join(self.output_dir, output_fname), 'w')
atexit.register(self.output_file.close)
print(colorize("Logging data to %s" % self.output_file.name, 'green', bold=True))
self.first_row = True
self.log_headers = []
self.log_current_row = {}
self.exp_name = exp_name
def log(self, msg, color='green'):
print(colorize(msg, color, bold=True))
def log_tabular(self, key, val):
if self.first_row:
self.log_headers.append(key)
else:
assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration" % key
assert key not in self.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key
self.log_current_row[key] = val
def save_config(self, config):
config_json = convert_json(config)
if self.exp_name is not None:
config_json['exp_name'] = self.exp_name
output = json.dumps(config_json,
separators=(',', ':\t'),
indent=4,
sort_keys=True)
print(colorize('Saving config:\n', color='cyan', bold=True))
print(output)
with open(osp.join(self.output_dir, "config.json"), 'w') as out:
out.write(output)
def save_state(self, state_dict, itr=None):
fname = 'vars.pkl' if itr is None else 'vars%d.pkl' % itr
try:
joblib.dump(state_dict, osp.join(self.output_dir, fname))
except:
self.log('Warning: could not pickle state_dict.', color='red')
if hasattr(self, 'tf_saver_elements'):
self._tf_simple_save(itr)
if hasattr(self, 'pytorch_saver_elements'):
self._pytorch_simple_save(itr)
def setup_pytorch_saver(self, what_to_save):
self.pytorch_saver_elements = what_to_save
def _pytorch_simple_save(self, itr=None):
assert hasattr(self, 'pytorch_saver_elements'), \
"First have to setup saving with self.setup_pytorch_saver"
fpath = 'pyt_save'
fpath = osp.join(self.output_dir, fpath)
fname = 'model' + ('%d' % itr if itr is not None else '') + '.pt'
fname = osp.join(fpath, fname)
os.makedirs(fpath, exist_ok=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
torch.save(self.pytorch_saver_elements, fname)
def dump_tabular(self):
vals = []
key_lens = [len(key) for key in self.log_headers]
max_key_len = max(15, max(key_lens))
keystr = '%' + '%d' % max_key_len
fmt = "| " + keystr + "s | %15s |"
n_slashes = 22 + max_key_len
print("-" * n_slashes)
for key in self.log_headers:
val = self.log_current_row.get(key, "")
valstr = "%8.3g" % val if hasattr(val, "__float__") else val
print(fmt % (key, valstr))
vals.append(val)
print("-" * n_slashes, flush=True)
if self.output_file is not None:
if self.first_row:
self.output_file.write("\t".join(self.log_headers) + "\n")
self.output_file.write("\t".join(map(str, vals)) + "\n")
self.output_file.flush()
self.log_current_row.clear()
self.first_row = False
class EpochLogger(Logger):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.epoch_dict = dict()
def store(self, **kwargs):
for k, v in kwargs.items():
if not (k in self.epoch_dict.keys()):
self.epoch_dict[k] = []
self.epoch_dict[k].append(v)
def log_tabular(self,
key,
val=None,
with_min_and_max=False,
average_only=False):
if val is not None:
super().log_tabular(key, val)
else:
v = self.epoch_dict[key]
vals = np.concatenate(v) if isinstance(
v[0], np.ndarray) and len(v[0].shape) > 0 else v
stats = statistics_scalar(vals, with_min_and_max=with_min_and_max)
super().log_tabular(key if average_only else 'Average' + key, stats[0])
if not (average_only):
super().log_tabular('Std' + key, stats[1])
if with_min_and_max:
super().log_tabular('Max' + key, stats[3])
super().log_tabular('Min' + key, stats[2])
self.epoch_dict[key] = []
def get_stats(self, key):
v = self.epoch_dict[key]
vals = np.concatenate(v) if isinstance(
v[0], np.ndarray) and len(v[0].shape) > 0 else v
return statistics_scalar(vals)
def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=False):
datestamp = datestamp or FORCE_DATESTAMP
ymd_time = time.strftime("%Y-%m-%d_") if datestamp else ''
relpath = ''.join([ymd_time, exp_name])
if seed is not None:
if datestamp:
hms_time = time.strftime("%Y-%m-%d_%H-%M-%S")
subfolder = ''.join([hms_time, '-', exp_name, '_s', str(seed)])
else:
subfolder = ''.join([exp_name, '_s', str(seed)])
relpath = osp.join(relpath, subfolder)
data_dir = data_dir or DEFAULT_DATA_DIR
logger_kwargs = dict(output_dir=osp.join(data_dir, relpath), exp_name=exp_name)
return logger_kwargs
def all_bools(vals):
return all([isinstance(v, bool) for v in vals])
def valid_str(v):
if hasattr(v, '__name__'):
return valid_str(v.__name__)
if isinstance(v, tuple) or isinstance(v, list):
return '-'.join([valid_str(x) for x in v])
str_v = str(v).lower()
valid_chars = "-_%s%s" % (string.ascii_letters, string.digits)
str_v = ''.join(c if c in valid_chars else '-' for c in str_v)
return str_v