import atexit
import errno
import os
import selectors
import signal
import socket
import struct
import sys
import threading
import warnings
from . import connection
from . import process
from .context import reduction
from . import resource_tracker
from . import spawn
from . import util
__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
'set_forkserver_preload']
MAXFDS_TO_SEND = 256
SIGNED_STRUCT = struct.Struct('q')
class ForkServer(object):
def __init__(self):
self._forkserver_address = None
self._forkserver_alive_fd = None
self._forkserver_pid = None
self._inherited_fds = None
self._lock = threading.Lock()
self._preload_modules = ['__main__']
def _stop(self):
with self._lock:
self._stop_unlocked()
def _stop_unlocked(self):
if self._forkserver_pid is None:
return
os.close(self._forkserver_alive_fd)
self._forkserver_alive_fd = None
os.waitpid(self._forkserver_pid, 0)
self._forkserver_pid = None
if not util.is_abstract_socket_namespace(self._forkserver_address):
os.unlink(self._forkserver_address)
self._forkserver_address = None
def set_forkserver_preload(self, modules_names):
if not all(type(mod) is str for mod in modules_names):
raise TypeError('module_names must be a list of strings')
self._preload_modules = modules_names
def get_inherited_fds(self):
return self._inherited_fds
def connect_to_new_process(self, fds):
self.ensure_running()
if len(fds) + 4 >= MAXFDS_TO_SEND:
raise ValueError('too many fds')
with socket.socket(socket.AF_UNIX) as client:
client.connect(self._forkserver_address)
parent_r, child_w = os.pipe()
child_r, parent_w = os.pipe()
allfds = [child_r, child_w, self._forkserver_alive_fd,
resource_tracker.getfd()]
allfds += fds
try:
reduction.sendfds(client, allfds)
return parent_r, parent_w
except:
os.close(parent_r)
os.close(parent_w)
raise
finally:
os.close(child_r)
os.close(child_w)
def ensure_running(self):
with self._lock:
resource_tracker.ensure_running()
if self._forkserver_pid is not None:
pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
if not pid:
return
os.close(self._forkserver_alive_fd)
self._forkserver_address = None
self._forkserver_alive_fd = None
self._forkserver_pid = None
cmd = ('from multiprocessing.forkserver import main; ' +
'main(%d, %d, %r, **%r)')
main_kws = {}
if self._preload_modules:
data = spawn.get_preparation_data('ignore')
if 'sys_path' in data:
main_kws['sys_path'] = data['sys_path']
if 'init_main_from_path' in data:
main_kws['main_path'] = data['init_main_from_path']
with socket.socket(socket.AF_UNIX) as listener:
address = connection.arbitrary_address('AF_UNIX')
listener.bind(address)
if not util.is_abstract_socket_namespace(address):
os.chmod(address, 0o600)
listener.listen()
alive_r, alive_w = os.pipe()
try:
fds_to_pass = [listener.fileno(), alive_r]
cmd %= (listener.fileno(), alive_r, self._preload_modules,
main_kws)
exe = spawn.get_executable()
args = [exe] + util._args_from_interpreter_flags()
args += ['-c', cmd]
pid = util.spawnv_passfds(exe, args, fds_to_pass)
except:
os.close(alive_w)
raise
finally:
os.close(alive_r)
self._forkserver_address = address
self._forkserver_alive_fd = alive_w
self._forkserver_pid = pid
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
if preload:
if sys_path is not None:
sys.path[:] = sys_path
if '__main__' in preload and main_path is not None:
process.current_process()._inheriting = True
try:
spawn.import_main_path(main_path)
finally:
del process.current_process()._inheriting
for modname in preload:
try:
__import__(modname)
except ImportError:
pass
util._flush_std_streams()
util._close_stdin()
sig_r, sig_w = os.pipe()
os.set_blocking(sig_r, False)
os.set_blocking(sig_w, False)
def sigchld_handler(*_unused):
pass
handlers = {
signal.SIGCHLD: sigchld_handler,
signal.SIGINT: signal.SIG_IGN,
}
old_handlers = {sig: signal.signal(sig, val)
for (sig, val) in handlers.items()}
signal.set_wakeup_fd(sig_w)
pid_to_fd = {}
with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
selectors.DefaultSelector() as selector:
_forkserver._forkserver_address = listener.getsockname()
selector.register(listener, selectors.EVENT_READ)
selector.register(alive_r, selectors.EVENT_READ)
selector.register(sig_r, selectors.EVENT_READ)
while True:
try:
while True:
rfds = [key.fileobj for (key, events) in selector.select()]
if rfds:
break
if alive_r in rfds:
assert os.read(alive_r, 1) == b'', "Not at EOF?"
raise SystemExit
if sig_r in rfds:
os.read(sig_r, 65536) while True:
try:
pid, sts = os.waitpid(-1, os.WNOHANG)
except ChildProcessError:
break
if pid == 0:
break
child_w = pid_to_fd.pop(pid, None)
if child_w is not None:
returncode = os.waitstatus_to_exitcode(sts)
try:
write_signed(child_w, returncode)
except BrokenPipeError:
pass
os.close(child_w)
else:
warnings.warn('forkserver: waitpid returned '
'unexpected pid %d' % pid)
if listener in rfds:
with listener.accept()[0] as s:
fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
if len(fds) > MAXFDS_TO_SEND:
raise RuntimeError(
"Too many ({0:n}) fds to send".format(
len(fds)))
child_r, child_w, *fds = fds
s.close()
pid = os.fork()
if pid == 0:
code = 1
try:
listener.close()
selector.close()
unused_fds = [alive_r, child_w, sig_r, sig_w]
unused_fds.extend(pid_to_fd.values())
atexit._clear()
atexit.register(util._exit_function)
code = _serve_one(child_r, fds,
unused_fds,
old_handlers)
except Exception:
sys.excepthook(*sys.exc_info())
sys.stderr.flush()
finally:
atexit._run_exitfuncs()
os._exit(code)
else:
try:
write_signed(child_w, pid)
except BrokenPipeError:
pass
pid_to_fd[pid] = child_w
os.close(child_r)
for fd in fds:
os.close(fd)
except OSError as e:
if e.errno != errno.ECONNABORTED:
raise
def _serve_one(child_r, fds, unused_fds, handlers):
signal.set_wakeup_fd(-1)
for sig, val in handlers.items():
signal.signal(sig, val)
for fd in unused_fds:
os.close(fd)
(_forkserver._forkserver_alive_fd,
resource_tracker._resource_tracker._fd,
*_forkserver._inherited_fds) = fds
parent_sentinel = os.dup(child_r)
code = spawn._main(child_r, parent_sentinel)
return code
def read_signed(fd):
data = b''
length = SIGNED_STRUCT.size
while len(data) < length:
s = os.read(fd, length - len(data))
if not s:
raise EOFError('unexpected EOF')
data += s
return SIGNED_STRUCT.unpack(data)[0]
def write_signed(fd, n):
msg = SIGNED_STRUCT.pack(n)
while msg:
nbytes = os.write(fd, msg)
if nbytes == 0:
raise RuntimeError('should not get here')
msg = msg[nbytes:]
_forkserver = ForkServer()
ensure_running = _forkserver.ensure_running
get_inherited_fds = _forkserver.get_inherited_fds
connect_to_new_process = _forkserver.connect_to_new_process
set_forkserver_preload = _forkserver.set_forkserver_preload