from __future__ import print_function
import sys
import time
import pickle
import numpy as np
import emcee
from emcee.utils import MPIPool
seed = 1234567
def sort_on_runtime(pos):
p = np.atleast_2d(pos)
idx = (np.argsort(p[:, 0]))[::-1]
return p[idx], idx
def lnprob(x):
if x[0] >= 0.0:
time.sleep(x[0])
return -0.5*x[0]*x[0]
mean_times = [0.5, 1.0]
variances = [0.1, 0.2, 0.5, 1.0, 2.0]
loadbalancing_options = [False, True]
runtime_sorting_options = [None, sort_on_runtime]
ndim = 1
nwalkers = 496
niters = 5
status_file = 'timings.txt'
f = open(status_file, "w")
pickle_file = 'initial_pos.pkl'
f.write("#####################################################################################################\n")
f.write("## loadbalance runtime_sorting iteration mean_time variance ideal actual \n")
f.write("#####################################################################################################\n")
for mean_time in mean_times:
for variance_fac in xrange(len(variances)):
first = 0
variance = variances[variance_fac]*mean_time
for loadbalance in loadbalancing_options:
for runtime_sorting_option in runtime_sorting_options:
pool = MPIPool(loadbalance=loadbalance)
if not pool.is_master():
pool.wait()
sys.exit(0)
sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, pool=pool, runtime_sortingfn=runtime_sorting_option)
tstart = time.time()
print("Before running the iterations. loadbalance = {0}".format(loadbalance))
if first == 0:
p0 = [(mean_time + variance * np.random.randn(ndim)) for col in range(nwalkers)]
pos, prob, rstate = sampler.run_mcmc(p0, 1)
pkl_file = open(pickle_file, 'wb')
pickle.dump(pos, pkl_file, -1)
pickle.dump(prob, pkl_file, -1)
pickle.dump(rstate, pkl_file, -1)
pkl_file.close()
else:
pkl_file = open(pickle_file, 'rb')
pos = pickle.load(pkl_file)
prob = pickle.load(pkl_file)
rstate = pickle.load(pkl_file)
pkl_file.close()
first = 1
t0 = time.time()
cumulative_time = 0.0
ideal_time = pos[pos > 0].sum()/(pool.comm.Get_size()-1)
for iternum, (pos, prob, rstate) in enumerate(sampler.sample(pos,prob,rstate,iterations=niters,storechain=False)):
t1 = time.time()
print("Done with iteration {0:2d}. time = {1:8.3f} seconds. perfect scaling = {2:8.3f} ".format(iternum, t1-t0, ideal_time))
if runtime_sorting_option is None:
integer_runtime_sort = 0
else:
integer_runtime_sort = 1
f.write(" {0:11b} {1:14d} {2:9d} {3:9.1f} {4:8.2f} {5:6.2f} {6:6.2f}\n".format(loadbalance, integer_runtime_sort, iternum+1, mean_time, variance, ideal_time, t1-t0))
f.flush()
cumulative_time = cumulative_time + t1-t0
t0 = t1
ideal_time = pos[pos > 0].sum()/(pool.comm.Get_size()-1)
t1 = time.time()
print("Loadbalancing = {0}, time variance = {1}. Total Time taken = {2:0.2f} seconds (avg = {3:0.3f})".format(loadbalance, variance, cumulative_time, cumulative_time/niters))
f.write("\n")
f.close()
pool.close()