use super::listener;
use anyhow::{Context, Result};
use listener::Listen;
use listener::{Connection, ConnectionMetadata};
use log::{error, warn};
use std::fs;
use std::fs::Permissions;
use std::io::{Error, ErrorKind};
use std::os::unix::fs::FileTypeExt;
use std::os::unix::fs::PermissionsExt;
use std::os::unix::io::FromRawFd;
use std::os::unix::io::RawFd;
use std::os::unix::net::UnixListener;
use std::path::PathBuf;
use std::time::Duration;
static DEFAULT_SOCKET_PATH: &str = "/run/parsec/parsec.sock";
#[derive(Debug)]
pub struct DomainSocketListener {
listener: UnixListener,
timeout: Duration,
}
impl DomainSocketListener {
pub fn new(timeout: Duration, socket_path: PathBuf) -> Result<Self> {
let listeners: Vec<RawFd> = sd_notify::listen_fds()?.collect();
let listener = match listeners.len() {
0 => {
if socket_path.exists() {
let meta = fs::metadata(&socket_path)?;
if meta.file_type().is_socket() {
warn!(
"Removing the existing socket file at {}.",
socket_path.display()
);
fs::remove_file(&socket_path)?;
} else {
error!(
"A file exists at {} but is not a Unix Domain Socket.",
socket_path.display()
);
}
}
let listener = UnixListener::bind(&socket_path).with_context(|| {
format!("Failed to bind to Unix socket at {:?}", socket_path)
})?;
listener.set_nonblocking(true)?;
let permissions = Permissions::from_mode(0o666);
fs::set_permissions(socket_path, permissions)?;
listener
}
1 => {
let nfd = listeners[0];
unsafe { UnixListener::from_raw_fd(nfd) }
}
n => {
error!(
"Received too many file descriptors ({} received, 0 or 1 expected).",
n
);
return Err(Error::new(
ErrorKind::InvalidData,
"too many file descriptors received",
)
.into());
}
};
Ok(Self { listener, timeout })
}
}
impl Listen for DomainSocketListener {
fn set_timeout(&mut self, duration: Duration) {
self.timeout = duration;
}
fn accept(&self) -> Option<Connection> {
let stream_result = self.listener.accept();
match stream_result {
Ok((stream, _)) => {
if let Err(err) = stream.set_read_timeout(Some(self.timeout)) {
format_error!("Failed to set read timeout", err);
None
} else if let Err(err) = stream.set_write_timeout(Some(self.timeout)) {
format_error!("Failed to set write timeout", err);
None
} else if let Err(err) = stream.set_nonblocking(false) {
format_error!("Failed to set stream as blocking", err);
None
} else {
let ucred = peer_credentials::peer_cred(&stream)
.map_err(|err| {
format_error!(
"Failed to grab peer credentials metadata from UnixStream",
err
);
err
})
.ok()?;
Some(Connection {
stream: Box::new(stream),
metadata: Some(ConnectionMetadata::UnixPeerCredentials {
uid: ucred.uid,
gid: ucred.gid,
pid: ucred.pid,
}),
})
}
}
Err(err) => {
if err.kind() != ErrorKind::WouldBlock {
format_error!("Failed to connect with a UnixStream", err);
}
None
}
}
}
}
#[derive(Clone, Debug, Default)]
pub struct DomainSocketListenerBuilder {
timeout: Option<Duration>,
socket_path: Option<PathBuf>,
}
impl DomainSocketListenerBuilder {
pub fn new() -> Self {
DomainSocketListenerBuilder {
timeout: None,
socket_path: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_socket_path(mut self, socket_path: Option<PathBuf>) -> Self {
self.socket_path = socket_path;
self
}
pub fn build(self) -> Result<DomainSocketListener> {
DomainSocketListener::new(
self.timeout.ok_or_else(|| {
error!("The listener timeout was not set.");
Error::new(ErrorKind::InvalidInput, "listener timeout missing")
})?,
self.socket_path
.unwrap_or_else(|| DEFAULT_SOCKET_PATH.into()),
)
}
}
pub mod peer_credentials {
use libc::{gid_t, pid_t, uid_t};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct UCred {
pub uid: uid_t,
pub gid: gid_t,
pub pid: Option<pid_t>,
}
#[cfg(any(target_os = "android", target_os = "linux"))]
pub use self::impl_linux::peer_cred;
#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
target_os = "ios",
target_os = "macos",
target_os = "openbsd"
))]
pub use self::impl_bsd::peer_cred;
#[cfg(any(target_os = "linux", target_os = "android"))]
#[allow(missing_docs, trivial_casts)] pub mod impl_linux {
use super::UCred;
use libc::{SO_PEERCRED, SOL_SOCKET, c_void, getsockopt, socklen_t, ucred};
use std::io;
use std::mem::size_of;
use std::os::unix::io::AsRawFd;
use std::os::unix::net::UnixStream;
pub fn peer_cred(socket: &UnixStream) -> io::Result<UCred> {
let ucred_size = size_of::<ucred>();
assert!(size_of::<u32>() <= size_of::<usize>());
assert!(ucred_size <= u32::MAX as usize);
let mut ucred_size = ucred_size as socklen_t;
let mut ucred: ucred = ucred {
pid: 1,
uid: 1,
gid: 1,
};
unsafe {
let ret = getsockopt(
socket.as_raw_fd(),
SOL_SOCKET,
SO_PEERCRED,
&mut ucred as *mut ucred as *mut c_void,
&mut ucred_size,
);
if ret == 0 && ucred_size as usize == size_of::<ucred>() {
Ok(UCred {
uid: ucred.uid,
gid: ucred.gid,
pid: Some(ucred.pid),
})
} else {
Err(io::Error::last_os_error())
}
}
}
}
#[cfg(any(
target_os = "dragonfly",
target_os = "macos",
target_os = "ios",
target_os = "freebsd",
target_os = "openbsd"
))]
#[allow(missing_docs)] pub mod impl_bsd {
use super::UCred;
use std::io;
use std::os::unix::io::AsRawFd;
use std::os::unix::net::UnixStream;
pub fn peer_cred(socket: &UnixStream) -> io::Result<UCred> {
let mut cred = UCred {
uid: 1,
gid: 1,
pid: None,
};
unsafe {
let ret = libc::getpeereid(socket.as_raw_fd(), &mut cred.uid, &mut cred.gid);
if ret == 0 {
Ok(cred)
} else {
Err(io::Error::last_os_error())
}
}
}
}
}