from __future__ import print_function
import subprocess
import select
import atexit
import collections
import os
import sys
import numpy as np
import emcee
_pools = []
def _finish_pools():
while _pools:
_pools[0].finish()
atexit.register(_finish_pools)
class Pool(object):
def __init__(self, commands):
self.popens = []
self.buffer = collections.defaultdict(str)
for cmd in commands:
p = subprocess.Popen(cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE)
self.init_subprocess(p)
self.popens.append(p)
_pools.append(self)
def finish(self):
for p in self.popens:
self.close_subprocess(p)
for p in self.popens:
p.wait()
del self.popens[:]
del _pools[ _pools.index(self) ]
def init_subprocess(self, popen):
def close_subprocess(self, popen):
popen.stdin.close()
def send_parameters(self, stdin, params):
txt = ' '.join([str(x) for x in params])
stdin.write(txt + '\n')
stdin.flush()
def identify_lnprob(self, text):
if text[-1] != '\n':
return None
try:
return float(text.strip())
except ValueError:
return None
def get_lnprob(self, stdout):
txt = os.read(stdout.fileno(), 4096)
self.buffer[stdout] += txt
val = self.identify_lnprob(self.buffer[stdout])
if val is not None:
self.buffer[stdout] = ''
return val
else:
return None
def map(self, function, paramlist):
inparams = zip(range(len(paramlist)), paramlist)
results = [None]*len(inparams)
freepopens = set( self.popens )
waitingstdout = {}
while inparams or waitingstdout:
while freepopens and inparams:
idx, params = inparams[0]
popen = iter(freepopens).next()
self.send_parameters(popen.stdin, params)
del inparams[0]
waitingstdout[popen.stdout] = (idx, popen)
freepopens.remove(popen)
stdouts = select.select( waitingstdout.keys(), [], [], 0.001 )[0]
for stdout in stdouts:
lnprob = self.get_lnprob(stdout)
if lnprob is not None:
idx, popen = waitingstdout[stdout]
results[idx] = lnprob
del waitingstdout[stdout]
freepopens.add(popen)
return results
def main():
cmds = [ [ sys.executable, __file__, 'remote' ]
for i in range(4) ]
pool = Pool( cmds )
ndim, nwalkers, nburn, nchain = 2, 100, 100, 1000
p0 = [np.random.rand(ndim) for i in range(nwalkers)]
sampler = emcee.EnsembleSampler(nwalkers, ndim, None, pool=pool)
pos, prob, state = sampler.run_mcmc(p0, nburn)
sampler.reset()
sampler.run_mcmc(pos, nchain, rstate0=state)
print("a = %g, b = %g" % ( np.median(sampler.chain[:,:,0]),
np.median(sampler.chain[:,:,1]) ))
def remote():
x = np.arange(9)
y = np.array([1.97,2.95,4.1,5.04,5.95,6.03,8,8.85,10.1])
err = 0.2
while True:
line = sys.stdin.readline()
if not line:
break
params = [float(v) for v in line.split()]
mody = params[0] + params[1]*x
chi2 = np.sum( ((y-mody) / err)**2 )
lnprob = -0.5*chi2
sys.stdout.write(str(lnprob)+'\n')
sys.stdout.flush()
if __name__ == '__main__':
if len(sys.argv) == 2 and sys.argv[1] == 'remote':
remote()
else:
main()