use platform::Channel;
use super::*;
use std::{io, mem, ptr, env, process};
use std::time::Duration;
use std::ffi::{OsString, OsStr};
use uuid::Uuid;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::reactor::{PollEvented, Handle as TokioHandle};
use self::mio::Ready;
use futures::{Poll, Async};
use platform::libc;
macro_rules! use_seqpacket {
() => { cfg!(not(target_os = "macos")) };
}
pub(crate) const MAX_MSG_FDS: usize = 32;
pub struct MessageChannel {
socket: PollEvented<ScopedFd>,
cmsg: Cmsg,
}
impl MessageChannel {
pub fn pair(tokio_loop: &TokioHandle) -> io::Result<(Self, Self)> {
let (a, b) = raw_socketpair()?;
Ok((
MessageChannel { socket: PollEvented::new_with_handle(a, tokio_loop)?, cmsg: Cmsg::new(MAX_MSG_FDS) },
MessageChannel { socket: PollEvented::new_with_handle(b, tokio_loop)?, cmsg: Cmsg::new(MAX_MSG_FDS) },
))
}
pub fn establish_with_child<F>(command: &mut process::Command, tokio_loop: &TokioHandle, transmit_and_launch: F) -> io::Result<(Self, process::Child)> where
F: FnOnce(&mut process::Command, ChildMessageChannel) -> io::Result<process::Child>
{
let (a, b) = raw_socketpair()?;
let _guard = CloexecGuard::new(&b.0);
let channel = ChildMessageChannel { fd: b.0 };
let child = transmit_and_launch(command, channel)?;
Ok((
MessageChannel { socket: PollEvented::new_with_handle(a, tokio_loop)?, cmsg: Cmsg::new(MAX_MSG_FDS), },
child,
))
}
pub fn establish_with_child_custom<F, T>(tokio_loop: &TokioHandle, transmit_and_launch: F) -> io::Result<(Self, T)> where
F: FnOnce(ChildMessageChannel) -> io::Result<(ProcessHandle, T)>
{
let (a, b) = raw_socketpair()?;
let _guard = CloexecGuard::new(&b.0);
let channel = ChildMessageChannel { fd: b.0 };
let (_token, t) = transmit_and_launch(channel)?;
Ok((
MessageChannel { socket: PollEvented::new_with_handle(a, tokio_loop)?, cmsg: Cmsg::new(MAX_MSG_FDS), },
t,
))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChildMessageChannel {
fd: libc::c_int,
}
impl ChildMessageChannel {
pub fn into_channel(self, tokio_loop: &TokioHandle) -> io::Result<MessageChannel> {
Ok(
MessageChannel { socket: PollEvented::new_with_handle(ScopedFd(self.fd), tokio_loop)?, cmsg: Cmsg::new(MAX_MSG_FDS) },
)
}
}
impl Channel for MessageChannel {
fn cmsg(&mut self) -> &mut Cmsg {
&mut self.cmsg
}
}
impl io::Write for MessageChannel {
fn write(&mut self, buffer: &[u8]) -> io::Result<usize> {
if let Async::Ready(_) = self.socket.poll_write_ready()? {
unsafe {
self.cmsg.set_data(buffer);
let result = libc::sendmsg(self.socket.get_ref().0, self.cmsg.ptr(), libc::MSG_DONTWAIT);
if result >= 0 {
Ok(result as usize)
} else {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::WouldBlock {
self.socket.clear_write_ready()?;
}
Err(err)
}
}
} else {
Err(io::ErrorKind::WouldBlock.into())
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl AsyncWrite for MessageChannel {
fn shutdown(&mut self) -> Poll<(), io::Error> {
if unsafe { libc::shutdown(self.socket.get_ref().0, libc::SHUT_RDWR) } < 0 {
return Err(io::Error::last_os_error());
}
Ok(Async::Ready(()))
}
}
impl io::Read for MessageChannel {
fn read(&mut self, buffer: &mut [u8]) -> io::Result<usize> {
if let Async::Ready(_) = self.socket.poll_read_ready(Ready::readable())? {
unsafe {
self.cmsg.set_data(buffer);
self.cmsg.reset_fd_count();
let result = libc::recvmsg(self.socket.get_ref().0, self.cmsg.ptr(), libc::MSG_DONTWAIT);
if result >= 0 {
Ok(result as usize)
} else {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::WouldBlock {
self.socket.clear_read_ready(Ready::readable())?;
}
Err(err)
}
}
} else {
Err(io::ErrorKind::WouldBlock.into())
}
}
}
impl AsyncRead for MessageChannel {
}
#[derive(Serialize, Deserialize, Debug)]
pub struct PreMessageChannel {
fd: ScopedFd,
}
impl PreMessageChannel {
pub fn pair() -> io::Result<(Self, Self)> {
let (a, b) = raw_socketpair()?;
Ok((
PreMessageChannel { fd: a },
PreMessageChannel { fd: b },
))
}
pub fn into_channel(self, _remote_process: ProcessHandle, tokio_loop: &TokioHandle) -> io::Result<MessageChannel> {
Ok(MessageChannel { socket: PollEvented::new_with_handle(self.fd, tokio_loop)?, cmsg: Cmsg::new(MAX_MSG_FDS) })
}
pub fn into_sealed_channel(self, tokio_loop: &TokioHandle) -> io::Result<MessageChannel> {
Ok(MessageChannel { socket: PollEvented::new_with_handle(self.fd, tokio_loop)?, cmsg: Cmsg::new(MAX_MSG_FDS) })
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ProcessHandle {
}
impl ProcessHandle {
pub fn current() -> io::Result<Self> {
Ok(ProcessHandle {})
}
pub fn from_child(_child: &process::Child) -> io::Result<Self> {
Ok(ProcessHandle {})
}
pub fn clone(&self) -> io::Result<Self> {
Ok(ProcessHandle {})
}
}
pub struct NamedMessageChannel {
socket: ScopedFd,
tokio_loop: TokioHandle,
name: OsString,
}
impl NamedMessageChannel {
pub fn new(tokio_loop: &TokioHandle) -> io::Result<Self> {
let name = OsString::from(env::temp_dir().join(format!("{}", Uuid::new_v4())));
unsafe {
let typ = if use_seqpacket!() {
libc::SOCK_SEQPACKET
} else {
libc::SOCK_STREAM
};
let socket = ScopedFd(libc::socket(libc::AF_UNIX, typ, 0));
set_cloexec(socket.0)?;
let addr = sockaddr_un(&name);
if libc::bind(socket.0, &addr as *const _ as *const libc::sockaddr, mem::size_of_val(&addr) as _) < 0 {
return Err(io::Error::last_os_error());
}
if libc::listen(socket.0, 1) < 0 {
return Err(io::Error::last_os_error());
}
Ok(NamedMessageChannel {
socket,
tokio_loop: tokio_loop.clone(),
name: OsString::from(name),
})
}
}
pub fn name(&self) -> &OsStr {
&self.name
}
pub fn accept(self, _timeout: Option<Duration>) -> io::Result<MessageChannel> {
unsafe {
let socket = libc::accept(self.socket.0, ptr::null_mut(), ptr::null_mut());
if socket < 0 {
return Err(io::Error::last_os_error());
}
Ok(MessageChannel {
socket: PollEvented::new_with_handle(ScopedFd(socket), &self.tokio_loop)?,
cmsg: Cmsg::new(MAX_MSG_FDS),
})
}
}
pub fn connect<N>(name: N, _timeout: Option<Duration>, tokio_loop: &TokioHandle) -> io::Result<MessageChannel> where
N: AsRef<OsStr>,
{
let name = name.as_ref();
unsafe {
let typ = if use_seqpacket!() {
libc::SOCK_SEQPACKET
} else {
libc::SOCK_STREAM
};
let socket = libc::socket(libc::AF_UNIX, typ, 0);
set_cloexec(socket)?;
let addr = sockaddr_un(name);
if libc::connect(socket, &addr as *const _ as *const libc::sockaddr, mem::size_of_val(&addr) as _) < 0 {
return Err(io::Error::last_os_error());
}
Ok(MessageChannel {
socket: PollEvented::new_with_handle(ScopedFd(socket), tokio_loop)?,
cmsg: Cmsg::new(MAX_MSG_FDS),
})
}
}
}
fn raw_socketpair() -> io::Result<(ScopedFd, ScopedFd)> {
unsafe {
let mut sockets: [libc::c_int; 2] = [0, 0];
let socket_type = if use_seqpacket!() {
libc::SOCK_SEQPACKET
} else {
libc::SOCK_STREAM
};
if libc::socketpair(libc::AF_UNIX, socket_type, 0, sockets.as_mut_ptr()) == -1 {
return Err(io::Error::last_os_error());
}
for &socket in sockets.iter() {
set_cloexec(socket)?;
}
Ok((ScopedFd(sockets[0]), ScopedFd(sockets[1])))
}
}
struct CloexecGuard<'a>(&'a RawFd);
impl<'a> Drop for CloexecGuard<'a> {
fn drop(&mut self) {
if let Err(err) = set_cloexec(*self.0) {
error!("failed to set CLOEXEC: {}", err);
}
}
}
impl<'a> CloexecGuard<'a> {
pub fn new(fd: &'a RawFd) -> io::Result<Self> {
clear_cloexec(*fd)?;
Ok(CloexecGuard(fd))
}
}