#![cfg_attr(feature="cargo-clippy", allow(clippy::similar_names))]
use std::ffi::CString;
use std::fs;
use std::io::{Read, Write, Result, Error, ErrorKind};
use std::net::Shutdown;
use std::os::unix::prelude::*;
use std::os::unix::net::{UnixListener, UnixStream};
use std::mem;
use std::path::Path;
use std::ptr;
use libc::{self, gid_t, mode_t, uid_t};
#[derive(Debug)]
pub(crate) struct Socket {
socket: UnixStream,
}
impl Socket {
pub(crate) fn open<P: AsRef<Path>>(
path: P,
uid: uid_t,
gid: gid_t,
mode: mode_t,
) -> Result<Self> {
let path = path.as_ref();
Self::enforce_ownership(&path)?;
Self::unlink(&path)?;
let umask = unsafe {
libc::umask(libc::S_IRWXU | libc::S_IRWXG | libc::S_IRWXO)
};
let socket = UnixListener::bind(&path).and_then(|listener| {
let cpath = CString::new(
path.as_os_str().as_bytes()
)?;
unsafe {
if libc::chown(cpath.as_ptr(), uid, gid) == -1 {
return Err(Error::last_os_error());
};
if libc::chmod(cpath.as_ptr(), mode) == -1 {
return Err(Error::last_os_error());
}
let fd = listener.as_raw_fd();
let mut readfds = mem::MaybeUninit::<libc::fd_set>::uninit();
libc::FD_ZERO(readfds.as_mut_ptr());
let mut readfds = readfds.assume_init();
libc::FD_SET(fd, &mut readfds);
match libc::select(
fd + 1, &mut readfds,
ptr::null_mut(),
ptr::null_mut(),
ptr::null_mut(),
) {
1 => (),
-1 => return Err(Error::last_os_error()),
0 => unreachable!("`select` returned 0 even though no timeout was set"),
_ => unreachable!("`select` indicated that more than 1 fd is ready"),
};
if !libc::FD_ISSET(fd, &mut readfds) {
unreachable!("`select` returned an unexpected file descriptor");
}
}
listener.accept().map(|connection| {
Self { socket: connection.0 }
})
});
let _ = Self::unlink(&path);
let _ = unsafe { libc::umask(umask) };
socket
}
pub(crate) fn close(&mut self) -> Result<()> {
self.socket.shutdown(Shutdown::Both)
}
fn unlink(path: &Path) -> Result<()> {
match fs::metadata(&path).map(|md| md.file_type().is_socket()) {
Ok(true) => fs::remove_file(path),
Ok(false) => Err(Error::new(
ErrorKind::AlreadyExists,
format!(
"{} exists and is not a socket",
path.to_string_lossy()
),
)),
_ => Ok(()),
}
}
fn enforce_ownership(path: &Path) -> Result<()> {
let parent = path.parent().ok_or_else(|| {
Error::new(ErrorKind::AlreadyExists, format!(
"couldn't determine permissions of the parent directory for {}",
path.to_string_lossy()
))
})?;
let parent = CString::new(
parent.as_os_str().as_bytes()
)?;
unsafe {
let mut stat = mem::MaybeUninit::<libc::stat>::uninit();
if libc::stat(
parent.as_ptr(),
stat.as_mut_ptr()
) == -1 {
return Err(Error::last_os_error());
}
let stat = stat.assume_init();
if stat.st_mode & libc::S_IFDIR == 0 {
return Err(Error::new(ErrorKind::Other, format!(
"the socket path {} is not a directory",
parent.to_string_lossy(),
)));
}
if stat.st_uid != libc::geteuid() {
return Err(Error::new(ErrorKind::Other, format!(
"the socket directory {} is not owned by root",
parent.to_string_lossy(),
)));
}
if stat.st_mode & (libc::S_IWGRP | libc::S_IWOTH) != 0 {
return Err(Error::new(ErrorKind::Other, format!(
"the socket directory {} has insecure permissions",
parent.to_string_lossy(),
)));
}
}
Ok(())
}
}
impl Drop for Socket {
fn drop(&mut self) {
let _ = self.close();
}
}
impl Read for Socket {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
ctrl_c_aborts_syscalls(|| self.socket.read(buf) )?
}
}
impl Write for Socket {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
ctrl_c_aborts_syscalls(|| self.socket.write(buf) )?
}
fn flush(&mut self) -> Result<()> {
ctrl_c_aborts_syscalls(|| self.socket.flush() )?
}
}
fn ctrl_c_aborts_syscalls<F, T>(func: F) -> Result<T>
where F: FnOnce() -> T
{
unsafe {
let mut sigaction_old = mem::MaybeUninit::<libc::sigaction>::uninit();
let sigaction_null = ::std::ptr::null_mut();
sigaction(libc::SIGINT, sigaction_null, sigaction_old.as_mut_ptr())?;
let sigaction_old = sigaction_old.assume_init();
let mut sigaction_new = sigaction_old;
sigaction_new.sa_flags &= !libc::SA_RESTART;
sigaction(libc::SIGINT, &sigaction_new, sigaction_null)?;
let result = func();
sigaction(libc::SIGINT, &sigaction_old, sigaction_null)?;
Ok(result)
}
}
unsafe fn sigaction(
sig: libc::c_int,
new: *const libc::sigaction,
old: *mut libc::sigaction,
) -> Result<()> {
if libc::sigaction(sig, new, old) == -1 {
return Err(Error::last_os_error())
}
Ok(())
}