use crate::CoreError;
use crate::error::syscall_ret;
use crate::reactor::Fd;
use std::io::Error as IoError;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::fs::FileTypeExt;
use std::os::unix::io::AsRawFd;
use std::path::Path;
#[inline(always)]
fn errno() -> i32 {
IoError::last_os_error().raw_os_error().unwrap_or(0)
}
pub struct UnixListenerFd {
pub fd: Fd,
}
pub struct UnixStreamFd {
pub fd: Fd,
}
pub enum UnixConnectResult {
Connected(UnixStreamFd),
InProgress(UnixStreamFd),
}
#[derive(Clone, Copy, Debug)]
pub enum UnixSocketAddr<'a> {
Path(&'a Path),
Abstract(&'a [u8]),
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum StaleSocketPolicy {
#[default]
Preserve,
UnlinkSocketOnly,
UnlinkAnyPath,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct UnixSocketBindOptions {
pub stale_socket_policy: StaleSocketPolicy,
pub mode: Option<u32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PeerCred {
pub pid: Option<i32>,
pub uid: u32,
pub gid: u32,
}
impl UnixListenerFd {
pub fn accept(&self) -> Result<Option<UnixStreamFd>, CoreError> {
loop {
let fd = unsafe {
libc::accept4(
self.fd.as_raw_fd(),
std::ptr::null_mut(),
std::ptr::null_mut(),
libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
)
};
if fd >= 0 {
return Ok(Some(UnixStreamFd {
fd: Fd::new(fd, "accept4")?,
}));
}
let e = errno();
if e == libc::EINTR {
continue;
}
if e == libc::EAGAIN || e == libc::EWOULDBLOCK {
return Ok(None);
}
return Err(CoreError::sys(e, "accept4"));
}
}
}
impl UnixStreamFd {
pub fn peer_cred(&self) -> Result<Option<PeerCred>, CoreError> {
peer_cred_raw(&self.fd)
}
pub fn check_connect_error(&self) -> Result<Option<i32>, CoreError> {
let mut code: libc::c_int = 0;
let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
let ret = unsafe {
libc::getsockopt(
self.fd.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_ERROR,
(&mut code as *mut libc::c_int).cast(),
&mut len,
)
};
syscall_ret(ret, "getsockopt(SO_ERROR)")?;
if code == 0 { Ok(None) } else { Ok(Some(code)) }
}
pub fn finish_connect(self) -> Result<Self, CoreError> {
match self.check_connect_error()? {
None => Ok(self),
Some(code) => Err(CoreError::sys(code, "connect(SO_ERROR)")),
}
}
}
pub fn bind_unix_listener(
addr: UnixSocketAddr<'_>,
opts: UnixSocketBindOptions,
) -> Result<UnixListenerFd, CoreError> {
let encoded = UnixSockAddr::new(addr, "unix bind address")?;
match addr {
UnixSocketAddr::Path(path) => {
apply_stale_socket_policy(path, opts.stale_socket_policy)?;
}
UnixSocketAddr::Abstract(_) => {
if opts.stale_socket_policy != StaleSocketPolicy::Preserve || opts.mode.is_some() {
return Err(CoreError::sys(libc::EINVAL, "abstract unix bind options"));
}
}
}
let fd = new_unix_stream_socket()?;
let ret = unsafe { libc::bind(fd.as_raw_fd(), encoded.as_ptr(), encoded.len()) };
syscall_ret(ret, "bind")?;
if let (UnixSocketAddr::Path(path), Some(mode)) = (addr, opts.mode) {
if let Err(err) = chmod_unix_socket(UnixSocketAddr::Path(path), mode) {
cleanup_created_path(addr);
return Err(err);
}
}
let ret = unsafe { libc::listen(fd.as_raw_fd(), libc::SOMAXCONN) };
if let Err(err) = syscall_ret(ret, "listen") {
cleanup_created_path(addr);
return Err(err);
}
Ok(UnixListenerFd { fd })
}
pub fn connect_unix_stream(addr: UnixSocketAddr<'_>) -> Result<UnixConnectResult, CoreError> {
let encoded = UnixSockAddr::new(addr, "unix connect address")?;
let fd = new_unix_stream_socket()?;
loop {
let ret = unsafe { libc::connect(fd.as_raw_fd(), encoded.as_ptr(), encoded.len()) };
if ret == 0 {
return Ok(UnixConnectResult::Connected(UnixStreamFd { fd }));
}
let e = errno();
if e == libc::EINTR {
continue;
}
if e == libc::EINPROGRESS || e == libc::EALREADY {
return Ok(UnixConnectResult::InProgress(UnixStreamFd { fd }));
}
if e == libc::EISCONN {
return Ok(UnixConnectResult::Connected(UnixStreamFd { fd }));
}
return Err(CoreError::sys(e, "connect"));
}
}
pub fn chmod_unix_socket(addr: UnixSocketAddr<'_>, mode: u32) -> Result<(), CoreError> {
match addr {
UnixSocketAddr::Path(path) => {
let metadata = std::fs::symlink_metadata(path).map_err(|err| {
CoreError::sys(
err.raw_os_error().unwrap_or(libc::EIO),
"lstat unix socket path",
)
})?;
if !metadata.file_type().is_socket() {
return Err(CoreError::sys(libc::EINVAL, "chmod unix socket path"));
}
let c_path = path_cstring(path, "chmod unix socket path")?;
let ret = unsafe { libc::chmod(c_path.as_ptr(), mode as libc::mode_t) };
syscall_ret(ret, "chmod")
}
UnixSocketAddr::Abstract(_) => Err(CoreError::sys(libc::EINVAL, "chmod abstract socket")),
}
}
pub fn chmod_socket_path(path: impl AsRef<Path>, mode: u32) -> Result<(), CoreError> {
chmod_unix_socket(UnixSocketAddr::Path(path.as_ref()), mode)
}
fn new_unix_stream_socket() -> Result<Fd, CoreError> {
let fd = unsafe {
libc::socket(
libc::AF_UNIX,
libc::SOCK_STREAM | libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
0,
)
};
syscall_ret(fd, "socket(AF_UNIX)")?;
Fd::new(fd, "socket(AF_UNIX)")
}
fn apply_stale_socket_policy(path: &Path, policy: StaleSocketPolicy) -> Result<(), CoreError> {
match policy {
StaleSocketPolicy::Preserve => Ok(()),
StaleSocketPolicy::UnlinkSocketOnly => {
let metadata = match std::fs::symlink_metadata(path) {
Ok(metadata) => metadata,
Err(err) if err.raw_os_error() == Some(libc::ENOENT) => return Ok(()),
Err(err) => {
return Err(CoreError::sys(
err.raw_os_error().unwrap_or(libc::EIO),
"lstat unix socket path",
));
}
};
if !metadata.file_type().is_socket() {
return Err(CoreError::sys(libc::EEXIST, "stale unix socket path"));
}
unlink_path(path, "unlink stale unix socket")
}
StaleSocketPolicy::UnlinkAnyPath => unlink_path(path, "unlink unix socket path"),
}
}
fn unlink_path(path: &Path, op: &'static str) -> Result<(), CoreError> {
match std::fs::remove_file(path) {
Ok(()) => Ok(()),
Err(err) if err.raw_os_error() == Some(libc::ENOENT) => Ok(()),
Err(err) => Err(CoreError::sys(err.raw_os_error().unwrap_or(libc::EIO), op)),
}
}
fn cleanup_created_path(addr: UnixSocketAddr<'_>) {
if let UnixSocketAddr::Path(path) = addr {
let _ = std::fs::remove_file(path);
}
}
struct UnixSockAddr {
inner: libc::sockaddr_un,
len: libc::socklen_t,
}
impl UnixSockAddr {
fn new(addr: UnixSocketAddr<'_>, op: &'static str) -> Result<Self, CoreError> {
let mut inner: libc::sockaddr_un = unsafe { std::mem::zeroed() };
inner.sun_family = libc::AF_UNIX as libc::sa_family_t;
let sun_path_offset = std::mem::offset_of!(libc::sockaddr_un, sun_path);
let len = match addr {
UnixSocketAddr::Path(path) => {
let bytes = path.as_os_str().as_bytes();
if bytes.is_empty() {
return Err(CoreError::sys(libc::EINVAL, op));
}
if bytes.contains(&0) {
return Err(CoreError::sys(libc::EINVAL, op));
}
if bytes.len() >= inner.sun_path.len() {
return Err(CoreError::sys(libc::ENAMETOOLONG, op));
}
for (slot, byte) in inner.sun_path.iter_mut().zip(bytes.iter().copied()) {
*slot = byte as libc::c_char;
}
sun_path_offset + bytes.len() + 1
}
UnixSocketAddr::Abstract(name) => {
validate_abstract_supported()?;
if name.is_empty() {
return Err(CoreError::sys(libc::EINVAL, op));
}
if name.len() + 1 > inner.sun_path.len() {
return Err(CoreError::sys(libc::ENAMETOOLONG, op));
}
inner.sun_path[0] = 0;
for (slot, byte) in inner.sun_path[1..].iter_mut().zip(name.iter().copied()) {
*slot = byte as libc::c_char;
}
sun_path_offset + 1 + name.len()
}
};
let len = libc::socklen_t::try_from(len).map_err(|_| CoreError::sys(libc::EINVAL, op))?;
Ok(Self { inner, len })
}
fn len(&self) -> libc::socklen_t {
self.len
}
fn as_ptr(&self) -> *const libc::sockaddr {
(&self.inner as *const libc::sockaddr_un).cast()
}
}
fn validate_abstract_supported() -> Result<(), CoreError> {
if cfg!(any(target_os = "linux", target_os = "android")) {
Ok(())
} else {
Err(CoreError::sys(libc::ENOSYS, "abstract unix socket"))
}
}
fn path_cstring(path: &Path, op: &'static str) -> Result<std::ffi::CString, CoreError> {
std::ffi::CString::new(path.as_os_str().as_bytes())
.map_err(|_| CoreError::sys(libc::EINVAL, op))
}
#[cfg(any(target_os = "linux", target_os = "android"))]
fn peer_cred_raw(fd: &Fd) -> Result<Option<PeerCred>, CoreError> {
let mut cred: libc::ucred = unsafe { std::mem::zeroed() };
let mut len = std::mem::size_of::<libc::ucred>() as libc::socklen_t;
let ret = unsafe {
libc::getsockopt(
fd.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_PEERCRED,
(&mut cred as *mut libc::ucred).cast(),
&mut len,
)
};
syscall_ret(ret, "getsockopt(SO_PEERCRED)")?;
Ok(Some(PeerCred {
pid: Some(cred.pid),
uid: cred.uid,
gid: cred.gid,
}))
}
#[cfg(not(any(target_os = "linux", target_os = "android")))]
fn peer_cred_raw(_fd: &Fd) -> Result<Option<PeerCred>, CoreError> {
Ok(None)
}